瀏覽代碼

weaving in

Blake Mizerany 1 年之前
父節點
當前提交
844217bcf1
共有 3 個文件被更改,包括 126 次插入4 次删除
  1. 23 4
      client/registry/registry.go
  2. 75 0
      server/cache.go
  3. 28 0
      server/routes.go

+ 23 - 4
client/registry/registry.go

@@ -31,9 +31,19 @@ type Client struct {
 	BaseURL string
 	BaseURL string
 
 
 	Logger *slog.Logger
 	Logger *slog.Logger
+
+	// NameFill is a string that is used to fill in the missing parts of
+	// a name when it is not fully qualified. It is used to make a name
+	// fully qualified before pushing or pulling it. The default is
+	// "registry.ollama.ai/library/_:latest".
+	//
+	// Most users can ignore this field. It is intended for use by
+	// clients that need to push or pull names to registries other than
+	// registry.ollama.ai, and for testing.
+	NameFill string
 }
 }
 
 
-func (c *Client) logger() *slog.Logger {
+func (c *Client) log() *slog.Logger {
 	return cmp.Or(c.Logger, slog.Default())
 	return cmp.Or(c.Logger, slog.Default())
 }
 }
 
 
@@ -92,12 +102,12 @@ type Cache interface {
 // layers that are not already in the cache. It returns an error if any part
 // layers that are not already in the cache. It returns an error if any part
 // of the process fails, specifically:
 // of the process fails, specifically:
 func (c *Client) Pull(ctx context.Context, cache Cache, name string) error {
 func (c *Client) Pull(ctx context.Context, cache Cache, name string) error {
-	mn := model.ParseName(name)
+	mn := parseNameFill(name, c.NameFill)
 	if !mn.IsFullyQualified() {
 	if !mn.IsFullyQualified() {
 		return fmt.Errorf("ollama: pull: invalid name: %s", name)
 		return fmt.Errorf("ollama: pull: invalid name: %s", name)
 	}
 	}
 
 
-	log := c.logger().With("name", name)
+	log := c.log().With("name", name)
 
 
 	pr, err := ollama.Do[*apitype.PullResponse](ctx, c.oclient(), "GET", "/v1/pull/"+name, nil)
 	pr, err := ollama.Do[*apitype.PullResponse](ctx, c.oclient(), "GET", "/v1/pull/"+name, nil)
 	if err != nil {
 	if err != nil {
@@ -211,6 +221,14 @@ func (nopSeeker) Seek(int64, int) (int64, error) {
 	return 0, nil
 	return 0, nil
 }
 }
 
 
+func parseNameFill(name, fill string) model.Name {
+	f := model.ParseNameBare(fill)
+	if !f.IsFullyQualified() {
+		panic(fmt.Errorf("invalid fill: %q", fill))
+	}
+	return model.Merge(model.ParseNameBare(name), f)
+}
+
 // Push pushes a manifest to the server and responds to the server's
 // Push pushes a manifest to the server and responds to the server's
 // requests for layer uploads, if any, and finally commits the manifest for
 // requests for layer uploads, if any, and finally commits the manifest for
 // name. It returns an error if any part of the process fails, specifically:
 // name. It returns an error if any part of the process fails, specifically:
@@ -218,7 +236,7 @@ func (nopSeeker) Seek(int64, int) (int64, error) {
 // If the server requests layers not found in the cache, ErrLayerNotFound is
 // If the server requests layers not found in the cache, ErrLayerNotFound is
 // returned.
 // returned.
 func (c *Client) Push(ctx context.Context, cache Cache, name string) error {
 func (c *Client) Push(ctx context.Context, cache Cache, name string) error {
-	mn := model.ParseName(name)
+	mn := parseNameFill(name, c.NameFill)
 	if !mn.IsFullyQualified() {
 	if !mn.IsFullyQualified() {
 		return fmt.Errorf("ollama: push: invalid name: %s", name)
 		return fmt.Errorf("ollama: push: invalid name: %s", name)
 	}
 	}
@@ -259,6 +277,7 @@ func (c *Client) Push(ctx context.Context, cache Cache, name string) error {
 			}
 			}
 			defer f.Close()
 			defer f.Close()
 
 
+			c.log().Info("pushing layer", "digest", need.Digest, "start", need.Start, "end", need.End)
 			cp, err := PushLayer(ctx, f, need.URL, need.Start, need.End)
 			cp, err := PushLayer(ctx, f, need.URL, need.Start, need.End)
 			if err != nil {
 			if err != nil {
 				return fmt.Errorf("PushLayer: %w: %s", err, need.Digest)
 				return fmt.Errorf("PushLayer: %w: %s", err, need.Digest)

+ 75 - 0
server/cache.go

@@ -0,0 +1,75 @@
+package server
+
+import (
+	"cmp"
+	"fmt"
+	"os"
+	"path/filepath"
+
+	"github.com/ollama/ollama/client/registry"
+	"github.com/ollama/ollama/types/model"
+)
+
+// cache is a simple demo disk cache. it does not validate anything
+type cache struct {
+	dir string
+}
+
+func defaultCache() registry.Cache {
+	homeDir, _ := os.UserHomeDir()
+	if homeDir == "" {
+		panic("could not determine home directory")
+	}
+	modelsDir := cmp.Or(
+		os.Getenv("OLLAMA_MODELS"),
+		filepath.Join(homeDir, ".ollama", "models"),
+	)
+	return &cache{modelsDir}
+}
+
+func invalidDigest(digest string) error {
+	return fmt.Errorf("invalid digest: %s", digest)
+}
+
+func (c *cache) OpenLayer(d model.Digest) (registry.ReadAtSeekCloser, error) {
+	return os.Open(c.LayerFile(d))
+}
+
+func (c *cache) LayerFile(d model.Digest) string {
+	return filepath.Join(c.dir, "blobs", d.String())
+}
+
+func (c *cache) PutLayerFile(d model.Digest, fromPath string) error {
+	if !d.IsValid() {
+		return invalidDigest(d.String())
+	}
+	bfile := c.LayerFile(d)
+	dir, _ := filepath.Split(bfile)
+	if err := os.MkdirAll(dir, 0755); err != nil {
+		return err
+	}
+	return os.Rename(fromPath, bfile)
+}
+
+func (c *cache) ManifestData(name model.Name) []byte {
+	if !name.IsFullyQualified() {
+		return nil
+	}
+	data, err := os.ReadFile(filepath.Join(c.dir, "manifests", name.Filepath()))
+	if err != nil {
+		return nil
+	}
+	return data
+}
+
+func (c *cache) SetManifestData(name model.Name, data []byte) error {
+	if !name.IsFullyQualified() {
+		return fmt.Errorf("invalid name: %s", name)
+	}
+	filep := filepath.Join(c.dir, "manifests", name.Filepath())
+	dir, _ := filepath.Split(filep)
+	if err := os.MkdirAll(dir, 0755); err != nil {
+		return err
+	}
+	return os.WriteFile(filep, data, 0644)
+}

+ 28 - 0
server/routes.go

@@ -17,6 +17,7 @@ import (
 	"path/filepath"
 	"path/filepath"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
+	"sync"
 	"syscall"
 	"syscall"
 	"time"
 	"time"
 
 
@@ -25,6 +26,7 @@ import (
 	"golang.org/x/exp/slices"
 	"golang.org/x/exp/slices"
 
 
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/client/registry"
 	"github.com/ollama/ollama/gpu"
 	"github.com/ollama/ollama/gpu"
 	"github.com/ollama/ollama/llm"
 	"github.com/ollama/ollama/llm"
 	"github.com/ollama/ollama/openai"
 	"github.com/ollama/ollama/openai"
@@ -33,6 +35,14 @@ import (
 	"github.com/ollama/ollama/version"
 	"github.com/ollama/ollama/version"
 )
 )
 
 
+var experiments = sync.OnceValue(func() []string {
+	return strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ",")
+})
+
+func useExperiemntal(flag string) bool {
+	return slices.Contains(experiments(), flag)
+}
+
 var mode string = gin.DebugMode
 var mode string = gin.DebugMode
 
 
 type Server struct {
 type Server struct {
@@ -444,6 +454,24 @@ func (s *Server) PullModelHandler(c *gin.Context) {
 		return
 		return
 	}
 	}
 
 
+	if useExperiemntal("pull") {
+		rc := &registry.Client{
+			BaseURL: os.Getenv("OLLAMA_REGISTRY_BASE_URL"),
+		}
+		modelsDir, err := modelsDir()
+		if err != nil {
+			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+			return
+		}
+		cache := &cache{dir: modelsDir}
+		// TODO(bmizerany): progress updates
+		if err := rc.Pull(c.Request.Context(), cache, model); err != nil {
+			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+			return
+		}
+		return
+	}
+
 	ch := make(chan any)
 	ch := make(chan any)
 	go func() {
 	go func() {
 		defer close(ch)
 		defer close(ch)