Переглянути джерело

fix: multiple templates when creating from model

multiple templates may appear in a model if a model is created from
another model that 1) has an autodetected template and 2) defines a
custom template
Michael Yang 10 місяців тому
батько
коміт
c16f8af911
3 змінених файлів з 33 додано та 26 видалено
  1. 1 16
      server/images.go
  2. 31 9
      server/model.go
  3. 1 1
      server/routes_create_test.go

+ 1 - 16
server/images.go

@@ -28,7 +28,6 @@ import (
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/llm"
 	"github.com/ollama/ollama/parser"
-	"github.com/ollama/ollama/templates"
 	"github.com/ollama/ollama/types/errtypes"
 	"github.com/ollama/ollama/types/model"
 	"github.com/ollama/ollama/version"
@@ -333,7 +332,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
 
 		switch c.Name {
 		case "model", "adapter":
-			var baseLayers []*layerWithGGML
+			var baseLayers []*layerGGML
 			if name := model.ParseName(c.Args); name.IsValid() {
 				baseLayers, err = parseFromModel(ctx, name, fn)
 				if err != nil {
@@ -435,20 +434,6 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
 					config.ModelType = cmp.Or(config.ModelType, format.HumanNumber(baseLayer.GGML.KV().ParameterCount()))
 					config.FileType = cmp.Or(config.FileType, baseLayer.GGML.KV().FileType().String())
 					config.ModelFamilies = append(config.ModelFamilies, baseLayer.GGML.KV().Architecture())
-
-					if s := baseLayer.GGML.KV().ChatTemplate(); s != "" {
-						if t, err := templates.NamedTemplate(s); err != nil {
-							slog.Debug("template detection", "error", err)
-						} else {
-							layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
-							if err != nil {
-								return err
-							}
-
-							layer.status = fmt.Sprintf("using autodetected template %s", t.Name)
-							layers = append(layers, layer)
-						}
-					}
 				}
 
 				layers = append(layers, baseLayer.Layer)

+ 31 - 9
server/model.go

@@ -7,6 +7,7 @@ import (
 	"errors"
 	"fmt"
 	"io"
+	"log/slog"
 	"net/http"
 	"os"
 	"path/filepath"
@@ -14,17 +15,18 @@ import (
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/convert"
 	"github.com/ollama/ollama/llm"
+	"github.com/ollama/ollama/templates"
 	"github.com/ollama/ollama/types/model"
 )
 
 var intermediateBlobs map[string]string = make(map[string]string)
 
-type layerWithGGML struct {
+type layerGGML struct {
 	*Layer
 	*llm.GGML
 }
 
-func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
+func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
 	m, err := ParseNamedManifest(name)
 	switch {
 	case errors.Is(err, os.ErrNotExist):
@@ -66,16 +68,16 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
 				return nil, err
 			}
 
-			layers = append(layers, &layerWithGGML{layer, ggml})
+			layers = append(layers, &layerGGML{layer, ggml})
 		default:
-			layers = append(layers, &layerWithGGML{layer, nil})
+			layers = append(layers, &layerGGML{layer, nil})
 		}
 	}
 
 	return layers, nil
 }
 
-func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
+func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
 	stat, err := file.Stat()
 	if err != nil {
 		return nil, err
@@ -179,13 +181,13 @@ func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(a
 		return nil, err
 	}
 
-	layers = append(layers, &layerWithGGML{layer, ggml})
+	layers = append(layers, &layerGGML{layer, ggml})
 
 	intermediateBlobs[digest] = layer.Digest
-	return layers, nil
+	return detectChatTemplate(layers)
 }
 
-func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
+func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
 	sr := io.NewSectionReader(file, 0, 512)
 	contentType, err := detectContentType(sr)
 	if err != nil {
@@ -227,10 +229,30 @@ func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(ap
 			return nil, err
 		}
 
-		layers = append(layers, &layerWithGGML{layer, ggml})
+		layers = append(layers, &layerGGML{layer, ggml})
 		offset = n
 	}
 
+	return detectChatTemplate(layers)
+}
+
+func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
+	for _, layer := range layers {
+		if s := layer.GGML.KV().ChatTemplate(); s != "" {
+			if t, err := templates.NamedTemplate(s); err != nil {
+				slog.Debug("template detection", "error", err)
+			} else {
+				tmpl, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
+				if err != nil {
+					return nil, err
+				}
+
+				tmpl.status = fmt.Sprintf("using autodetected template %s", t.Name)
+				layers = append(layers, &layerGGML{tmpl, nil})
+			}
+		}
+	}
+
 	return layers, nil
 }
 

+ 1 - 1
server/routes_create_test.go

@@ -535,7 +535,7 @@ func TestCreateDetectTemplate(t *testing.T) {
 		}
 
 		checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
-			filepath.Join(p, "blobs", "sha256-06cd2687a518d624073f125f1db1c5c727f77c75e84a138fe745186dbbbb4cd7"),
+			filepath.Join(p, "blobs", "sha256-2f8e594e6f34b1b4d36a246628eeb3365ce442303d656f1fcc69e821722acea0"),
 			filepath.Join(p, "blobs", "sha256-542b217f179c7825eeb5bca3c77d2b75ed05bafbd3451d9188891a60a85337c6"),
 			filepath.Join(p, "blobs", "sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"),
 		})