Blake Mizerany 1 år sedan
förälder
incheckning
8afe873f17
1 ändrade filer med 33 tillägg och 12 borttagningar
  1. 33 12
      client/registry/registry.go

+ 33 - 12
client/registry/registry.go

@@ -17,6 +17,7 @@ import (
 
 	"github.com/ollama/ollama/client/ollama"
 	"github.com/ollama/ollama/client/registry/apitype"
+	"github.com/ollama/ollama/types/model"
 	"golang.org/x/exp/constraints"
 	"golang.org/x/sync/errgroup"
 )
@@ -54,20 +55,23 @@ type Cache interface {
 	//
 	// If the digest is invalid, or the layer does not exist, the empty
 	// string is returned.
-	LayerFile(digest string) string
+	LayerFile(model.Digest) string
 
 	// OpenLayer opens the layer file for the given model digest and
 	// returns it, or an if any. The caller is responsible for closing
 	// the returned file.
-	OpenLayer(digest string) (ReadAtSeekCloser, error)
+	OpenLayer(model.Digest) (ReadAtSeekCloser, error)
 
 	// PutLayerFile moves the layer file at fromPath to the cache for
 	// the given model digest. It is a hack intended to short circuit a
 	// file copy operation.
 	//
+	// The file returned is expected to exist for the lifetime of the
+	// cache.
+	//
 	// TODO(bmizerany): remove this; find a better way. Once we move
 	// this into a build package, we should be able to get rid of this.
-	PutLayerFile(digest, fromPath string) error
+	PutLayerFile(_ model.Digest, fromPath string) error
 
 	// SetManifestData sets the provided manifest data for the given
 	// model name. If the manifest data is empty, the manifest is
@@ -75,19 +79,24 @@ type Cache interface {
 	//
 	// It is an error to call SetManifestData with a name that is not
 	// complete.
-	SetManifestData(name string, data []byte) error
+	SetManifestData(model.Name, []byte) error
 
 	// ManifestData returns the manifest data for the given model name.
 	//
 	// If the name incomplete, or the manifest does not exist, the empty
 	// string is returned.
-	ManifestData(name string) []byte
+	ManifestData(name model.Name) []byte
 }
 
 // Pull pulls the manifest for name, and downloads any of its required
 // layers that are not already in the cache. It returns an error if any part
 // of the process fails, specifically:
 func (c *Client) Pull(ctx context.Context, cache Cache, name string) error {
+	mn := model.ParseName(name)
+	if !mn.IsFullyQualified() {
+		return fmt.Errorf("ollama: pull: invalid name: %s", name)
+	}
+
 	log := c.logger().With("name", name)
 
 	pr, err := ollama.Do[*apitype.PullResponse](ctx, c.oclient(), "GET", "/v1/pull/"+name, nil)
@@ -101,10 +110,14 @@ func (c *Client) Pull(ctx context.Context, cache Cache, name string) error {
 
 	// download required layers we do not already have
 	for _, l := range pr.Manifest.Layers {
-		if cache.LayerFile(l.Digest) != "" {
+		d, err := model.ParseDigest(l.Digest)
+		if err != nil {
+			return fmt.Errorf("ollama: reading manifest: %w: %s", err, l.Digest)
+		}
+		if cache.LayerFile(d) != "" {
 			continue
 		}
-		err := func() error {
+		err = func() error {
 			log := log.With("digest", l.Digest, "mediaType", l.MediaType, "size", l.Size)
 			log.Debug("starting download")
 
@@ -170,7 +183,7 @@ func (c *Client) Pull(ctx context.Context, cache Cache, name string) error {
 			}
 
 			tmpFile.Close() // release our hold on the file before moving it
-			return cache.PutLayerFile(l.Digest, tmpFile.Name())
+			return cache.PutLayerFile(d, tmpFile.Name())
 		}()
 		if err != nil {
 			return fmt.Errorf("ollama: pull: %w", err)
@@ -187,7 +200,7 @@ func (c *Client) Pull(ctx context.Context, cache Cache, name string) error {
 	}
 
 	// TODO(bmizerany): remove dep on model.Name
-	return cache.SetManifestData(name, data)
+	return cache.SetManifestData(mn, data)
 }
 
 type nopSeeker struct {
@@ -205,7 +218,11 @@ func (nopSeeker) Seek(int64, int) (int64, error) {
 // If the server requests layers not found in the cache, ErrLayerNotFound is
 // returned.
 func (c *Client) Push(ctx context.Context, cache Cache, name string) error {
-	manifest := cache.ManifestData(name)
+	mn := model.ParseName(name)
+	if !mn.IsFullyQualified() {
+		return fmt.Errorf("ollama: push: invalid name: %s", name)
+	}
+	manifest := cache.ManifestData(mn)
 	if len(manifest) == 0 {
 		return fmt.Errorf("manifest not found: %s", name)
 	}
@@ -232,7 +249,11 @@ func (c *Client) Push(ctx context.Context, cache Cache, name string) error {
 	var g errgroup.Group
 	for _, need := range pr.Needs {
 		g.Go(func() error {
-			f, err := cache.OpenLayer(need.Digest)
+			nd, err := model.ParseDigest(need.Digest)
+			if err != nil {
+				return fmt.Errorf("ParseDigest: %w: %s", err, need.Digest)
+			}
+			f, err := cache.OpenLayer(nd)
 			if err != nil {
 				return fmt.Errorf("OpenLayer: %w: %s", err, need.Digest)
 			}
@@ -266,7 +287,7 @@ func (c *Client) Push(ctx context.Context, cache Cache, name string) error {
 		}
 	}
 
-	return cache.SetManifestData(name, manifest)
+	return cache.SetManifestData(mn, manifest)
 }
 
 func PushLayer(ctx context.Context, body io.ReaderAt, url string, start, end int64) (*apitype.CompletePart, error) {