소스 검색

update pull handler to use model.Name

Michael Yang 8 달 전
부모
커밋
6761aca1e1
4개의 변경된 파일43개의 추가작업 그리고 85개의 파일을 삭제
  1. 6 5
      server/download.go
  2. 28 65
      server/images.go
  3. 1 1
      server/model.go
  4. 8 14
      server/routes.go

+ 6 - 5
server/download.go

@@ -24,6 +24,7 @@ import (
 
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/format"
+	"github.com/ollama/ollama/types/model"
 )
 
 const maxRetries = 6
@@ -451,15 +452,16 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
 	}
 }
 
-type downloadOpts struct {
-	mp      ModelPath
+type downloadOptions struct {
+	name    model.Name
+	baseURL *url.URL
 	digest  string
 	regOpts *registryOptions
 	fn      func(api.ProgressResponse)
 }
 
 // downloadBlob downloads a blob from the registry and stores it in the blobs directory
-func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ error) {
+func downloadBlob(ctx context.Context, opts downloadOptions) (cacheHit bool, _ error) {
 	fp, err := GetBlobsPath(opts.digest)
 	if err != nil {
 		return false, err
@@ -484,8 +486,7 @@ func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ erro
 	data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
 	download := data.(*blobDownload)
 	if !ok {
-		requestURL := opts.mp.BaseURL()
-		requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
+		requestURL := opts.baseURL.JoinPath("blobs", opts.digest)
 		if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil {
 			blobDownloadManager.Delete(opts.digest)
 			return false, err

+ 28 - 65
server/images.go

@@ -797,8 +797,6 @@ func PruneDirectory(path string) error {
 }
 
 func PushModel(ctx context.Context, name model.Name, opts registryOptions, fn func(api.ProgressResponse)) error {
-	fn(api.ProgressResponse{Status: "retrieving manifest"})
-
 	m, err := ParseNamedManifest(name)
 	if err != nil {
 		return err
@@ -842,118 +840,83 @@ func PushModel(ctx context.Context, name model.Name, opts registryOptions, fn fu
 	return nil
 }
 
-func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
-	mp := ParseModelPath(name)
+func PullModel(ctx context.Context, name model.Name, opts *registryOptions, fn func(api.ProgressResponse)) error {
+	mm, _ := ParseNamedManifest(name)
 
-	// build deleteMap to prune unused layers
-	deleteMap := make(map[string]struct{})
-	manifest, _, err := GetManifest(mp)
-	if errors.Is(err, os.ErrNotExist) {
-		// noop
-	} else if err != nil && !errors.Is(err, os.ErrNotExist) {
-		return err
-	} else {
-		for _, l := range manifest.Layers {
-			deleteMap[l.Digest] = struct{}{}
-		}
-		if manifest.Config.Digest != "" {
-			deleteMap[manifest.Config.Digest] = struct{}{}
-		}
+	scheme := "https"
+	if opts.Insecure {
+		scheme = "http"
 	}
 
-	if mp.ProtocolScheme == "http" && !regOpts.Insecure {
-		return errors.New("insecure protocol http")
+	baseURL, err := url.Parse(fmt.Sprintf("%s://%s", scheme, path.Join(name.Host, "v2", name.Namespace, name.Model)))
+	if err != nil {
+		return err
 	}
 
 	fn(api.ProgressResponse{Status: "pulling manifest"})
-
-	manifest, err = pullModelManifest(ctx, mp, regOpts)
+	m, err := pullModelManifest(ctx, name, baseURL, opts)
 	if err != nil {
 		return fmt.Errorf("pull model manifest: %s", err)
 	}
 
-	var layers []Layer
-	layers = append(layers, manifest.Layers...)
-	if manifest.Config.Digest != "" {
-		layers = append(layers, manifest.Config)
-	}
+	layers := append(m.Layers, m.Config)
 
 	skipVerify := make(map[string]bool)
 	for _, layer := range layers {
-		cacheHit, err := downloadBlob(ctx, downloadOpts{
-			mp:      mp,
+		hit, err := downloadBlob(ctx, downloadOptions{
+			name:    name,
+			baseURL: baseURL,
 			digest:  layer.Digest,
-			regOpts: regOpts,
+			regOpts: opts,
 			fn:      fn,
 		})
 		if err != nil {
 			return err
 		}
-		skipVerify[layer.Digest] = cacheHit
-		delete(deleteMap, layer.Digest)
+
+		skipVerify[layer.Digest] = hit
 	}
-	delete(deleteMap, manifest.Config.Digest)
 
 	fn(api.ProgressResponse{Status: "verifying sha256 digest"})
 	for _, layer := range layers {
-		if skipVerify[layer.Digest] {
-			continue
-		}
-		if err := verifyBlob(layer.Digest); err != nil {
-			if errors.Is(err, errDigestMismatch) {
+		if !skipVerify[layer.Digest] {
+			if err := verifyBlob(layer.Digest); errors.Is(err, errDigestMismatch) {
 				// something went wrong, delete the blob
 				fp, err := GetBlobsPath(layer.Digest)
 				if err != nil {
 					return err
 				}
+
 				if err := os.Remove(fp); err != nil {
 					// log this, but return the original error
 					slog.Info(fmt.Sprintf("couldn't remove file with digest mismatch '%s': %v", fp, err))
 				}
+			} else if err != nil {
+				return err
 			}
-			return err
 		}
 	}
 
 	fn(api.ProgressResponse{Status: "writing manifest"})
-
-	manifestJSON, err := json.Marshal(manifest)
-	if err != nil {
+	if err := WriteManifest(name, m.Config, m.Layers); err != nil {
 		return err
 	}
 
-	fp, err := mp.GetManifestPath()
-	if err != nil {
-		return err
-	}
-	if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil {
-		return err
-	}
-
-	err = os.WriteFile(fp, manifestJSON, 0o644)
-	if err != nil {
-		slog.Info(fmt.Sprintf("couldn't write to %s", fp))
-		return err
-	}
-
-	if !envconfig.NoPrune() && len(deleteMap) > 0 {
-		fn(api.ProgressResponse{Status: "removing unused layers"})
-		if err := deleteUnusedLayers(deleteMap); err != nil {
-			fn(api.ProgressResponse{Status: fmt.Sprintf("couldn't remove unused layers: %v", err)})
-		}
+	if !envconfig.NoPrune() && mm != nil {
+		fn(api.ProgressResponse{Status: "pruning old layers"})
+		_ = mm.RemoveLayers()
 	}
 
 	fn(api.ProgressResponse{Status: "success"})
-
 	return nil
 }
 
-func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
-	requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
+func pullModelManifest(ctx context.Context, name model.Name, baseURL *url.URL, opts *registryOptions) (*Manifest, error) {
+	requestURL := baseURL.JoinPath("manifests", name.Tag)
 
 	headers := make(http.Header)
 	headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
-	resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, regOpts)
+	resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
 	if err != nil {
 		return nil, err
 	}

+ 1 - 1
server/model.go

@@ -34,7 +34,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
 	m, err := ParseNamedManifest(name)
 	switch {
 	case errors.Is(err, os.ErrNotExist):
-		if err := PullModel(ctx, name.String(), &registryOptions{}, fn); err != nil {
+		if err := PullModel(ctx, name, &registryOptions{}, fn); err != nil {
 			return nil, err
 		}
 

+ 8 - 14
server/routes.go

@@ -464,24 +464,22 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
 }
 
 func (s *Server) PullHandler(c *gin.Context) {
-	var req api.PullRequest
-	err := c.ShouldBindJSON(&req)
-	switch {
-	case errors.Is(err, io.EOF):
+	var r api.PullRequest
+	if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
 		return
-	case err != nil:
+	} else if err != nil {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		return
 	}
 
-	name := model.ParseName(cmp.Or(req.Model, req.Name))
-	if !name.IsValid() {
+	n := model.ParseName(cmp.Or(r.Model, r.Name))
+	if !n.IsValid() {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid model name"})
 		return
 	}
 
-	if err := checkNameExists(name); err != nil {
+	if err := checkNameExists(n); err != nil {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		return
 	}
@@ -493,19 +491,15 @@ func (s *Server) PullHandler(c *gin.Context) {
 			ch <- r
 		}
 
-		regOpts := &registryOptions{
-			Insecure: req.Insecure,
-		}
-
 		ctx, cancel := context.WithCancel(c.Request.Context())
 		defer cancel()
 
-		if err := PullModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil {
+		if err := PullModel(ctx, n, &registryOptions{Insecure: r.Insecure}, fn); err != nil {
 			ch <- gin.H{"error": err.Error()}
 		}
 	}()
 
-	if req.Stream != nil && !*req.Stream {
+	if r.Stream != nil && !*r.Stream {
 		waitForStream(c, ch)
 		return
 	}