浏览代码

handle intermediate blobs

Michael Yang 1 年之前
父节点
当前提交
dc474f9b83
共有 2 个文件被更改,包括 101 次插入9 次删除
  1. 99 8
      server/images.go
  2. 2 1
      server/layer.go

+ 99 - 8
server/images.go

@@ -410,10 +410,17 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, c
 					if err != nil {
 						return err
 					}
+					metadataLayer.Intermediate = true
 					metadataLayer.MergeBase = baseLayer.Digest
 
 					layers = append(layers, metadataLayer)
 
+					metadataPath, err := GetBlobsPath(metadataLayer.Digest)
+					if err != nil {
+						return err
+					}
+					defer os.Remove(metadataPath)
+
 					stat, err := f.Stat()
 					if err != nil {
 						return err
@@ -424,9 +431,16 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, c
 					if err != nil {
 						return err
 					}
+					dataLayer.Intermediate = true
 					dataLayer.MergeBase = baseLayer.Digest
 
 					layers = append(layers, dataLayer)
+
+					dataPath, err := GetBlobsPath(dataLayer.Digest)
+					if err != nil {
+						return err
+					}
+					defer os.Remove(dataPath)
 					continue
 				}
 
@@ -813,6 +827,49 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
 	layers = append(layers, manifest.Layers...)
 	layers = append(layers, manifest.Config)
 
+	for _, layer := range layers {
+		if !layer.Intermediate {
+			continue
+		}
+
+		switch layer.MediaType {
+		case "application/vnd.ollama.image.model+metadata", "application/vnd.ollama.image.model+data":
+			if _, err := GetBlobsPath(layer.MergeBase); errors.Is(err, os.ErrNotExist) {
+				filename, err := GetBlobsPath(layer.MergeBase)
+				if err != nil {
+					return err
+				}
+
+				f, err := os.Open(filename)
+				if err != nil {
+					return err
+				}
+				defer f.Close()
+
+				ggml, size, err := llm.DecodeGGML(f)
+				if err != nil {
+					return err
+				}
+
+				if _, err := f.Seek(0, io.SeekStart); err != nil {
+					return err
+				}
+
+				metadata := io.NewSectionReader(f, 0, ggml.Offset())
+				if _, err := NewLayer(metadata, "application/vnd.ollama.image.model+metadata"); err != nil {
+					return err
+				}
+
+				data := io.NewSectionReader(f, ggml.Offset(), size)
+				if _, err := NewLayer(data, "application/vnd.ollama.image.model+metadata"); err != nil {
+					return err
+				}
+			} else if err != nil {
+				return err
+			}
+		}
+	}
+
 	for _, layer := range layers {
 		if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
 			slog.Info(fmt.Sprintf("error uploading blob: %v", err))
@@ -882,6 +939,27 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
 	layers = append(layers, manifest.Config)
 
 	for _, layer := range layers {
+		if layer.Intermediate {
+			filename, err := GetBlobsPath(layer.MergeBase)
+			if err != nil {
+				return err
+			}
+
+			if _, err := os.Stat(filename); errors.Is(err, os.ErrNotExist) {
+				// pass
+			} else if err != nil {
+				return err
+			} else {
+				fn(api.ProgressResponse{
+					Status:    fmt.Sprintf("pulling %s", layer.Digest[7:19]),
+					Digest:    layer.Digest,
+					Total:     layer.Size,
+					Completed: layer.Size,
+				})
+				continue
+			}
+		}
+
 		if err := downloadBlob(
 			ctx,
 			downloadOpts{
@@ -902,16 +980,27 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
 
 	mergedLayers := make(map[string]mergedLayer)
 	for _, layer := range manifest.Layers {
-		merged := mergedLayers[layer.MergeBase]
-		if layer.MediaType == "application/vnd.ollama.image.model+metadata" {
-			merged.Metadata = layer
-		} else if layer.MediaType == "application/vnd.ollama.image.model+data" {
-			merged.Data = layer
+		filename, err := GetBlobsPath(layer.MergeBase)
+		if err != nil {
+			return err
+		}
+
+		if _, err := os.Stat(filename); errors.Is(err, os.ErrNotExist) {
+			merged := mergedLayers[layer.MergeBase]
+			if layer.MediaType == "application/vnd.ollama.image.model+metadata" {
+				merged.Metadata = layer
+			} else if layer.MediaType == "application/vnd.ollama.image.model+data" {
+				merged.Data = layer
+			} else {
+				continue
+			}
+
+			mergedLayers[layer.MergeBase] = merged
+		} else if err != nil {
+			return err
 		} else {
 			continue
 		}
-
-		mergedLayers[layer.MergeBase] = merged
 	}
 
 	for _, mergedLayer := range mergedLayers {
@@ -935,7 +1024,9 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
 
 	fn(api.ProgressResponse{Status: "verifying sha256 digest"})
 	for _, layer := range layers {
-		if err := verifyBlob(layer.Digest); err != nil {
+		if err := verifyBlob(layer.Digest); errors.Is(err, os.ErrNotExist) && layer.Intermediate {
+			// pass
+		} else if err != nil {
 			if errors.Is(err, errDigestMismatch) {
 				// something went wrong, delete the blob
 				fp, err := GetBlobsPath(layer.Digest)

+ 2 - 1
server/layer.go

@@ -13,7 +13,8 @@ type Layer struct {
 	Size      int64  `json:"size"`
 	From      string `json:"from,omitempty"`
 
-	MergeBase string `json:"merge_base,omitempty"`
+	Intermediate bool   `json:"intermediate,omitempty"`
+	MergeBase    string `json:"merge_base,omitempty"`
 
 	message string
 }