Przeglądaj źródła

update push to use model.Name

Michael Yang 11 miesięcy temu
rodzic
commit
3e24edd9ed
3 zmienionych plików z 49 dodań i 53 usunięć
  1. 17 18
      server/images.go
  2. 8 18
      server/routes.go
  3. 24 17
      server/upload.go

+ 17 - 18
server/images.go

@@ -16,6 +16,7 @@ import (
 	"net/http"
 	"net/url"
 	"os"
+	"path"
 	"path/filepath"
 	"runtime"
 	"slices"
@@ -795,45 +796,42 @@ func PruneDirectory(path string) error {
 	return nil
 }
 
-func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
-	mp := ParseModelPath(name)
+func PushModel(ctx context.Context, name model.Name, opts registryOptions, fn func(api.ProgressResponse)) error {
 	fn(api.ProgressResponse{Status: "retrieving manifest"})
 
-	if mp.ProtocolScheme == "http" && !regOpts.Insecure {
-		return errors.New("insecure protocol http")
-	}
-
-	manifest, _, err := GetManifest(mp)
+	m, err := ParseNamedManifest(name)
 	if err != nil {
-		fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
 		return err
 	}
 
-	var layers []Layer
-	layers = append(layers, manifest.Layers...)
-	if manifest.Config.Digest != "" {
-		layers = append(layers, manifest.Config)
+	scheme := "https"
+	if opts.Insecure {
+		scheme = "http"
 	}
 
-	for _, layer := range layers {
-		if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
+	baseURL, err := url.Parse(fmt.Sprintf("%s://%s", scheme, path.Join(name.Host, "v2", name.Namespace, name.Model)))
+	if err != nil {
+		return err
+	}
+
+	for _, layer := range append(m.Layers, m.Config) {
+		if err := uploadBlob(ctx, uploadOptions{name: name, baseURL: baseURL, layer: layer, regOpts: &opts, fn: fn}); err != nil {
 			slog.Info(fmt.Sprintf("error uploading blob: %v", err))
 			return err
 		}
 	}
 
 	fn(api.ProgressResponse{Status: "pushing manifest"})
-	requestURL := mp.BaseURL()
-	requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
+	requestURL := baseURL.JoinPath("manifests", name.Tag)
 
-	manifestJSON, err := json.Marshal(manifest)
+	manifestJSON, err := json.Marshal(m)
 	if err != nil {
 		return err
 	}
 
 	headers := make(http.Header)
 	headers.Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json")
-	resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, bytes.NewReader(manifestJSON), regOpts)
+	resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, bytes.NewReader(manifestJSON), &opts)
 	if err != nil {
 		return err
 	}
@@ -1105,6 +1103,7 @@ func makeRequest(ctx context.Context, method string, requestURL *url.URL, header
 		return nil, err
 	}
 
+	slog.Debug("request upstream", "method", method, "request", requestURL.Redacted(), "status", resp.StatusCode)
 	return resp, nil
 }
 

+ 8 - 18
server/routes.go

@@ -514,24 +514,18 @@ func (s *Server) PullHandler(c *gin.Context) {
 }
 
 func (s *Server) PushHandler(c *gin.Context) {
-	var req api.PushRequest
-	err := c.ShouldBindJSON(&req)
-	switch {
-	case errors.Is(err, io.EOF):
+	var r api.PushRequest
+	if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
 		return
-	case err != nil:
+	} else if err != nil {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		return
 	}
 
-	var model string
-	if req.Model != "" {
-		model = req.Model
-	} else if req.Name != "" {
-		model = req.Name
-	} else {
-		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
+	n := model.ParseName(cmp.Or(r.Model, r.Name))
+	if !n.IsValid() {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
 		return
 	}
 
@@ -542,19 +536,15 @@ func (s *Server) PushHandler(c *gin.Context) {
 			ch <- r
 		}
 
-		regOpts := &registryOptions{
-			Insecure: req.Insecure,
-		}
-
 		ctx, cancel := context.WithCancel(c.Request.Context())
 		defer cancel()
 
-		if err := PushModel(ctx, model, regOpts, fn); err != nil {
+		if err := PushModel(ctx, n, registryOptions{Insecure: r.Insecure}, fn); err != nil {
 			ch <- gin.H{"error": err.Error()}
 		}
 	}()
 
-	if req.Stream != nil && !*req.Stream {
+	if r.Stream != nil && !*r.Stream {
 		waitForStream(c, ch)
 		return
 	}

+ 24 - 17
server/upload.go

@@ -21,6 +21,7 @@ import (
 
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/format"
+	"github.com/ollama/ollama/types/model"
 )
 
 var blobUploadManager sync.Map
@@ -108,7 +109,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *reg
 		offset += size
 	}
 
-	slog.Info(fmt.Sprintf("uploading %s in %d %s part(s)", b.Digest[7:19], len(b.Parts), format.HumanBytes(b.Parts[0].Size)))
+	slog.Info("uploading blob", "digest", b.Digest, "size", format.HumanBytes(b.Total), "parts", len(b.Parts), "size per part", format.HumanBytes(b.Parts[0].Size))
 
 	requestURL, err = url.Parse(location)
 	if err != nil {
@@ -362,40 +363,46 @@ func (p *progressWriter) Rollback() {
 	p.written = 0
 }
 
-func uploadBlob(ctx context.Context, mp ModelPath, layer Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
-	requestURL := mp.BaseURL()
-	requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)
+type uploadOptions struct {
+	name    model.Name
+	baseURL *url.URL
+	layer   Layer
+	regOpts *registryOptions
+	fn      func(api.ProgressResponse)
+}
+
+func uploadBlob(ctx context.Context, opts uploadOptions) error {
+	requestURL := opts.baseURL.JoinPath("blobs", opts.layer.Digest)
 
-	resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts)
+	resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts.regOpts)
 	switch {
 	case errors.Is(err, os.ErrNotExist):
 	case err != nil:
 		return err
 	default:
 		defer resp.Body.Close()
-		fn(api.ProgressResponse{
-			Status:    fmt.Sprintf("pushing %s", layer.Digest[7:19]),
-			Digest:    layer.Digest,
-			Total:     layer.Size,
-			Completed: layer.Size,
+		opts.fn(api.ProgressResponse{
+			Status:    fmt.Sprintf("pushing %s", opts.layer.Digest[7:19]),
+			Digest:    opts.layer.Digest,
+			Total:     opts.layer.Size,
+			Completed: opts.layer.Size,
 		})
 
 		return nil
 	}
 
-	data, ok := blobUploadManager.LoadOrStore(layer.Digest, &blobUpload{Layer: layer})
+	data, ok := blobUploadManager.LoadOrStore(opts.layer.Digest, &blobUpload{Layer: opts.layer})
 	upload := data.(*blobUpload)
 	if !ok {
-		requestURL := mp.BaseURL()
-		requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/")
-		if err := upload.Prepare(ctx, requestURL, opts); err != nil {
-			blobUploadManager.Delete(layer.Digest)
+		requestURL := opts.baseURL.JoinPath("blobs", "uploads")
+		if err := upload.Prepare(ctx, requestURL, opts.regOpts); err != nil {
+			blobUploadManager.Delete(opts.layer.Digest)
 			return err
 		}
 
 		//nolint:contextcheck
-		go upload.Run(context.Background(), opts)
+		go upload.Run(context.Background(), opts.regOpts)
 	}
 
-	return upload.Wait(ctx, fn)
+	return upload.Wait(ctx, opts.fn)
 }