Browse Source

cache and reuse intermediate blobs

particularly useful for zipfiles and f16s
Michael Yang 11 tháng trước cách đây
mục cha
commit
3520c0e4d5
4 tập tin đã thay đổi với 53 bổ sung18 xóa
  1. 24 3
      server/images.go
  2. 1 1
      server/layer.go
  3. 9 14
      server/model.go
  4. 19 0
      server/routes.go

+ 24 - 3
server/images.go

@@ -340,7 +340,24 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
 					return err
 				}
 			} else if strings.HasPrefix(c.Args, "@") {
-				blobpath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
+				digest := strings.TrimPrefix(c.Args, "@")
+				if ib, ok := intermediateBlobs.Load(digest); ok {
+					p, err := GetBlobsPath(ib.(string))
+					if err != nil {
+						return err
+					}
+
+					if _, err := os.Stat(p); errors.Is(err, os.ErrNotExist) {
+						// pass
+					} else if err != nil {
+						return err
+					} else {
+						fn(api.ProgressResponse{Status: fmt.Sprintf("using cached layer %s", ib.(string))})
+						digest = ib.(string)
+					}
+				}
+
+				blobpath, err := GetBlobsPath(digest)
 				if err != nil {
 					return err
 				}
@@ -351,14 +368,14 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
 				}
 				defer blob.Close()
 
-				baseLayers, err = parseFromFile(ctx, blob, fn)
+				baseLayers, err = parseFromFile(ctx, blob, digest, fn)
 				if err != nil {
 					return err
 				}
 			} else if file, err := os.Open(realpath(modelFileDir, c.Args)); err == nil {
 				defer file.Close()
 
-				baseLayers, err = parseFromFile(ctx, file, fn)
+				baseLayers, err = parseFromFile(ctx, file, "", fn)
 				if err != nil {
 					return err
 				}
@@ -398,10 +415,14 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
 							return err
 						}
 
+						f16digest := baseLayer.Layer.Digest
+
 						baseLayer.Layer, err = NewLayer(temp, baseLayer.Layer.MediaType)
 						if err != nil {
 							return err
 						}
+
+						intermediateBlobs.Store(f16digest, baseLayer.Layer.Digest)
 					}
 				}
 

+ 1 - 1
server/layer.go

@@ -80,7 +80,7 @@ func NewLayerFromLayer(digest, mediatype, from string) (*Layer, error) {
 	}, nil
 }
 
-func (l *Layer) Open() (io.ReadCloser, error) {
+func (l *Layer) Open() (io.ReadSeekCloser, error) {
 	blob, err := GetBlobsPath(l.Digest)
 	if err != nil {
 		return nil, err

+ 9 - 14
server/model.go

@@ -10,6 +10,7 @@ import (
 	"net/http"
 	"os"
 	"path/filepath"
+	"sync"
 
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/convert"
@@ -17,6 +18,8 @@ import (
 	"github.com/ollama/ollama/types/model"
 )
 
+var intermediateBlobs sync.Map
+
 type layerWithGGML struct {
 	*Layer
 	*llm.GGML
@@ -76,7 +79,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
 	return layers, nil
 }
 
-func parseFromZipFile(_ context.Context, file *os.File, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
+func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
 	stat, err := file.Stat()
 	if err != nil {
 		return nil, err
@@ -169,12 +172,7 @@ func parseFromZipFile(_ context.Context, file *os.File, fn func(api.ProgressResp
 		return nil, fmt.Errorf("aaa: %w", err)
 	}
 
-	blobpath, err := GetBlobsPath(layer.Digest)
-	if err != nil {
-		return nil, err
-	}
-
-	bin, err := os.Open(blobpath)
+	bin, err := layer.Open()
 	if err != nil {
 		return nil, err
 	}
@@ -185,16 +183,13 @@ func parseFromZipFile(_ context.Context, file *os.File, fn func(api.ProgressResp
 		return nil, err
 	}
 
-	layer, err = NewLayerFromLayer(layer.Digest, layer.MediaType, "")
-	if err != nil {
-		return nil, err
-	}
-
 	layers = append(layers, &layerWithGGML{layer, ggml})
+
+	intermediateBlobs.Store(digest, layer.Digest)
 	return layers, nil
 }
 
-func parseFromFile(ctx context.Context, file *os.File, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
+func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
 	sr := io.NewSectionReader(file, 0, 512)
 	contentType, err := detectContentType(sr)
 	if err != nil {
@@ -205,7 +200,7 @@ func parseFromFile(ctx context.Context, file *os.File, fn func(api.ProgressRespo
 	case "gguf", "ggla":
 		// noop
 	case "application/zip":
-		return parseFromZipFile(ctx, file, fn)
+		return parseFromZipFile(ctx, file, digest, fn)
 	default:
 		return nil, fmt.Errorf("unsupported content type: %s", contentType)
 	}

+ 19 - 0
server/routes.go

@@ -841,6 +841,25 @@ func (s *Server) HeadBlobHandler(c *gin.Context) {
 }
 
 func (s *Server) CreateBlobHandler(c *gin.Context) {
+	ib, ok := intermediateBlobs.Load(c.Param("digest"))
+	if ok {
+		p, err := GetBlobsPath(ib.(string))
+		if err != nil {
+			c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+			return
+		}
+
+		if _, err := os.Stat(p); errors.Is(err, os.ErrNotExist) {
+			intermediateBlobs.Delete(c.Param("digest"))
+		} else if err != nil {
+			c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+			return
+		} else {
+			c.Status(http.StatusOK)
+			return
+		}
+	}
+
 	path, err := GetBlobsPath(c.Param("digest"))
 	if err != nil {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})