Michael Yang 1 год назад
Родитель
Сommit
b535afe35c
3 измененных файлов с 34 добавлено и 85 удалено
  1. 21 41
      server/images.go
  2. 13 12
      server/model.go
  3. 0 32
      types/ordered/map.go

+ 21 - 41
server/images.go

@@ -30,7 +30,6 @@ import (
 	"github.com/ollama/ollama/parser"
 	"github.com/ollama/ollama/types/errtypes"
 	"github.com/ollama/ollama/types/model"
-	"github.com/ollama/ollama/types/ordered"
 	"github.com/ollama/ollama/version"
 )
 
@@ -316,7 +315,7 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, c
 
 		switch c.Name {
 		case "model", "adapter":
-			var baseLayers *ordered.Map[*Layer, *llm.GGML]
+			var baseLayers []*layerWithGGML
 			if name := model.ParseName(c.Args); name.IsValid() {
 				baseLayers, err = parseFromModel(ctx, name, fn)
 				if err != nil {
@@ -349,70 +348,51 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, c
 				return fmt.Errorf("invalid model reference: %s", c.Args)
 			}
 
-			var err2 error
-			var tempfiles []*os.File
-
-			// TODO(mxyng): replace with rangefunc
-			baseLayers.Items()(func(layer *Layer, ggml *llm.GGML) bool {
-				if quantization != "" && ggml != nil && ggml.Name() == "gguf" {
+			for _, baseLayer := range baseLayers {
+				if quantization != "" && baseLayer.GGML != nil && baseLayer.GGML.Name() == "gguf" {
 					ftype, err := llm.ParseFileType(quantization)
 					if err != nil {
-						err2 = err
-						return false
+						return err
 					}
 
-					filetype := ggml.KV().FileType()
+					filetype := baseLayer.GGML.KV().FileType()
 					if !slices.Contains([]string{"F16", "F32"}, filetype) {
-						err2 = errors.New("quantization is only supported for F16 and F32 models")
-						return false
+						return errors.New("quantization is only supported for F16 and F32 models")
 					}
 
 					fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", filetype, quantization)})
 
-					blob, err := GetBlobsPath(layer.Digest)
+					blob, err := GetBlobsPath(baseLayer.Digest)
 					if err != nil {
-						err2 = err
-						return false
+						return err
 					}
 
 					temp, err := os.CreateTemp(filepath.Dir(blob), quantization)
 					if err != nil {
-						err2 = err
-						return false
+						return err
 					}
-					tempfiles = append(tempfiles, temp)
+					defer temp.Close()
+					defer os.Remove(temp.Name())
 
 					if err := llm.Quantize(blob, temp.Name(), ftype); err != nil {
-						err2 = err
-						return false
+						return err
 					}
 
-					layer, err = NewLayer(temp, layer.MediaType)
+					baseLayer.Layer, err = NewLayer(temp, baseLayer.Layer.MediaType)
 					if err != nil {
-						err2 = err
-						return false
+						return err
 					}
 				}
 
-				if ggml != nil {
-					config.ModelFormat = cmp.Or(config.ModelFormat, ggml.Name())
-					config.ModelFamily = cmp.Or(config.ModelFamily, ggml.KV().Architecture())
-					config.ModelType = cmp.Or(config.ModelType, format.HumanNumber(ggml.KV().ParameterCount()))
-					config.FileType = cmp.Or(config.FileType, ggml.KV().FileType())
-					config.ModelFamilies = append(config.ModelFamilies, ggml.KV().Architecture())
+				if baseLayer.GGML != nil {
+					config.ModelFormat = cmp.Or(config.ModelFormat, baseLayer.GGML.Name())
+					config.ModelFamily = cmp.Or(config.ModelFamily, baseLayer.GGML.KV().Architecture())
+					config.ModelType = cmp.Or(config.ModelType, format.HumanNumber(baseLayer.GGML.KV().ParameterCount()))
+					config.FileType = cmp.Or(config.FileType, baseLayer.GGML.KV().FileType())
+					config.ModelFamilies = append(config.ModelFamilies, baseLayer.GGML.KV().Architecture())
 				}
 
-				layers = append(layers, layer)
-				return true
-			})
-
-			for _, tempfile := range tempfiles {
-				defer tempfile.Close()
-				defer os.Remove(tempfile.Name())
-			}
-
-			if err2 != nil {
-				return err2
+				layers = append(layers, baseLayer.Layer)
 			}
 		case "license", "template", "system":
 			blob := strings.NewReader(c.Args)

+ 13 - 12
server/model.go

@@ -15,10 +15,14 @@ import (
 	"github.com/ollama/ollama/convert"
 	"github.com/ollama/ollama/llm"
 	"github.com/ollama/ollama/types/model"
-	"github.com/ollama/ollama/types/ordered"
 )
 
-func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (*ordered.Map[*Layer, *llm.GGML], error) {
+type layerWithGGML struct {
+	*Layer
+	*llm.GGML
+}
+
+func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
 	modelpath := ParseModelPath(name.DisplayLongest())
 	manifest, _, err := GetManifest(modelpath)
 	switch {
@@ -36,7 +40,6 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
 		return nil, err
 	}
 
-	layers := ordered.NewMap[*Layer, *llm.GGML]()
 	for _, layer := range manifest.Layers {
 		layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname())
 		if err != nil {
@@ -62,9 +65,10 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
 			if err != nil {
 				return nil, err
 			}
-			layers.Add(layer, ggml)
+
+			layers = append(layers, &layerWithGGML{layer, ggml})
 		default:
-			layers.Add(layer, nil)
+			layers = append(layers, &layerWithGGML{layer, nil})
 		}
 
 	}
@@ -72,7 +76,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
 	return layers, nil
 }
 
-func parseFromZipFile(_ context.Context, file *os.File, fn func(api.ProgressResponse)) (*ordered.Map[*Layer, *llm.GGML], error) {
+func parseFromZipFile(_ context.Context, file *os.File, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
 	stat, err := file.Stat()
 	if err != nil {
 		return nil, err
@@ -184,12 +188,11 @@ func parseFromZipFile(_ context.Context, file *os.File, fn func(api.ProgressResp
 		return nil, err
 	}
 
-	layers := ordered.NewMap[*Layer, *llm.GGML]()
-	layers.Add(layer, ggml)
+	layers = append(layers, &layerWithGGML{layer, ggml})
 	return layers, nil
 }
 
-func parseFromFile(ctx context.Context, file *os.File, fn func(api.ProgressResponse)) (*ordered.Map[*Layer, *llm.GGML], error) {
+func parseFromFile(ctx context.Context, file *os.File, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
 	sr := io.NewSectionReader(file, 0, 512)
 	contentType, err := detectContentType(sr)
 	if err != nil {
@@ -205,8 +208,6 @@ func parseFromFile(ctx context.Context, file *os.File, fn func(api.ProgressRespo
 		return nil, fmt.Errorf("unsupported content type: %s", contentType)
 	}
 
-	layers := ordered.NewMap[*Layer, *llm.GGML]()
-
 	stat, err := file.Stat()
 	if err != nil {
 		return nil, err
@@ -233,7 +234,7 @@ func parseFromFile(ctx context.Context, file *os.File, fn func(api.ProgressRespo
 			return nil, err
 		}
 
-		layers.Add(layer, ggml)
+		layers = append(layers, &layerWithGGML{layer, ggml})
 		offset = n
 	}
 

+ 0 - 32
types/ordered/map.go

@@ -1,32 +0,0 @@
-package ordered
-
-type Map[K comparable, V any] struct {
-	s []K
-	m map[K]V
-}
-
-func NewMap[K comparable, V any]() *Map[K, V] {
-	return &Map[K, V]{
-		s: make([]K, 0),
-		m: make(map[K]V),
-	}
-}
-
-type iter_Seq2[K, V any] func(func(K, V) bool)
-
-func (m *Map[K, V]) Items() iter_Seq2[K, V] {
-	return func(yield func(K, V) bool) {
-		for _, k := range m.s {
-			if !yield(k, m.m[k]) {
-				return
-			}
-		}
-	}
-}
-
-func (m *Map[K, V]) Add(k K, v V) {
-	if _, ok := m.m[k]; !ok {
-		m.s = append(m.s, k)
-		m.m[k] = v
-	}
-}