Sfoglia il codice sorgente

fix create model when template detection errors

Michael Yang 10 mesi fa
parent
commit
030e765e76

+ 4 - 14
llm/gguf.go

@@ -618,22 +618,8 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
 		}
 		}
 	}
 	}
 
 
-	offset, err := ws.Seek(0, io.SeekCurrent)
-	if err != nil {
-		return err
-	}
-
 	var alignment int64 = 32
 	var alignment int64 = 32
-	padding := llm.padding(offset, alignment)
-	if err := binary.Write(ws, llm.ByteOrder, bytes.Repeat([]byte{0}, int(padding))); err != nil {
-		return err
-	}
-
 	for _, tensor := range tensors {
 	for _, tensor := range tensors {
-		if _, err := tensor.WriteTo(ws); err != nil {
-			return err
-		}
-
 		offset, err := ws.Seek(0, io.SeekCurrent)
 		offset, err := ws.Seek(0, io.SeekCurrent)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
@@ -643,6 +629,10 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
 		if err := binary.Write(ws, llm.ByteOrder, bytes.Repeat([]byte{0}, int(padding))); err != nil {
 		if err := binary.Write(ws, llm.ByteOrder, bytes.Repeat([]byte{0}, int(padding))); err != nil {
 			return err
 			return err
 		}
 		}
+
+		if _, err := tensor.WriteTo(ws); err != nil {
+			return err
+		}
 	}
 	}
 
 
 	return nil
 	return nil

+ 10 - 11
server/images.go

@@ -437,18 +437,17 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
 					config.ModelFamilies = append(config.ModelFamilies, baseLayer.GGML.KV().Architecture())
 					config.ModelFamilies = append(config.ModelFamilies, baseLayer.GGML.KV().Architecture())
 
 
 					if s := baseLayer.GGML.KV().ChatTemplate(); s != "" {
 					if s := baseLayer.GGML.KV().ChatTemplate(); s != "" {
-						t, err := templates.NamedTemplate(s)
-						if err != nil {
-							return err
+						if t, err := templates.NamedTemplate(s); err != nil {
+							slog.Debug("template detection", "error", err)
+						} else {
+							layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
+							if err != nil {
+								return err
+							}
+
+							layer.status = fmt.Sprintf("using autodetected template %s", t.Name)
+							layers = append(layers, layer)
 						}
 						}
-
-						layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
-						if err != nil {
-							return err
-						}
-
-						layer.status = fmt.Sprintf("using autodetected template %s", t.Name)
-						layers = append(layers, layer)
 					}
 					}
 				}
 				}
 
 

+ 56 - 24
server/routes_create_test.go

@@ -15,11 +15,12 @@ import (
 
 
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/llm"
 )
 )
 
 
 var stream bool = false
 var stream bool = false
 
 
-func createBinFile(t *testing.T) string {
+func createBinFile(t *testing.T, kv map[string]any, ti []llm.Tensor) string {
 	t.Helper()
 	t.Helper()
 
 
 	f, err := os.CreateTemp(t.TempDir(), "")
 	f, err := os.CreateTemp(t.TempDir(), "")
@@ -28,19 +29,7 @@ func createBinFile(t *testing.T) string {
 	}
 	}
 	defer f.Close()
 	defer f.Close()
 
 
-	if err := binary.Write(f, binary.LittleEndian, []byte("GGUF")); err != nil {
-		t.Fatal(err)
-	}
-
-	if err := binary.Write(f, binary.LittleEndian, uint32(3)); err != nil {
-		t.Fatal(err)
-	}
-
-	if err := binary.Write(f, binary.LittleEndian, uint64(0)); err != nil {
-		t.Fatal(err)
-	}
-
-	if err := binary.Write(f, binary.LittleEndian, uint64(0)); err != nil {
+	if err := llm.NewGGUFV3(binary.LittleEndian).Encode(f, kv, ti); err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 
 
@@ -101,7 +90,7 @@ func TestCreateFromBin(t *testing.T) {
 	var s Server
 	var s Server
 	w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
 	w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
 		Name:      "test",
 		Name:      "test",
-		Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
+		Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
 		Stream:    &stream,
 		Stream:    &stream,
 	})
 	})
 
 
@@ -126,7 +115,7 @@ func TestCreateFromModel(t *testing.T) {
 
 
 	w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
 	w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
 		Name:      "test",
 		Name:      "test",
-		Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
+		Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
 		Stream:    &stream,
 		Stream:    &stream,
 	})
 	})
 
 
@@ -166,7 +155,7 @@ func TestCreateRemovesLayers(t *testing.T) {
 
 
 	w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
 	w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
 		Name:      "test",
 		Name:      "test",
-		Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt }}", createBinFile(t)),
+		Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt }}", createBinFile(t, nil, nil)),
 		Stream:    &stream,
 		Stream:    &stream,
 	})
 	})
 
 
@@ -186,7 +175,7 @@ func TestCreateRemovesLayers(t *testing.T) {
 
 
 	w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
 	w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
 		Name:      "test",
 		Name:      "test",
-		Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .System }} {{ .Prompt }}", createBinFile(t)),
+		Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .System }} {{ .Prompt }}", createBinFile(t, nil, nil)),
 		Stream:    &stream,
 		Stream:    &stream,
 	})
 	})
 
 
@@ -212,7 +201,7 @@ func TestCreateUnsetsSystem(t *testing.T) {
 
 
 	w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
 	w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
 		Name:      "test",
 		Name:      "test",
-		Modelfile: fmt.Sprintf("FROM %s\nSYSTEM Say hi!", createBinFile(t)),
+		Modelfile: fmt.Sprintf("FROM %s\nSYSTEM Say hi!", createBinFile(t, nil, nil)),
 		Stream:    &stream,
 		Stream:    &stream,
 	})
 	})
 
 
@@ -232,7 +221,7 @@ func TestCreateUnsetsSystem(t *testing.T) {
 
 
 	w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
 	w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
 		Name:      "test",
 		Name:      "test",
-		Modelfile: fmt.Sprintf("FROM %s\nSYSTEM \"\"", createBinFile(t)),
+		Modelfile: fmt.Sprintf("FROM %s\nSYSTEM \"\"", createBinFile(t, nil, nil)),
 		Stream:    &stream,
 		Stream:    &stream,
 	})
 	})
 
 
@@ -267,7 +256,7 @@ func TestCreateMergeParameters(t *testing.T) {
 
 
 	w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
 	w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
 		Name:      "test",
 		Name:      "test",
-		Modelfile: fmt.Sprintf("FROM %s\nPARAMETER temperature 1\nPARAMETER top_k 10\nPARAMETER stop USER:\nPARAMETER stop ASSISTANT:", createBinFile(t)),
+		Modelfile: fmt.Sprintf("FROM %s\nPARAMETER temperature 1\nPARAMETER top_k 10\nPARAMETER stop USER:\nPARAMETER stop ASSISTANT:", createBinFile(t, nil, nil)),
 		Stream:    &stream,
 		Stream:    &stream,
 	})
 	})
 
 
@@ -369,7 +358,7 @@ func TestCreateReplacesMessages(t *testing.T) {
 
 
 	w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
 	w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
 		Name:      "test",
 		Name:      "test",
-		Modelfile: fmt.Sprintf("FROM %s\nMESSAGE assistant \"What is my purpose?\"\nMESSAGE user \"You run tests.\"\nMESSAGE assistant \"Oh, my god.\"", createBinFile(t)),
+		Modelfile: fmt.Sprintf("FROM %s\nMESSAGE assistant \"What is my purpose?\"\nMESSAGE user \"You run tests.\"\nMESSAGE assistant \"Oh, my god.\"", createBinFile(t, nil, nil)),
 		Stream:    &stream,
 		Stream:    &stream,
 	})
 	})
 
 
@@ -444,7 +433,7 @@ func TestCreateTemplateSystem(t *testing.T) {
 
 
 	w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
 	w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
 		Name:      "test",
 		Name:      "test",
-		Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt }}\nSYSTEM Say hello!\nTEMPLATE {{ .System }} {{ .Prompt }}\nSYSTEM Say bye!", createBinFile(t)),
+		Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt }}\nSYSTEM Say hello!\nTEMPLATE {{ .System }} {{ .Prompt }}\nSYSTEM Say bye!", createBinFile(t, nil, nil)),
 		Stream:    &stream,
 		Stream:    &stream,
 	})
 	})
 
 
@@ -489,7 +478,7 @@ func TestCreateLicenses(t *testing.T) {
 
 
 	w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
 	w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
 		Name:      "test",
 		Name:      "test",
-		Modelfile: fmt.Sprintf("FROM %s\nLICENSE MIT\nLICENSE Apache-2.0", createBinFile(t)),
+		Modelfile: fmt.Sprintf("FROM %s\nLICENSE MIT\nLICENSE Apache-2.0", createBinFile(t, nil, nil)),
 		Stream:    &stream,
 		Stream:    &stream,
 	})
 	})
 
 
@@ -526,3 +515,46 @@ func TestCreateLicenses(t *testing.T) {
 		t.Errorf("expected Apache-2.0, actual %s", apache)
 		t.Errorf("expected Apache-2.0, actual %s", apache)
 	}
 	}
 }
 }
+
+func TestCreateDetectTemplate(t *testing.T) {
+	p := t.TempDir()
+	t.Setenv("OLLAMA_MODELS", p)
+	var s Server
+
+	t.Run("matched", func(t *testing.T) {
+		w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
+			Name: "test",
+			Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
+				"tokenizer.chat_template": "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}",
+			}, nil)),
+			Stream: &stream,
+		})
+
+		if w.Code != http.StatusOK {
+			t.Fatalf("expected status code 200, actual %d", w.Code)
+		}
+
+		checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
+			filepath.Join(p, "blobs", "sha256-06cd2687a518d624073f125f1db1c5c727f77c75e84a138fe745186dbbbb4cd7"),
+			filepath.Join(p, "blobs", "sha256-542b217f179c7825eeb5bca3c77d2b75ed05bafbd3451d9188891a60a85337c6"),
+			filepath.Join(p, "blobs", "sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"),
+		})
+	})
+
+	t.Run("unmatched", func(t *testing.T) {
+		w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
+			Name:      "test",
+			Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
+			Stream:    &stream,
+		})
+
+		if w.Code != http.StatusOK {
+			t.Fatalf("expected status code 200, actual %d", w.Code)
+		}
+
+		checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
+			filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
+			filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"),
+		})
+	})
+}

+ 2 - 2
server/routes_delete_test.go

@@ -16,7 +16,7 @@ func TestDelete(t *testing.T) {
 
 
 	w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
 	w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
 		Name:      "test",
 		Name:      "test",
-		Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
+		Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
 	})
 	})
 
 
 	if w.Code != http.StatusOK {
 	if w.Code != http.StatusOK {
@@ -25,7 +25,7 @@ func TestDelete(t *testing.T) {
 
 
 	w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
 	w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
 		Name:      "test2",
 		Name:      "test2",
-		Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .System }} {{ .Prompt }}", createBinFile(t)),
+		Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .System }} {{ .Prompt }}", createBinFile(t, nil, nil)),
 	})
 	})
 
 
 	if w.Code != http.StatusOK {
 	if w.Code != http.StatusOK {

+ 1 - 1
server/routes_list_test.go

@@ -29,7 +29,7 @@ func TestList(t *testing.T) {
 	for _, n := range expectNames {
 	for _, n := range expectNames {
 		createRequest(t, s.CreateModelHandler, api.CreateRequest{
 		createRequest(t, s.CreateModelHandler, api.CreateRequest{
 			Name:      n,
 			Name:      n,
-			Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
+			Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
 		})
 		})
 	}
 	}
 
 

+ 2 - 2
server/routes_test.go

@@ -261,7 +261,7 @@ func TestCase(t *testing.T) {
 		t.Run(tt, func(t *testing.T) {
 		t.Run(tt, func(t *testing.T) {
 			w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
 			w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
 				Name:      tt,
 				Name:      tt,
-				Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
+				Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
 				Stream:    &stream,
 				Stream:    &stream,
 			})
 			})
 
 
@@ -277,7 +277,7 @@ func TestCase(t *testing.T) {
 			t.Run("create", func(t *testing.T) {
 			t.Run("create", func(t *testing.T) {
 				w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
 				w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
 					Name:      strings.ToUpper(tt),
 					Name:      strings.ToUpper(tt),
-					Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
+					Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
 					Stream:    &stream,
 					Stream:    &stream,
 				})
 				})
 
 

+ 2 - 1
templates/template.go

@@ -30,7 +30,8 @@ var templatesOnce = sync.OnceValues(func() ([]*Template, error) {
 			return nil, err
 			return nil, err
 		}
 		}
 
 
-		t.Bytes = bts
+		// normalize line endings
+		t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n"))
 	}
 	}
 
 
 	return templates, nil
 	return templates, nil