Bläddra i källkod

refactor layer creation

previous layer creation was not ideal because:

1. it required reading the input file multiple times, once to calculate
   the sha256 checksum, another to write it to disk, and potentially one
   more to decode the underlying gguf
2. used io.ReadSeeker which is prone to user error. if the file isn't
   reset correctly or in the right place, it could end up reading an
   empty file

there are also some brittleness when reading existing layers else
writing the inherited layers will error reading an already closed file

this commit aims to fix these issues by restructuring layer creation.

1. it will now write the layer to a temporary file as well as the hash
   function and move it to the final location on Commit
2. layers are read once once when copied to the destination. exception
   is raw model files which still requires a second read to decode the
   model metadata
Michael Yang 1 år sedan
förälder
incheckning
70a93057cd
3 ändrade filer med 188 tillägg och 193 borttagningar
  1. 45 193
      server/images.go
  2. 109 0
      server/layers.go
  3. 34 0
      server/manifests.go

+ 45 - 193
server/images.go

@@ -19,8 +19,6 @@ import (
 	"strings"
 	"strings"
 	"text/template"
 	"text/template"
 
 
-	"golang.org/x/exp/slices"
-
 	"github.com/jmorganca/ollama/api"
 	"github.com/jmorganca/ollama/api"
 	"github.com/jmorganca/ollama/llm"
 	"github.com/jmorganca/ollama/llm"
 	"github.com/jmorganca/ollama/parser"
 	"github.com/jmorganca/ollama/parser"
@@ -128,22 +126,10 @@ func (m *Model) ChatPrompt(msgs []api.Message) (string, error) {
 type ManifestV2 struct {
 type ManifestV2 struct {
 	SchemaVersion int      `json:"schemaVersion"`
 	SchemaVersion int      `json:"schemaVersion"`
 	MediaType     string   `json:"mediaType"`
 	MediaType     string   `json:"mediaType"`
-	Config        Layer    `json:"config"`
+	Config        *Layer   `json:"config"`
 	Layers        []*Layer `json:"layers"`
 	Layers        []*Layer `json:"layers"`
 }
 }
 
 
-type Layer struct {
-	MediaType string `json:"mediaType"`
-	Digest    string `json:"digest"`
-	Size      int64  `json:"size"`
-	From      string `json:"from,omitempty"`
-}
-
-type LayerReader struct {
-	Layer
-	io.Reader
-}
-
 type ConfigV2 struct {
 type ConfigV2 struct {
 	ModelFormat string `json:"model_format"`
 	ModelFormat string `json:"model_format"`
 	ModelFamily string `json:"model_family"`
 	ModelFamily string `json:"model_family"`
@@ -301,11 +287,14 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
 	config := ConfigV2{
 	config := ConfigV2{
 		OS:           "linux",
 		OS:           "linux",
 		Architecture: "amd64",
 		Architecture: "amd64",
+		RootFS: RootFS{
+			Type: "layers",
+		},
 	}
 	}
 
 
 	deleteMap := make(map[string]struct{})
 	deleteMap := make(map[string]struct{})
 
 
-	var layers []*LayerReader
+	var layers Layers
 
 
 	params := make(map[string][]string)
 	params := make(map[string][]string)
 	fromParams := make(map[string]any)
 	fromParams := make(map[string]any)
@@ -386,13 +375,12 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
 						}
 						}
 					}
 					}
 
 
-					layer, err := GetLayerWithBufferFromLayer(layer)
+					layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname())
 					if err != nil {
 					if err != nil {
 						return err
 						return err
 					}
 					}
 
 
-					layer.From = modelpath.GetShortTagname()
-					layers = append(layers, layer)
+					layers.Add(layer)
 				}
 				}
 
 
 				deleteMap[manifest.Config.Digest] = struct{}{}
 				deleteMap[manifest.Config.Digest] = struct{}{}
@@ -412,13 +400,12 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
 			config.FileType = ggml.FileType()
 			config.FileType = ggml.FileType()
 
 
 			bin.Seek(0, io.SeekStart)
 			bin.Seek(0, io.SeekStart)
-			layer, err := CreateLayer(bin)
+			layer, err := NewLayer(bin, mediatype)
 			if err != nil {
 			if err != nil {
 				return err
 				return err
 			}
 			}
 
 
-			layer.MediaType = mediatype
-			layers = append(layers, layer)
+			layers.Add(layer)
 		case "adapter":
 		case "adapter":
 			if strings.HasPrefix(c.Args, "@") {
 			if strings.HasPrefix(c.Args, "@") {
 				blobPath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
 				blobPath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
@@ -436,41 +423,32 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
 			}
 			}
 			defer bin.Close()
 			defer bin.Close()
 
 
-			layer, err := CreateLayer(bin)
+			layer, err := NewLayer(bin, mediatype)
 			if err != nil {
 			if err != nil {
 				return err
 				return err
 			}
 			}
 
 
-			if layer.Size > 0 {
-				layer.MediaType = mediatype
-				layers = append(layers, layer)
-			}
+			layers.Add(layer)
 		case "license":
 		case "license":
 			fn(api.ProgressResponse{Status: "creating license layer"})
 			fn(api.ProgressResponse{Status: "creating license layer"})
-			layer, err := CreateLayer(strings.NewReader(c.Args))
+
+			bin := strings.NewReader(c.Args)
+			layer, err := NewLayer(bin, mediatype)
 			if err != nil {
 			if err != nil {
 				return err
 				return err
 			}
 			}
 
 
-			if layer.Size > 0 {
-				layer.MediaType = mediatype
-				layers = append(layers, layer)
-			}
+			layers.Add(layer)
 		case "template", "system":
 		case "template", "system":
 			fn(api.ProgressResponse{Status: fmt.Sprintf("creating %s layer", c.Name)})
 			fn(api.ProgressResponse{Status: fmt.Sprintf("creating %s layer", c.Name)})
 
 
-			// remove duplicate layers
-			layers = removeLayerFromLayers(layers, mediatype)
-
-			layer, err := CreateLayer(strings.NewReader(c.Args))
+			bin := strings.NewReader(c.Args)
+			layer, err := NewLayer(bin, mediatype)
 			if err != nil {
 			if err != nil {
 				return err
 				return err
 			}
 			}
 
 
-			if layer.Size > 0 {
-				layer.MediaType = mediatype
-				layers = append(layers, layer)
-			}
+			layers.Replace(layer)
 		default:
 		default:
 			params[c.Name] = append(params[c.Name], c.Args)
 			params[c.Name] = append(params[c.Name], c.Args)
 		}
 		}
@@ -502,164 +480,62 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
 		}
 		}
 
 
 		fn(api.ProgressResponse{Status: "creating config layer"})
 		fn(api.ProgressResponse{Status: "creating config layer"})
-		layer, err := CreateLayer(bytes.NewReader(b.Bytes()))
+		layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
 
 
-		layer.MediaType = "application/vnd.ollama.image.params"
-		layers = append(layers, layer)
+		layers.Replace(layer)
 	}
 	}
 
 
-	digests, err := getLayerDigests(layers)
-	if err != nil {
-		return err
+	digests := make([]string, len(layers.items))
+	for i, layer := range layers.items {
+		digests[i] = layer.Digest
 	}
 	}
 
 
-	configLayer, err := createConfigLayer(config, digests)
-	if err != nil {
-		return err
-	}
+	config.RootFS.DiffIDs = digests
 
 
-	layers = append(layers, configLayer)
-	delete(deleteMap, configLayer.Digest)
-
-	if err := SaveLayers(layers, fn, false); err != nil {
+	var b bytes.Buffer
+	if err := json.NewEncoder(&b).Encode(config); err != nil {
 		return err
 		return err
 	}
 	}
 
 
-	var contentLayers []*Layer
-	for _, layer := range layers {
-		contentLayers = append(contentLayers, &layer.Layer)
-		delete(deleteMap, layer.Digest)
-	}
-
-	fn(api.ProgressResponse{Status: "writing manifest"})
-	if err := CreateManifest(name, configLayer, contentLayers); err != nil {
+	configLayer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
+	if err != nil {
 		return err
 		return err
 	}
 	}
 
 
-	if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
-		if err := deleteUnusedLayers(nil, deleteMap, false); err != nil {
-			return err
-		}
-	}
-
-	fn(api.ProgressResponse{Status: "success"})
-	return nil
-}
-
-func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerReader {
-	return slices.DeleteFunc(layers, func(layer *LayerReader) bool {
-		return layer.MediaType == mediaType
-	})
-}
+	delete(deleteMap, configLayer.Digest)
 
 
-func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force bool) error {
-	// Write each of the layers to disk
-	for _, layer := range layers {
-		fp, err := GetBlobsPath(layer.Digest)
+	for _, layer := range append(layers.items, configLayer) {
+		committed, err := layer.Commit()
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
 
 
-		_, err = os.Stat(fp)
-		if os.IsNotExist(err) || force {
-			fn(api.ProgressResponse{Status: fmt.Sprintf("writing layer %s", layer.Digest)})
-
-			out, err := os.Create(fp)
-			if err != nil {
-				log.Printf("couldn't create %s", fp)
-				return err
-			}
-			defer out.Close()
-
-			if _, err = io.Copy(out, layer.Reader); err != nil {
-				return err
-			}
-
-		} else {
-			fn(api.ProgressResponse{Status: fmt.Sprintf("using already created layer %s", layer.Digest)})
+		status := "writing layer"
+		if !committed {
+			status = "using already created layer"
 		}
 		}
-	}
 
 
-	return nil
-}
+		fn(api.ProgressResponse{Status: fmt.Sprintf("%s %s", status, layer.Digest)})
 
 
-func CreateManifest(name string, cfg *LayerReader, layers []*Layer) error {
-	mp := ParseModelPath(name)
-	manifest := ManifestV2{
-		SchemaVersion: 2,
-		MediaType:     "application/vnd.docker.distribution.manifest.v2+json",
-		Config: Layer{
-			MediaType: cfg.MediaType,
-			Size:      cfg.Size,
-			Digest:    cfg.Digest,
-		},
-		Layers: layers,
-	}
-
-	manifestJSON, err := json.Marshal(manifest)
-	if err != nil {
-		return err
+		delete(deleteMap, layer.Digest)
 	}
 	}
 
 
-	fp, err := mp.GetManifestPath()
-	if err != nil {
-		return err
-	}
-	if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil {
+	fn(api.ProgressResponse{Status: "writing manifest"})
+	if err := WriteManifest(name, configLayer, layers.items); err != nil {
 		return err
 		return err
 	}
 	}
-	return os.WriteFile(fp, manifestJSON, 0o644)
-}
-
-func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) {
-	fp, err := GetBlobsPath(layer.Digest)
-	if err != nil {
-		return nil, err
-	}
-
-	file, err := os.Open(fp)
-	if err != nil {
-		return nil, fmt.Errorf("could not open blob: %w", err)
-	}
-	defer file.Close()
-
-	newLayer, err := CreateLayer(file)
-	if err != nil {
-		return nil, err
-	}
-	newLayer.MediaType = layer.MediaType
-	return newLayer, nil
-}
 
 
-func getLayerDigests(layers []*LayerReader) ([]string, error) {
-	var digests []string
-	for _, l := range layers {
-		if l.Digest == "" {
-			return nil, fmt.Errorf("layer is missing a digest")
+	if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
+		if err := deleteUnusedLayers(nil, deleteMap, false); err != nil {
+			return err
 		}
 		}
-		digests = append(digests, l.Digest)
-	}
-	return digests, nil
-}
-
-// CreateLayer creates a Layer object from a given file
-func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
-	digest, size := GetSHA256Digest(f)
-	f.Seek(0, io.SeekStart)
-
-	layer := &LayerReader{
-		Layer: Layer{
-			MediaType: "application/vnd.docker.image.rootfs.diff.tar",
-			Digest:    digest,
-			Size:      size,
-		},
-		Reader: f,
 	}
 	}
 
 
-	return layer, nil
+	fn(api.ProgressResponse{Status: "success"})
+	return nil
 }
 }
 
 
 func CopyModel(src, dest string) error {
 func CopyModel(src, dest string) error {
@@ -929,7 +805,7 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
 
 
 	var layers []*Layer
 	var layers []*Layer
 	layers = append(layers, manifest.Layers...)
 	layers = append(layers, manifest.Layers...)
-	layers = append(layers, &manifest.Config)
+	layers = append(layers, manifest.Config)
 
 
 	for _, layer := range layers {
 	for _, layer := range layers {
 		if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
 		if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
@@ -1000,7 +876,7 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
 
 
 	var layers []*Layer
 	var layers []*Layer
 	layers = append(layers, manifest.Layers...)
 	layers = append(layers, manifest.Layers...)
-	layers = append(layers, &manifest.Config)
+	layers = append(layers, manifest.Config)
 
 
 	for _, layer := range layers {
 	for _, layer := range layers {
 		if err := downloadBlob(
 		if err := downloadBlob(
@@ -1088,30 +964,6 @@ func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptio
 	return m, err
 	return m, err
 }
 }
 
 
-func createConfigLayer(config ConfigV2, layers []string) (*LayerReader, error) {
-	config.RootFS = RootFS{
-		Type:    "layers",
-		DiffIDs: layers,
-	}
-
-	configJSON, err := json.Marshal(config)
-	if err != nil {
-		return nil, err
-	}
-
-	digest, size := GetSHA256Digest(bytes.NewBuffer(configJSON))
-
-	layer := &LayerReader{
-		Layer: Layer{
-			MediaType: "application/vnd.docker.container.image.v1+json",
-			Digest:    digest,
-			Size:      size,
-		},
-		Reader: bytes.NewBuffer(configJSON),
-	}
-	return layer, nil
-}
-
 // GetSHA256Digest returns the SHA256 hash of a given buffer and returns it, and the size of buffer
 // GetSHA256Digest returns the SHA256 hash of a given buffer and returns it, and the size of buffer
 func GetSHA256Digest(r io.Reader) (string, int64) {
 func GetSHA256Digest(r io.Reader) (string, int64) {
 	h := sha256.New()
 	h := sha256.New()

+ 109 - 0
server/layers.go

@@ -0,0 +1,109 @@
+package server
+
+import (
+	"crypto/sha256"
+	"fmt"
+	"io"
+	"os"
+	"runtime"
+	"strings"
+
+	"golang.org/x/exp/slices"
+)
+
+type Layers struct {
+	items []*Layer
+}
+
+func (ls *Layers) Add(layer *Layer) {
+	if layer.Size > 0 {
+		ls.items = append(ls.items, layer)
+	}
+}
+
+func (ls *Layers) Replace(layer *Layer) {
+	if layer.Size > 0 {
+		mediatype := layer.MediaType
+		layers := slices.DeleteFunc(ls.items, func(l *Layer) bool {
+			return l.MediaType == mediatype
+		})
+
+		ls.items = append(layers, layer)
+	}
+}
+
+type Layer struct {
+	MediaType string `json:"mediaType"`
+	Digest    string `json:"digest"`
+	Size      int64  `json:"size"`
+	From      string `json:"from,omitempty"`
+
+	tempFileName string
+}
+
+func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
+	blobs, err := GetBlobsPath("")
+	if err != nil {
+		return nil, err
+	}
+
+	delimiter := ":"
+	if runtime.GOOS == "windows" {
+		delimiter = "-"
+	}
+
+	pattern := strings.Join([]string{"sha256", "*-partial"}, delimiter)
+	temp, err := os.CreateTemp(blobs, pattern)
+	if err != nil {
+		return nil, err
+	}
+	defer temp.Close()
+
+	sha256sum := sha256.New()
+	n, err := io.Copy(io.MultiWriter(temp, sha256sum), r)
+	if err != nil {
+		return nil, err
+	}
+
+	return &Layer{
+		MediaType:    mediatype,
+		Digest:       fmt.Sprintf("sha256:%x", sha256sum.Sum(nil)),
+		Size:         n,
+		tempFileName: temp.Name(),
+	}, nil
+}
+
+func NewLayerFromLayer(digest, mediatype, from string) (*Layer, error) {
+	blob, err := GetBlobsPath(digest)
+	if err != nil {
+		return nil, err
+	}
+
+	fi, err := os.Stat(blob)
+	if err != nil {
+		return nil, err
+	}
+
+	return &Layer{
+		MediaType: mediatype,
+		Digest:    digest,
+		Size:      fi.Size(),
+		From:      from,
+	}, nil
+}
+
+func (l *Layer) Commit() (bool, error) {
+	// always remove temp
+	defer os.Remove(l.tempFileName)
+
+	blob, err := GetBlobsPath(l.Digest)
+	if err != nil {
+		return false, err
+	}
+
+	if _, err := os.Stat(blob); err != nil {
+		return true, os.Rename(l.tempFileName, blob)
+	}
+
+	return false, nil
+}

+ 34 - 0
server/manifests.go

@@ -0,0 +1,34 @@
+package server
+
+import (
+	"bytes"
+	"encoding/json"
+	"os"
+	"path/filepath"
+)
+
+func WriteManifest(name string, config *Layer, layers []*Layer) error {
+	manifest := ManifestV2{
+		SchemaVersion: 2,
+		MediaType:     "application/vnd.docker.distribution.manifest.v2+json",
+		Config:        config,
+		Layers:        layers,
+	}
+
+	var b bytes.Buffer
+	if err := json.NewEncoder(&b).Encode(manifest); err != nil {
+		return err
+	}
+
+	modelpath := ParseModelPath(name)
+	manifestPath, err := modelpath.GetManifestPath()
+	if err != nil {
+		return err
+	}
+
+	if err := os.MkdirAll(filepath.Dir(manifestPath), 0755); err != nil {
+		return err
+	}
+
+	return os.WriteFile(manifestPath, b.Bytes(), 0644)
+}