Ver Fonte

Merge pull request #1250 from jmorganca/mxyng/create-layer

refactor layer creation
Michael Yang há 1 ano atrás
pai
commit
d3479c07a1
7 ficheiros alterados com 197 adições e 223 exclusões
  1. 4 3
      go.mod
  2. 0 2
      go.sum
  3. 1 1
      llm/ggml.go
  4. 45 193
      server/images.go
  5. 109 0
      server/layers.go
  6. 34 0
      server/manifests.go
  7. 4 24
      server/routes.go

+ 4 - 3
go.mod

@@ -5,14 +5,15 @@ go 1.20
 require (
 	github.com/emirpasic/gods v1.18.1
 	github.com/gin-gonic/gin v1.9.1
-	github.com/mattn/go-runewidth v0.0.14
-	github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db
 	github.com/olekukonko/tablewriter v0.0.5
 	github.com/spf13/cobra v1.7.0
 	golang.org/x/sync v0.3.0
 )
 
-require github.com/rivo/uniseg v0.2.0 // indirect
+require (
+	github.com/mattn/go-runewidth v0.0.14 // indirect
+	github.com/rivo/uniseg v0.2.0 // indirect
+)
 
 require (
 	github.com/bytedance/sonic v1.9.1 // indirect

+ 0 - 2
go.sum

@@ -63,8 +63,6 @@ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
 github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
 github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
 github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
-github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ=
-github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw=
 github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
 github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
 github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=

+ 1 - 1
llm/ggml.go

@@ -179,7 +179,7 @@ const (
 	FILE_MAGIC_GGUF_BE = 0x47475546
 )
 
-func DecodeGGML(r io.ReadSeeker) (*GGML, error) {
+func DecodeGGML(r io.Reader) (*GGML, error) {
 	var ggml GGML
 	binary.Read(r, binary.LittleEndian, &ggml.magic)
 

+ 45 - 193
server/images.go

@@ -19,8 +19,6 @@ import (
 	"strings"
 	"text/template"
 
-	"golang.org/x/exp/slices"
-
 	"github.com/jmorganca/ollama/api"
 	"github.com/jmorganca/ollama/llm"
 	"github.com/jmorganca/ollama/parser"
@@ -131,22 +129,10 @@ func (m *Model) ChatPrompt(msgs []api.Message) (string, error) {
 type ManifestV2 struct {
 	SchemaVersion int      `json:"schemaVersion"`
 	MediaType     string   `json:"mediaType"`
-	Config        Layer    `json:"config"`
+	Config        *Layer   `json:"config"`
 	Layers        []*Layer `json:"layers"`
 }
 
-type Layer struct {
-	MediaType string `json:"mediaType"`
-	Digest    string `json:"digest"`
-	Size      int64  `json:"size"`
-	From      string `json:"from,omitempty"`
-}
-
-type LayerReader struct {
-	Layer
-	io.Reader
-}
-
 type ConfigV2 struct {
 	ModelFormat string `json:"model_format"`
 	ModelFamily string `json:"model_family"`
@@ -304,11 +290,14 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
 	config := ConfigV2{
 		OS:           "linux",
 		Architecture: "amd64",
+		RootFS: RootFS{
+			Type: "layers",
+		},
 	}
 
 	deleteMap := make(map[string]struct{})
 
-	var layers []*LayerReader
+	var layers Layers
 
 	params := make(map[string][]string)
 	fromParams := make(map[string]any)
@@ -389,13 +378,12 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
 						}
 					}
 
-					layer, err := GetLayerWithBufferFromLayer(layer)
+					layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname())
 					if err != nil {
 						return err
 					}
 
-					layer.From = modelpath.GetShortTagname()
-					layers = append(layers, layer)
+					layers.Add(layer)
 				}
 
 				deleteMap[manifest.Config.Digest] = struct{}{}
@@ -415,13 +403,12 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
 			config.FileType = ggml.FileType()
 
 			bin.Seek(0, io.SeekStart)
-			layer, err := CreateLayer(bin)
+			layer, err := NewLayer(bin, mediatype)
 			if err != nil {
 				return err
 			}
 
-			layer.MediaType = mediatype
-			layers = append(layers, layer)
+			layers.Add(layer)
 		case "adapter":
 			if strings.HasPrefix(c.Args, "@") {
 				blobPath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
@@ -439,41 +426,32 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
 			}
 			defer bin.Close()
 
-			layer, err := CreateLayer(bin)
+			layer, err := NewLayer(bin, mediatype)
 			if err != nil {
 				return err
 			}
 
-			if layer.Size > 0 {
-				layer.MediaType = mediatype
-				layers = append(layers, layer)
-			}
+			layers.Add(layer)
 		case "license":
 			fn(api.ProgressResponse{Status: "creating license layer"})
-			layer, err := CreateLayer(strings.NewReader(c.Args))
+
+			bin := strings.NewReader(c.Args)
+			layer, err := NewLayer(bin, mediatype)
 			if err != nil {
 				return err
 			}
 
-			if layer.Size > 0 {
-				layer.MediaType = mediatype
-				layers = append(layers, layer)
-			}
+			layers.Add(layer)
 		case "template", "system":
 			fn(api.ProgressResponse{Status: fmt.Sprintf("creating %s layer", c.Name)})
 
-			// remove duplicate layers
-			layers = removeLayerFromLayers(layers, mediatype)
-
-			layer, err := CreateLayer(strings.NewReader(c.Args))
+			bin := strings.NewReader(c.Args)
+			layer, err := NewLayer(bin, mediatype)
 			if err != nil {
 				return err
 			}
 
-			if layer.Size > 0 {
-				layer.MediaType = mediatype
-				layers = append(layers, layer)
-			}
+			layers.Replace(layer)
 		default:
 			params[c.Name] = append(params[c.Name], c.Args)
 		}
@@ -505,164 +483,62 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
 		}
 
 		fn(api.ProgressResponse{Status: "creating config layer"})
-		layer, err := CreateLayer(bytes.NewReader(b.Bytes()))
+		layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
 		if err != nil {
 			return err
 		}
 
-		layer.MediaType = "application/vnd.ollama.image.params"
-		layers = append(layers, layer)
+		layers.Replace(layer)
 	}
 
-	digests, err := getLayerDigests(layers)
-	if err != nil {
-		return err
+	digests := make([]string, len(layers.items))
+	for i, layer := range layers.items {
+		digests[i] = layer.Digest
 	}
 
-	configLayer, err := createConfigLayer(config, digests)
-	if err != nil {
-		return err
-	}
+	config.RootFS.DiffIDs = digests
 
-	layers = append(layers, configLayer)
-	delete(deleteMap, configLayer.Digest)
-
-	if err := SaveLayers(layers, fn, false); err != nil {
+	var b bytes.Buffer
+	if err := json.NewEncoder(&b).Encode(config); err != nil {
 		return err
 	}
 
-	var contentLayers []*Layer
-	for _, layer := range layers {
-		contentLayers = append(contentLayers, &layer.Layer)
-		delete(deleteMap, layer.Digest)
-	}
-
-	fn(api.ProgressResponse{Status: "writing manifest"})
-	if err := CreateManifest(name, configLayer, contentLayers); err != nil {
+	configLayer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
+	if err != nil {
 		return err
 	}
 
-	if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
-		if err := deleteUnusedLayers(nil, deleteMap, false); err != nil {
-			return err
-		}
-	}
-
-	fn(api.ProgressResponse{Status: "success"})
-	return nil
-}
-
-func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerReader {
-	return slices.DeleteFunc(layers, func(layer *LayerReader) bool {
-		return layer.MediaType == mediaType
-	})
-}
+	delete(deleteMap, configLayer.Digest)
 
-func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force bool) error {
-	// Write each of the layers to disk
-	for _, layer := range layers {
-		fp, err := GetBlobsPath(layer.Digest)
+	for _, layer := range append(layers.items, configLayer) {
+		committed, err := layer.Commit()
 		if err != nil {
 			return err
 		}
 
-		_, err = os.Stat(fp)
-		if os.IsNotExist(err) || force {
-			fn(api.ProgressResponse{Status: fmt.Sprintf("writing layer %s", layer.Digest)})
-
-			out, err := os.Create(fp)
-			if err != nil {
-				log.Printf("couldn't create %s", fp)
-				return err
-			}
-			defer out.Close()
-
-			if _, err = io.Copy(out, layer.Reader); err != nil {
-				return err
-			}
-
-		} else {
-			fn(api.ProgressResponse{Status: fmt.Sprintf("using already created layer %s", layer.Digest)})
+		status := "writing layer"
+		if !committed {
+			status = "using already created layer"
 		}
-	}
 
-	return nil
-}
+		fn(api.ProgressResponse{Status: fmt.Sprintf("%s %s", status, layer.Digest)})
 
-func CreateManifest(name string, cfg *LayerReader, layers []*Layer) error {
-	mp := ParseModelPath(name)
-	manifest := ManifestV2{
-		SchemaVersion: 2,
-		MediaType:     "application/vnd.docker.distribution.manifest.v2+json",
-		Config: Layer{
-			MediaType: cfg.MediaType,
-			Size:      cfg.Size,
-			Digest:    cfg.Digest,
-		},
-		Layers: layers,
-	}
-
-	manifestJSON, err := json.Marshal(manifest)
-	if err != nil {
-		return err
+		delete(deleteMap, layer.Digest)
 	}
 
-	fp, err := mp.GetManifestPath()
-	if err != nil {
-		return err
-	}
-	if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil {
+	fn(api.ProgressResponse{Status: "writing manifest"})
+	if err := WriteManifest(name, configLayer, layers.items); err != nil {
 		return err
 	}
-	return os.WriteFile(fp, manifestJSON, 0o644)
-}
-
-func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) {
-	fp, err := GetBlobsPath(layer.Digest)
-	if err != nil {
-		return nil, err
-	}
-
-	file, err := os.Open(fp)
-	if err != nil {
-		return nil, fmt.Errorf("could not open blob: %w", err)
-	}
-	defer file.Close()
-
-	newLayer, err := CreateLayer(file)
-	if err != nil {
-		return nil, err
-	}
-	newLayer.MediaType = layer.MediaType
-	return newLayer, nil
-}
 
-func getLayerDigests(layers []*LayerReader) ([]string, error) {
-	var digests []string
-	for _, l := range layers {
-		if l.Digest == "" {
-			return nil, fmt.Errorf("layer is missing a digest")
+	if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
+		if err := deleteUnusedLayers(nil, deleteMap, false); err != nil {
+			return err
 		}
-		digests = append(digests, l.Digest)
-	}
-	return digests, nil
-}
-
-// CreateLayer creates a Layer object from a given file
-func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
-	digest, size := GetSHA256Digest(f)
-	f.Seek(0, io.SeekStart)
-
-	layer := &LayerReader{
-		Layer: Layer{
-			MediaType: "application/vnd.docker.image.rootfs.diff.tar",
-			Digest:    digest,
-			Size:      size,
-		},
-		Reader: f,
 	}
 
-	return layer, nil
+	fn(api.ProgressResponse{Status: "success"})
+	return nil
 }
 
 func CopyModel(src, dest string) error {
@@ -932,7 +808,7 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
 
 	var layers []*Layer
 	layers = append(layers, manifest.Layers...)
-	layers = append(layers, &manifest.Config)
+	layers = append(layers, manifest.Config)
 
 	for _, layer := range layers {
 		if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
@@ -1003,7 +879,7 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
 
 	var layers []*Layer
 	layers = append(layers, manifest.Layers...)
-	layers = append(layers, &manifest.Config)
+	layers = append(layers, manifest.Config)
 
 	for _, layer := range layers {
 		if err := downloadBlob(
@@ -1091,30 +967,6 @@ func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptio
 	return m, err
 }
 
-func createConfigLayer(config ConfigV2, layers []string) (*LayerReader, error) {
-	config.RootFS = RootFS{
-		Type:    "layers",
-		DiffIDs: layers,
-	}
-
-	configJSON, err := json.Marshal(config)
-	if err != nil {
-		return nil, err
-	}
-
-	digest, size := GetSHA256Digest(bytes.NewBuffer(configJSON))
-
-	layer := &LayerReader{
-		Layer: Layer{
-			MediaType: "application/vnd.docker.container.image.v1+json",
-			Digest:    digest,
-			Size:      size,
-		},
-		Reader: bytes.NewBuffer(configJSON),
-	}
-	return layer, nil
-}
-
 // GetSHA256Digest returns the SHA256 hash of a given buffer and returns it, and the size of buffer
 func GetSHA256Digest(r io.Reader) (string, int64) {
 	h := sha256.New()

+ 109 - 0
server/layers.go

@@ -0,0 +1,109 @@
+package server
+
+import (
+	"crypto/sha256"
+	"fmt"
+	"io"
+	"os"
+	"runtime"
+	"strings"
+
+	"golang.org/x/exp/slices"
+)
+
+type Layers struct {
+	items []*Layer
+}
+
+func (ls *Layers) Add(layer *Layer) {
+	if layer.Size > 0 {
+		ls.items = append(ls.items, layer)
+	}
+}
+
+func (ls *Layers) Replace(layer *Layer) {
+	if layer.Size > 0 {
+		mediatype := layer.MediaType
+		layers := slices.DeleteFunc(ls.items, func(l *Layer) bool {
+			return l.MediaType == mediatype
+		})
+
+		ls.items = append(layers, layer)
+	}
+}
+
+type Layer struct {
+	MediaType string `json:"mediaType"`
+	Digest    string `json:"digest"`
+	Size      int64  `json:"size"`
+	From      string `json:"from,omitempty"`
+
+	tempFileName string
+}
+
+func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
+	blobs, err := GetBlobsPath("")
+	if err != nil {
+		return nil, err
+	}
+
+	delimiter := ":"
+	if runtime.GOOS == "windows" {
+		delimiter = "-"
+	}
+
+	pattern := strings.Join([]string{"sha256", "*-partial"}, delimiter)
+	temp, err := os.CreateTemp(blobs, pattern)
+	if err != nil {
+		return nil, err
+	}
+	defer temp.Close()
+
+	sha256sum := sha256.New()
+	n, err := io.Copy(io.MultiWriter(temp, sha256sum), r)
+	if err != nil {
+		return nil, err
+	}
+
+	return &Layer{
+		MediaType:    mediatype,
+		Digest:       fmt.Sprintf("sha256:%x", sha256sum.Sum(nil)),
+		Size:         n,
+		tempFileName: temp.Name(),
+	}, nil
+}
+
+func NewLayerFromLayer(digest, mediatype, from string) (*Layer, error) {
+	blob, err := GetBlobsPath(digest)
+	if err != nil {
+		return nil, err
+	}
+
+	fi, err := os.Stat(blob)
+	if err != nil {
+		return nil, err
+	}
+
+	return &Layer{
+		MediaType: mediatype,
+		Digest:    digest,
+		Size:      fi.Size(),
+		From:      from,
+	}, nil
+}
+
+func (l *Layer) Commit() (bool, error) {
+	// always remove temp
+	defer os.Remove(l.tempFileName)
+
+	blob, err := GetBlobsPath(l.Digest)
+	if err != nil {
+		return false, err
+	}
+
+	if _, err := os.Stat(blob); err != nil {
+		return true, os.Rename(l.tempFileName, blob)
+	}
+
+	return false, nil
+}

+ 34 - 0
server/manifests.go

@@ -0,0 +1,34 @@
+package server
+
+import (
+	"bytes"
+	"encoding/json"
+	"os"
+	"path/filepath"
+)
+
+func WriteManifest(name string, config *Layer, layers []*Layer) error {
+	manifest := ManifestV2{
+		SchemaVersion: 2,
+		MediaType:     "application/vnd.docker.distribution.manifest.v2+json",
+		Config:        config,
+		Layers:        layers,
+	}
+
+	var b bytes.Buffer
+	if err := json.NewEncoder(&b).Encode(manifest); err != nil {
+		return err
+	}
+
+	modelpath := ParseModelPath(name)
+	manifestPath, err := modelpath.GetManifestPath()
+	if err != nil {
+		return err
+	}
+
+	if err := os.MkdirAll(filepath.Dir(manifestPath), 0755); err != nil {
+		return err
+	}
+
+	return os.WriteFile(manifestPath, b.Bytes(), 0644)
+}

+ 4 - 24
server/routes.go

@@ -2,7 +2,6 @@ package server
 
 import (
 	"context"
-	"crypto/sha256"
 	"encoding/json"
 	"errors"
 	"fmt"
@@ -735,37 +734,18 @@ func HeadBlobHandler(c *gin.Context) {
 }
 
 func CreateBlobHandler(c *gin.Context) {
-	targetPath, err := GetBlobsPath(c.Param("digest"))
+	layer, err := NewLayer(c.Request.Body, "")
 	if err != nil {
 		c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 	}
 
-	hash := sha256.New()
-	temp, err := os.CreateTemp(filepath.Dir(targetPath), c.Param("digest")+"-")
-	if err != nil {
-		c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-		return
-	}
-	defer temp.Close()
-	defer os.Remove(temp.Name())
-
-	if _, err := io.Copy(temp, io.TeeReader(c.Request.Body, hash)); err != nil {
-		c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-		return
-	}
-
-	if fmt.Sprintf("sha256:%x", hash.Sum(nil)) != c.Param("digest") {
-		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "digest does not match body"})
-		return
-	}
-
-	if err := temp.Close(); err != nil {
-		c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+	if layer.Digest != c.Param("digest") {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("digest mismatch, expected %q, got %q", c.Param("digest"), layer.Digest)})
 		return
 	}
 
-	if err := os.Rename(temp.Name(), targetPath); err != nil {
+	if _, err := layer.Commit(); err != nil {
 		c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 	}