浏览代码

Merge pull request #348 from jmorganca/cross-repo-mount

cross repo blob mount
Michael Yang 1 年之前
父节点
当前提交
5d9a4cd251
共有 1 个文件被更改,包括 24 次插入8 次删除
  1. 24 8
      server/images.go

+ 24 - 8
server/images.go

@@ -13,6 +13,7 @@ import (
 	"log"
 	"net/http"
 	"os"
+	"path"
 	"path/filepath"
 	"reflect"
 	"strconv"
@@ -94,6 +95,7 @@ type Layer struct {
 	MediaType string `json:"mediaType"`
 	Digest    string `json:"digest"`
 	Size      int    `json:"size"`
+	From      string `json:"from,omitempty"`
 }
 
 type LayerReader struct {
@@ -270,7 +272,8 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
 		case "model":
 			fn(api.ProgressResponse{Status: "looking for model"})
 			embed.model = c.Args
-			mf, err := GetManifest(ParseModelPath(c.Args))
+			mp := ParseModelPath(c.Args)
+			mf, err := GetManifest(mp)
 			if err != nil {
 				modelFile, err := filenameWithPath(path, c.Args)
 				if err != nil {
@@ -328,6 +331,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
 					if err != nil {
 						return err
 					}
+					newLayer.From = mp.GetNamespaceRepository()
 					layers = append(layers, newLayer)
 				}
 			}
@@ -451,8 +455,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
 	}
 	layers = append(layers, cfg)
 
-	err = SaveLayers(layers, fn, false)
-	if err != nil {
+	if err := SaveLayers(layers, fn, false); err != nil {
 		return err
 	}
 
@@ -926,14 +929,24 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
 			Total:  layer.Size,
 		})
 
-		location, err := startUpload(ctx, mp, regOpts)
+		location, err := startUpload(ctx, mp, layer, regOpts)
 		if err != nil {
 			log.Printf("couldn't start upload: %v", err)
 			return err
 		}
 
-		err = uploadBlobChunked(ctx, mp, location, layer, regOpts, fn)
-		if err != nil {
+		if strings.HasPrefix(path.Base(location), "sha256:") {
+			layer.Digest = path.Base(location)
+			fn(api.ProgressResponse{
+				Status:    "using existing layer",
+				Digest:    layer.Digest,
+				Total:     layer.Size,
+				Completed: layer.Size,
+			})
+			continue
+		}
+
+		if err := uploadBlobChunked(ctx, mp, location, layer, regOpts, fn); err != nil {
 			log.Printf("error uploading blob: %v", err)
 			return err
 		}
@@ -1093,8 +1106,11 @@ func GetSHA256Digest(r io.Reader) (string, int) {
 	return fmt.Sprintf("sha256:%x", h.Sum(nil)), int(n)
 }
 
-func startUpload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (string, error) {
+func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (string, error) {
 	url := fmt.Sprintf("%s/v2/%s/blobs/uploads/", mp.Registry, mp.GetNamespaceRepository())
+	if layer.From != "" {
+		url = fmt.Sprintf("%s/v2/%s/blobs/uploads/?mount=%s&from=%s", mp.Registry, mp.GetNamespaceRepository(), layer.Digest, layer.From)
+	}
 
 	resp, err := makeRequest(ctx, "POST", url, nil, nil, regOpts)
 	if err != nil {
@@ -1104,7 +1120,7 @@ func startUpload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (s
 	defer resp.Body.Close()
 
 	// Check for success
-	if resp.StatusCode != http.StatusAccepted {
+	if resp.StatusCode != http.StatusAccepted && resp.StatusCode != http.StatusCreated {
 		body, _ := io.ReadAll(resp.Body)
 		return "", fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body)
 	}