瀏覽代碼

Merge pull request #3682 from ollama/mxyng/quantize-all-the-things

quantize any fp16/fp32 model
Michael Yang 1 年之前
父節點
當前提交
1e0a669f75
共有 14 個文件被更改,包括 624 次插入589 次删除
  1. 2 1
      convert/convert.go
  2. 2 13
      convert/gemma.go
  3. 2 16
      convert/llama.go
  4. 2 13
      convert/mistral.go
  5. 3 14
      convert/mixtral.go
  6. 1 1
      integration/utils_test.go
  7. 140 0
      llm/filetype.go
  8. 18 77
      llm/ggml.go
  9. 4 52
      llm/llm.go
  10. 139 328
      server/images.go
  11. 29 44
      server/layer.go
  12. 261 0
      server/model.go
  13. 1 6
      server/routes.go
  14. 20 24
      server/routes_test.go

+ 2 - 1
convert/convert.go

@@ -5,6 +5,7 @@ import (
 	"encoding/binary"
 	"encoding/json"
 	"fmt"
+	"io"
 	"log/slog"
 	"os"
 	"path/filepath"
@@ -47,7 +48,7 @@ type ByteOrder interface {
 type ModelArch interface {
 	GetTensors() error
 	LoadVocab() error
-	WriteGGUF() (string, error)
+	WriteGGUF(io.WriteSeeker) error
 }
 
 type ModelFormat interface {

+ 2 - 13
convert/gemma.go

@@ -94,7 +94,7 @@ func (m *GemmaModel) LoadVocab() error {
 	return nil
 }
 
-func (m *GemmaModel) WriteGGUF() (string, error) {
+func (m *GemmaModel) WriteGGUF(ws io.WriteSeeker) error {
 	kv := llm.KV{
 		"general.architecture":                   "gemma",
 		"general.name":                           m.Name,
@@ -122,16 +122,5 @@ func (m *GemmaModel) WriteGGUF() (string, error) {
 		"tokenizer.ggml.add_eos_token":    false,
 	}
 
-	f, err := os.CreateTemp("", "ollama-gguf")
-	if err != nil {
-		return "", err
-	}
-	defer f.Close()
-
-	mod := llm.NewGGUFV3(m.Params.ByteOrder)
-	if err := mod.Encode(f, kv, m.Tensors); err != nil {
-		return "", err
-	}
-
-	return f.Name(), nil
+	return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
 }

+ 2 - 16
convert/llama.go

@@ -5,7 +5,6 @@ import (
 	"fmt"
 	"io"
 	"log/slog"
-	"os"
 	"regexp"
 	"strings"
 
@@ -132,7 +131,7 @@ func (m *LlamaModel) LoadVocab() error {
 	return nil
 }
 
-func (m *LlamaModel) WriteGGUF() (string, error) {
+func (m *LlamaModel) WriteGGUF(ws io.WriteSeeker) error {
 	kv := llm.KV{
 		"general.architecture":                   "llama",
 		"general.name":                           m.Name,
@@ -159,18 +158,5 @@ func (m *LlamaModel) WriteGGUF() (string, error) {
 		"tokenizer.ggml.add_eos_token":    false,
 	}
 
-	f, err := os.CreateTemp("", "ollama-gguf")
-	if err != nil {
-		return "", err
-	}
-	defer f.Close()
-
-	mod := llm.NewGGUFV3(m.Params.ByteOrder)
-	if err := mod.Encode(f, kv, m.Tensors); err != nil {
-		return "", err
-	}
-
-	slog.Debug(fmt.Sprintf("gguf file = %s", f.Name()))
-
-	return f.Name(), nil
+	return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
 }

+ 2 - 13
convert/mistral.go

@@ -132,7 +132,7 @@ func (m *MistralModel) LoadVocab() error {
 	return nil
 }
 
-func (m *MistralModel) WriteGGUF() (string, error) {
+func (m *MistralModel) WriteGGUF(ws io.WriteSeeker) error {
 	kv := llm.KV{
 		"general.architecture":                   "llama",
 		"general.name":                           m.Name,
@@ -158,16 +158,5 @@ func (m *MistralModel) WriteGGUF() (string, error) {
 		"tokenizer.ggml.unknown_token_id": uint32(0),
 	}
 
-	f, err := os.CreateTemp("", "ollama-gguf")
-	if err != nil {
-		return "", err
-	}
-	defer f.Close()
-
-	mod := llm.NewGGUFV3(m.Params.ByteOrder)
-	if err := mod.Encode(f, kv, m.Tensors); err != nil {
-		return "", err
-	}
-
-	return f.Name(), nil
+	return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
 }

+ 3 - 14
convert/mixtral.go

@@ -1,7 +1,7 @@
 package convert
 
 import (
-	"os"
+	"io"
 	"regexp"
 
 	"github.com/ollama/ollama/llm"
@@ -47,7 +47,7 @@ func (m *MixtralModel) LoadVocab() error {
 	return nil
 }
 
-func (m *MixtralModel) WriteGGUF() (string, error) {
+func (m *MixtralModel) WriteGGUF(ws io.WriteSeeker) error {
 	kv := llm.KV{
 		"general.architecture":          "llama",
 		"general.name":                  m.Name,
@@ -81,16 +81,5 @@ func (m *MixtralModel) WriteGGUF() (string, error) {
 		"tokenizer.ggml.add_eos_token":    false,
 	}
 
-	f, err := os.CreateTemp("", "ollama-gguf")
-	if err != nil {
-		return "", err
-	}
-	defer f.Close()
-
-	mod := llm.NewGGUFV3(m.Params.ByteOrder)
-	if err := mod.Encode(f, kv, m.Tensors); err != nil {
-		return "", err
-	}
-
-	return f.Name(), nil
+	return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
 }

+ 1 - 1
integration/utils_test.go

@@ -107,7 +107,7 @@ func startServer(ctx context.Context, ollamaHost string) error {
 
 	if tmp := os.Getenv("OLLAMA_HOST"); tmp != ollamaHost {
 		slog.Info("setting env", "OLLAMA_HOST", ollamaHost)
-		os.Setenv("OLLAMA_HOST", ollamaHost)
+		t.Setenv("OLLAMA_HOST", ollamaHost)
 	}
 
 	slog.Info("starting server", "url", ollamaHost)

+ 140 - 0
llm/filetype.go

@@ -0,0 +1,140 @@
+package llm
+
+import "fmt"
+
+type fileType uint32
+
+const (
+	fileTypeF32 fileType = iota
+	fileTypeF16
+	fileTypeQ4_0
+	fileTypeQ4_1
+	fileTypeQ4_1_F16
+	fileTypeQ4_2 // unused
+	fileTypeQ4_3 // unused
+	fileTypeQ8_0
+	fileTypeQ5_0
+	fileTypeQ5_1
+	fileTypeQ2_K
+	fileTypeQ3_K_S
+	fileTypeQ3_K_M
+	fileTypeQ3_K_L
+	fileTypeQ4_K_S
+	fileTypeQ4_K_M
+	fileTypeQ5_K_S
+	fileTypeQ5_K_M
+	fileTypeQ6_K
+	fileTypeIQ2_XXS
+	fileTypeIQ2_XS
+	fileTypeQ2_K_S
+	fileTypeQ3_K_XS
+	fileTypeIQ3_XXS
+
+	fileTypeUnknown
+)
+
+func ParseFileType(s string) (fileType, error) {
+	switch s {
+	case "F32":
+		return fileTypeF32, nil
+	case "F16":
+		return fileTypeF16, nil
+	case "Q4_0":
+		return fileTypeQ4_0, nil
+	case "Q4_1":
+		return fileTypeQ4_1, nil
+	case "Q4_1_F16":
+		return fileTypeQ4_1_F16, nil
+	case "Q8_0":
+		return fileTypeQ8_0, nil
+	case "Q5_0":
+		return fileTypeQ5_0, nil
+	case "Q5_1":
+		return fileTypeQ5_1, nil
+	case "Q2_K":
+		return fileTypeQ2_K, nil
+	case "Q3_K_S":
+		return fileTypeQ3_K_S, nil
+	case "Q3_K_M":
+		return fileTypeQ3_K_M, nil
+	case "Q3_K_L":
+		return fileTypeQ3_K_L, nil
+	case "Q4_K_S":
+		return fileTypeQ4_K_S, nil
+	case "Q4_K_M":
+		return fileTypeQ4_K_M, nil
+	case "Q5_K_S":
+		return fileTypeQ5_K_S, nil
+	case "Q5_K_M":
+		return fileTypeQ5_K_M, nil
+	case "Q6_K":
+		return fileTypeQ6_K, nil
+	case "IQ2_XXS":
+		return fileTypeIQ2_XXS, nil
+	case "IQ2_XS":
+		return fileTypeIQ2_XS, nil
+	case "Q2_K_S":
+		return fileTypeQ2_K_S, nil
+	case "Q3_K_XS":
+		return fileTypeQ3_K_XS, nil
+	case "IQ3_XXS":
+		return fileTypeIQ3_XXS, nil
+	default:
+		return fileTypeUnknown, fmt.Errorf("unknown fileType: %s", s)
+	}
+}
+
+func (t fileType) String() string {
+	switch t {
+	case fileTypeF32:
+		return "F32"
+	case fileTypeF16:
+		return "F16"
+	case fileTypeQ4_0:
+		return "Q4_0"
+	case fileTypeQ4_1:
+		return "Q4_1"
+	case fileTypeQ4_1_F16:
+		return "Q4_1_F16"
+	case fileTypeQ8_0:
+		return "Q8_0"
+	case fileTypeQ5_0:
+		return "Q5_0"
+	case fileTypeQ5_1:
+		return "Q5_1"
+	case fileTypeQ2_K:
+		return "Q2_K"
+	case fileTypeQ3_K_S:
+		return "Q3_K_S"
+	case fileTypeQ3_K_M:
+		return "Q3_K_M"
+	case fileTypeQ3_K_L:
+		return "Q3_K_L"
+	case fileTypeQ4_K_S:
+		return "Q4_K_S"
+	case fileTypeQ4_K_M:
+		return "Q4_K_M"
+	case fileTypeQ5_K_S:
+		return "Q5_K_S"
+	case fileTypeQ5_K_M:
+		return "Q5_K_M"
+	case fileTypeQ6_K:
+		return "Q6_K"
+	case fileTypeIQ2_XXS:
+		return "IQ2_XXS"
+	case fileTypeIQ2_XS:
+		return "IQ2_XS"
+	case fileTypeQ2_K_S:
+		return "Q2_K_S"
+	case fileTypeQ3_K_XS:
+		return "Q3_K_XS"
+	case fileTypeIQ3_XXS:
+		return "IQ3_XXS"
+	default:
+		return "unknown"
+	}
+}
+
+func (t fileType) Value() uint32 {
+	return uint32(t)
+}

+ 18 - 77
llm/ggml.go

@@ -13,82 +13,6 @@ type GGML struct {
 	model
 }
 
-const (
-	fileTypeF32 uint32 = iota
-	fileTypeF16
-	fileTypeQ4_0
-	fileTypeQ4_1
-	fileTypeQ4_1_F16
-	fileTypeQ8_0 uint32 = iota + 2
-	fileTypeQ5_0
-	fileTypeQ5_1
-	fileTypeQ2_K
-	fileTypeQ3_K_S
-	fileTypeQ3_K_M
-	fileTypeQ3_K_L
-	fileTypeQ4_K_S
-	fileTypeQ4_K_M
-	fileTypeQ5_K_S
-	fileTypeQ5_K_M
-	fileTypeQ6_K
-	fileTypeIQ2_XXS
-	fileTypeIQ2_XS
-	fileTypeQ2_K_S
-	fileTypeQ3_K_XS
-	fileTypeIQ3_XXS
-)
-
-func fileType(fileType uint32) string {
-	switch fileType {
-	case fileTypeF32:
-		return "F32"
-	case fileTypeF16:
-		return "F16"
-	case fileTypeQ4_0:
-		return "Q4_0"
-	case fileTypeQ4_1:
-		return "Q4_1"
-	case fileTypeQ4_1_F16:
-		return "Q4_1_F16"
-	case fileTypeQ8_0:
-		return "Q8_0"
-	case fileTypeQ5_0:
-		return "Q5_0"
-	case fileTypeQ5_1:
-		return "Q5_1"
-	case fileTypeQ2_K:
-		return "Q2_K"
-	case fileTypeQ3_K_S:
-		return "Q3_K_S"
-	case fileTypeQ3_K_M:
-		return "Q3_K_M"
-	case fileTypeQ3_K_L:
-		return "Q3_K_L"
-	case fileTypeQ4_K_S:
-		return "Q4_K_S"
-	case fileTypeQ4_K_M:
-		return "Q4_K_M"
-	case fileTypeQ5_K_S:
-		return "Q5_K_S"
-	case fileTypeQ5_K_M:
-		return "Q5_K_M"
-	case fileTypeQ6_K:
-		return "Q6_K"
-	case fileTypeIQ2_XXS:
-		return "IQ2_XXS"
-	case fileTypeIQ2_XS:
-		return "IQ2_XS"
-	case fileTypeQ2_K_S:
-		return "Q2_K_S"
-	case fileTypeQ3_K_XS:
-		return "Q3_K_XS"
-	case fileTypeIQ3_XXS:
-		return "IQ3_XXS"
-	default:
-		return "unknown"
-	}
-}
-
 type model interface {
 	KV() KV
 	Tensors() Tensors
@@ -123,7 +47,7 @@ func (kv KV) ParameterCount() uint64 {
 
 func (kv KV) FileType() string {
 	if u64 := kv.u64("general.file_type"); u64 > 0 {
-		return fileType(uint32(u64))
+		return fileType(uint32(u64)).String()
 	}
 
 	return "unknown"
@@ -286,6 +210,23 @@ const (
 
 var ErrUnsupportedFormat = errors.New("unsupported model format")
 
+func DetectGGMLType(b []byte) string {
+	switch binary.LittleEndian.Uint32(b[:4]) {
+	case FILE_MAGIC_GGML:
+		return "ggml"
+	case FILE_MAGIC_GGMF:
+		return "ggmf"
+	case FILE_MAGIC_GGJT:
+		return "ggjt"
+	case FILE_MAGIC_GGLA:
+		return "ggla"
+	case FILE_MAGIC_GGUF_LE, FILE_MAGIC_GGUF_BE:
+		return "gguf"
+	default:
+		return ""
+	}
+}
+
 func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) {
 	var magic uint32
 	if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {

+ 4 - 52
llm/llm.go

@@ -20,7 +20,7 @@ func SystemInfo() string {
 	return C.GoString(C.llama_print_system_info())
 }
 
-func Quantize(infile, outfile, filetype string) error {
+func Quantize(infile, outfile string, ftype fileType) error {
 	cinfile := C.CString(infile)
 	defer C.free(unsafe.Pointer(cinfile))
 
@@ -29,58 +29,10 @@ func Quantize(infile, outfile, filetype string) error {
 
 	params := C.llama_model_quantize_default_params()
 	params.nthread = -1
+	params.ftype = ftype.Value()
 
-	switch filetype {
-	case "F32":
-		params.ftype = fileTypeF32
-	case "F16":
-		params.ftype = fileTypeF16
-	case "Q4_0":
-		params.ftype = fileTypeQ4_0
-	case "Q4_1":
-		params.ftype = fileTypeQ4_1
-	case "Q4_1_F16":
-		params.ftype = fileTypeQ4_1_F16
-	case "Q8_0":
-		params.ftype = fileTypeQ8_0
-	case "Q5_0":
-		params.ftype = fileTypeQ5_0
-	case "Q5_1":
-		params.ftype = fileTypeQ5_1
-	case "Q2_K":
-		params.ftype = fileTypeQ2_K
-	case "Q3_K_S":
-		params.ftype = fileTypeQ3_K_S
-	case "Q3_K_M":
-		params.ftype = fileTypeQ3_K_M
-	case "Q3_K_L":
-		params.ftype = fileTypeQ3_K_L
-	case "Q4_K_S":
-		params.ftype = fileTypeQ4_K_S
-	case "Q4_K_M":
-		params.ftype = fileTypeQ4_K_M
-	case "Q5_K_S":
-		params.ftype = fileTypeQ5_K_S
-	case "Q5_K_M":
-		params.ftype = fileTypeQ5_K_M
-	case "Q6_K":
-		params.ftype = fileTypeQ6_K
-	case "IQ2_XXS":
-		params.ftype = fileTypeIQ2_XXS
-	case "IQ2_XS":
-		params.ftype = fileTypeIQ2_XS
-	case "Q2_K_S":
-		params.ftype = fileTypeQ2_K_S
-	case "Q3_K_XS":
-		params.ftype = fileTypeQ3_K_XS
-	case "IQ3_XXS":
-		params.ftype = fileTypeIQ3_XXS
-	default:
-		return fmt.Errorf("unknown filetype: %s", filetype)
-	}
-
-	if retval := C.llama_model_quantize(cinfile, coutfile, &params); retval != 0 {
-		return fmt.Errorf("llama_model_quantize: %d", retval)
+	if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 {
+		return fmt.Errorf("llama_model_quantize: %d", rc)
 	}
 
 	return nil

+ 139 - 328
server/images.go

@@ -1,8 +1,8 @@
 package server
 
 import (
-	"archive/zip"
 	"bytes"
+	"cmp"
 	"context"
 	"crypto/sha256"
 	"encoding/base64"
@@ -11,7 +11,6 @@ import (
 	"errors"
 	"fmt"
 	"io"
-	"io/fs"
 	"log"
 	"log/slog"
 	"net/http"
@@ -26,7 +25,6 @@ import (
 
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/auth"
-	"github.com/ollama/ollama/convert"
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/llm"
 	"github.com/ollama/ollama/server/envconfig"
@@ -158,36 +156,6 @@ type ConfigV2 struct {
 	RootFS       RootFS `json:"rootfs"`
 }
 
-func (c *ConfigV2) SetModelFormat(format string) {
-	if c.ModelFormat == "" {
-		c.ModelFormat = format
-	}
-}
-
-func (c *ConfigV2) SetModelFamily(families ...string) {
-	for _, family := range families {
-		if c.ModelFamily == "" {
-			c.ModelFamily = family
-		}
-
-		if !slices.Contains(c.ModelFamilies, family) {
-			c.ModelFamilies = append(c.ModelFamilies, family)
-		}
-	}
-}
-
-func (c *ConfigV2) SetModelType(modelType string) {
-	if c.ModelType == "" {
-		c.ModelType = modelType
-	}
-}
-
-func (c *ConfigV2) SetFileType(fileType string) {
-	if c.FileType == "" {
-		c.FileType = fileType
-	}
-}
-
 type RootFS struct {
 	Type    string   `json:"type"`
 	DiffIDs []string `json:"diff_ids"`
@@ -332,7 +300,7 @@ func GetModel(name string) (*Model, error) {
 	return model, nil
 }
 
-func realpath(mfDir, from string) string {
+func realpath(rel, from string) string {
 	abspath, err := filepath.Abs(from)
 	if err != nil {
 		return from
@@ -349,22 +317,15 @@ func realpath(mfDir, from string) string {
 		return filepath.Join(home, from[2:])
 	}
 
-	if _, err := os.Stat(filepath.Join(mfDir, from)); err == nil {
+	if _, err := os.Stat(filepath.Join(rel, from)); err == nil {
 		// this is a file relative to the Modelfile
-		return filepath.Join(mfDir, from)
+		return filepath.Join(rel, from)
 	}
 
 	return abspath
 }
 
-func CreateModel(ctx context.Context, name, modelFileDir, quantization string, modelfile *model.File, fn func(resp api.ProgressResponse)) error {
-	deleteMap := make(map[string]struct{})
-	if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil {
-		for _, layer := range append(manifest.Layers, manifest.Config) {
-			deleteMap[layer.Digest] = struct{}{}
-		}
-	}
-
+func CreateModel(ctx context.Context, name, modelFileDir, quantization string, modelfile *model.File, fn func(resp api.ProgressResponse)) (err error) {
 	config := ConfigV2{
 		OS:           "linux",
 		Architecture: "amd64",
@@ -373,250 +334,181 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
 		},
 	}
 
-	var layers Layers
-	messages := []string{}
-
-	params := make(map[string][]string)
-	fromParams := make(map[string]any)
+	var messages []*api.Message
+	parameters := make(map[string]any)
 
+	var layers []*Layer
 	for _, c := range modelfile.Commands {
 		mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
 
 		switch c.Name {
-		case "model":
-			if strings.HasPrefix(c.Args, "@") {
-				blobPath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
+		case "model", "adapter":
+			var baseLayers []*layerWithGGML
+			if name := model.ParseName(c.Args); name.IsValid() {
+				baseLayers, err = parseFromModel(ctx, name, fn)
+				if err != nil {
+					return err
+				}
+			} else if strings.HasPrefix(c.Args, "@") {
+				blobpath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
 				if err != nil {
 					return err
 				}
 
-				c.Args = blobPath
-			}
+				blob, err := os.Open(blobpath)
+				if err != nil {
+					return err
+				}
+				defer blob.Close()
 
-			pathName := realpath(modelFileDir, c.Args)
+				baseLayers, err = parseFromFile(ctx, blob, fn)
+				if err != nil {
+					return err
+				}
+			} else if file, err := os.Open(realpath(modelFileDir, c.Args)); err == nil {
+				defer file.Close()
 
-			ggufName, err := convertModel(name, pathName, fn)
-			if err != nil {
-				var pathErr *fs.PathError
-				switch {
-				case errors.Is(err, zip.ErrFormat):
-					// it's not a safetensor archive
-				case errors.As(err, &pathErr):
-					// it's not a file on disk, could be a model reference
-				default:
+				baseLayers, err = parseFromFile(ctx, file, fn)
+				if err != nil {
 					return err
 				}
+			} else {
+				return fmt.Errorf("invalid model reference: %s", c.Args)
 			}
 
-			if ggufName != "" {
-				pathName = ggufName
-				defer os.RemoveAll(ggufName)
-
-				if quantization != "" {
-					quantization = strings.ToUpper(quantization)
-					fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", "F16", quantization)})
-					tempfile, err := os.CreateTemp(filepath.Dir(ggufName), quantization)
+			for _, baseLayer := range baseLayers {
+				if quantization != "" &&
+					baseLayer.MediaType == "application/vnd.ollama.image.model" &&
+					baseLayer.GGML != nil &&
+					baseLayer.GGML.Name() == "gguf" {
+					ftype, err := llm.ParseFileType(quantization)
 					if err != nil {
 						return err
 					}
-					defer os.RemoveAll(tempfile.Name())
 
-					if err := llm.Quantize(ggufName, tempfile.Name(), quantization); err != nil {
-						return err
+					filetype := baseLayer.GGML.KV().FileType()
+					if !slices.Contains([]string{"F16", "F32"}, filetype) {
+						return errors.New("quantization is only supported for F16 and F32 models")
 					}
 
-					if err := tempfile.Close(); err != nil {
-						return err
-					}
-
-					pathName = tempfile.Name()
-				}
-			}
+					fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", filetype, quantization)})
 
-			bin, err := os.Open(pathName)
-			if err != nil {
-				// not a file on disk so must be a model reference
-				modelpath := ParseModelPath(c.Args)
-				manifest, _, err := GetManifest(modelpath)
-				switch {
-				case errors.Is(err, os.ErrNotExist):
-					fn(api.ProgressResponse{Status: "pulling model"})
-					if err := PullModel(ctx, c.Args, &registryOptions{}, fn); err != nil {
+					blob, err := GetBlobsPath(baseLayer.Digest)
+					if err != nil {
 						return err
 					}
 
-					manifest, _, err = GetManifest(modelpath)
+					temp, err := os.CreateTemp(filepath.Dir(blob), quantization)
 					if err != nil {
 						return err
 					}
-				case err != nil:
-					return err
-				}
-
-				fn(api.ProgressResponse{Status: "reading model metadata"})
-				fromConfigPath, err := GetBlobsPath(manifest.Config.Digest)
-				if err != nil {
-					return err
-				}
-
-				fromConfigFile, err := os.Open(fromConfigPath)
-				if err != nil {
-					return err
-				}
-				defer fromConfigFile.Close()
-
-				var fromConfig ConfigV2
-				if err := json.NewDecoder(fromConfigFile).Decode(&fromConfig); err != nil {
-					return err
-				}
-
-				// if the model is still not in gguf format, error out
-				if fromConfig.ModelFormat != "gguf" {
-					return fmt.Errorf("%s is not in gguf format, this base model is not compatible with this version of ollama", c.Args)
-				}
+					defer temp.Close()
+					defer os.Remove(temp.Name())
 
-				config.SetModelFormat(fromConfig.ModelFormat)
-				config.SetModelFamily(append(fromConfig.ModelFamilies, fromConfig.ModelFamily)...)
-				config.SetModelType(fromConfig.ModelType)
-				config.SetFileType(fromConfig.FileType)
-
-				for _, layer := range manifest.Layers {
-					deleteMap[layer.Digest] = struct{}{}
-					if layer.MediaType == "application/vnd.ollama.image.params" {
-						fromParamsPath, err := GetBlobsPath(layer.Digest)
-						if err != nil {
-							return err
-						}
-
-						fromParamsFile, err := os.Open(fromParamsPath)
-						if err != nil {
-							return err
-						}
-						defer fromParamsFile.Close()
-
-						if err := json.NewDecoder(fromParamsFile).Decode(&fromParams); err != nil {
-							return err
-						}
+					if err := llm.Quantize(blob, temp.Name(), ftype); err != nil {
+						return err
 					}
 
-					layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname())
+					baseLayer.Layer, err = NewLayer(temp, baseLayer.Layer.MediaType)
 					if err != nil {
 						return err
 					}
-
-					layers.Add(layer)
 				}
 
-				deleteMap[manifest.Config.Digest] = struct{}{}
-				continue
-			}
-			defer bin.Close()
-
-			var offset int64
-			for {
-				fn(api.ProgressResponse{Status: "creating model layer"})
-				if _, err := bin.Seek(offset, io.SeekStart); err != nil {
-					return err
-				}
-
-				ggml, size, err := llm.DecodeGGML(bin)
-				if errors.Is(err, io.EOF) {
-					break
-				} else if errors.Is(err, llm.ErrUnsupportedFormat) {
-					return fmt.Errorf("model binary specified in FROM field is not a valid gguf format model, %w", err)
-				} else if err != nil {
-					return err
-				}
-
-				config.SetModelFormat(ggml.Name())
-				config.SetModelFamily(ggml.KV().Architecture())
-				config.SetModelType(format.HumanNumber(ggml.KV().ParameterCount()))
-				config.SetFileType(ggml.KV().FileType())
-
-				mediatype := mediatype
-				if ggml.KV().Architecture() == "clip" {
-					mediatype = "application/vnd.ollama.image.projector"
+				if baseLayer.GGML != nil {
+					config.ModelFormat = cmp.Or(config.ModelFormat, baseLayer.GGML.Name())
+					config.ModelFamily = cmp.Or(config.ModelFamily, baseLayer.GGML.KV().Architecture())
+					config.ModelType = cmp.Or(config.ModelType, format.HumanNumber(baseLayer.GGML.KV().ParameterCount()))
+					config.FileType = cmp.Or(config.FileType, baseLayer.GGML.KV().FileType())
+					config.ModelFamilies = append(config.ModelFamilies, baseLayer.GGML.KV().Architecture())
 				}
 
-				sr := io.NewSectionReader(bin, offset, size)
-				layer, err := NewLayer(sr, mediatype)
-				if err != nil {
-					return err
-				}
-
-				layers.Add(layer)
-
-				offset += size
+				layers = append(layers, baseLayer.Layer)
+			}
+		case "license", "template", "system":
+			blob := strings.NewReader(c.Args)
+			layer, err := NewLayer(blob, mediatype)
+			if err != nil {
+				return err
 			}
-		case "adapter":
-			if strings.HasPrefix(c.Args, "@") {
-				blobPath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
-				if err != nil {
-					return err
-				}
 
-				c.Args = blobPath
+			if c.Name != "license" {
+				// replace
+				layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
+					return layer.MediaType == mediatype
+				})
 			}
 
-			fn(api.ProgressResponse{Status: "creating adapter layer"})
-			bin, err := os.Open(realpath(modelFileDir, c.Args))
-			if err != nil {
-				return err
+			layers = append(layers, layer)
+		case "message":
+			role, content, ok := strings.Cut(c.Args, ": ")
+			if !ok {
+				return fmt.Errorf("invalid message: %s", c.Args)
 			}
-			defer bin.Close()
 
-			_, size, err := llm.DecodeGGML(bin)
+			messages = append(messages, &api.Message{Role: role, Content: content})
+		default:
+			ps, err := api.FormatParams(map[string][]string{c.Name: {c.Args}})
 			if err != nil {
 				return err
 			}
 
-			sr := io.NewSectionReader(bin, 0, size)
-			layer, err := NewLayer(sr, mediatype)
-			if err != nil {
-				return err
+			for k, v := range ps {
+				if ks, ok := parameters[k].([]string); ok {
+					parameters[k] = append(ks, v.([]string)...)
+				} else if vs, ok := v.([]string); ok {
+					parameters[k] = vs
+				} else {
+					parameters[k] = v
+				}
 			}
+		}
+	}
 
-			layers.Add(layer)
-		case "license":
-			fn(api.ProgressResponse{Status: "creating license layer"})
+	var err2 error
+	layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
+		switch layer.MediaType {
+		case "application/vnd.ollama.image.message":
+			// if there are new messages, remove the inherited ones
+			if len(messages) > 0 {
+				return true
+			}
 
-			bin := strings.NewReader(c.Args)
-			layer, err := NewLayer(bin, mediatype)
+			return false
+		case "application/vnd.ollama.image.params":
+			// merge inherited parameters with new ones
+			r, err := layer.Open()
 			if err != nil {
-				return err
+				err2 = err
+				return false
 			}
+			defer r.Close()
 
-			layers.Add(layer)
-		case "template", "system":
-			fn(api.ProgressResponse{Status: fmt.Sprintf("creating %s layer", c.Name)})
+			var ps map[string]any
+			if err := json.NewDecoder(r).Decode(&ps); err != nil {
+				err2 = err
+				return false
+			}
 
-			bin := strings.NewReader(c.Args)
-			layer, err := NewLayer(bin, mediatype)
-			if err != nil {
-				return err
+			for k, v := range ps {
+				if _, ok := parameters[k]; !ok {
+					parameters[k] = v
+				}
 			}
 
-			layers.Replace(layer)
-		case "message":
-			messages = append(messages, c.Args)
+			return true
 		default:
-			params[c.Name] = append(params[c.Name], c.Args)
+			return false
 		}
+	})
+
+	if err2 != nil {
+		return err2
 	}
 
 	if len(messages) > 0 {
-		fn(api.ProgressResponse{Status: "creating parameters layer"})
-
-		msgs := make([]api.Message, 0)
-
-		for _, m := range messages {
-			// todo: handle images
-			msg := strings.SplitN(m, ": ", 2)
-			msgs = append(msgs, api.Message{Role: msg[0], Content: msg[1]})
-		}
-
 		var b bytes.Buffer
-		if err := json.NewEncoder(&b).Encode(msgs); err != nil {
+		if err := json.NewEncoder(&b).Encode(messages); err != nil {
 			return err
 		}
 
@@ -625,39 +517,25 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
 			return err
 		}
 
-		layers.Replace(layer)
+		layers = append(layers, layer)
 	}
 
-	if len(params) > 0 {
-		fn(api.ProgressResponse{Status: "creating parameters layer"})
-
-		formattedParams, err := api.FormatParams(params)
-		if err != nil {
-			return err
-		}
-
-		for k, v := range fromParams {
-			if _, ok := formattedParams[k]; !ok {
-				formattedParams[k] = v
-			}
-		}
-
+	if len(parameters) > 0 {
 		var b bytes.Buffer
-		if err := json.NewEncoder(&b).Encode(formattedParams); err != nil {
+		if err := json.NewEncoder(&b).Encode(parameters); err != nil {
 			return err
 		}
 
-		fn(api.ProgressResponse{Status: "creating config layer"})
 		layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
 		if err != nil {
 			return err
 		}
 
-		layers.Replace(layer)
+		layers = append(layers, layer)
 	}
 
-	digests := make([]string, len(layers.items))
-	for i, layer := range layers.items {
+	digests := make([]string, len(layers))
+	for i, layer := range layers {
 		digests[i] = layer.Digest
 	}
 
@@ -668,36 +546,37 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
 		return err
 	}
 
-	configLayer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
+	layer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
 	if err != nil {
 		return err
 	}
 
-	delete(deleteMap, configLayer.Digest)
-
-	for _, layer := range append(layers.items, configLayer) {
-		committed, err := layer.Commit()
-		if err != nil {
-			return err
+	for _, layer := range append(layers, layer) {
+		if layer.status != "" {
+			fn(api.ProgressResponse{Status: layer.status})
 		}
+	}
 
-		status := "writing layer"
-		if !committed {
-			status = "using already created layer"
+	unref := make(map[string]struct{})
+	if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil {
+		for _, layer := range manifest.Layers {
+			if !slices.Contains(digests, layer.Digest) {
+				unref[layer.Digest] = struct{}{}
+			}
 		}
 
-		fn(api.ProgressResponse{Status: fmt.Sprintf("%s %s", status, layer.Digest)})
-
-		delete(deleteMap, layer.Digest)
+		if manifest.Config.Digest != layer.Digest {
+			unref[manifest.Config.Digest] = struct{}{}
+		}
 	}
 
 	fn(api.ProgressResponse{Status: "writing manifest"})
-	if err := WriteManifest(name, configLayer, layers.items); err != nil {
+	if err := WriteManifest(name, layer, layers); err != nil {
 		return err
 	}
 
 	if !envconfig.NoPrune {
-		if err := deleteUnusedLayers(nil, deleteMap, false); err != nil {
+		if err := deleteUnusedLayers(nil, unref, false); err != nil {
 			return err
 		}
 	}
@@ -706,74 +585,6 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
 	return nil
 }
 
-func convertModel(name, path string, fn func(resp api.ProgressResponse)) (string, error) {
-	r, err := zip.OpenReader(path)
-	if err != nil {
-		return "", err
-	}
-	defer r.Close()
-
-	tempDir, err := os.MkdirTemp("", "ollama-convert")
-	if err != nil {
-		return "", err
-	}
-	defer os.RemoveAll(tempDir)
-
-	fn(api.ProgressResponse{Status: "unpacking model metadata"})
-	for _, f := range r.File {
-		fpath := filepath.Join(tempDir, f.Name)
-		outFile, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode())
-		if err != nil {
-			return "", err
-		}
-
-		rc, err := f.Open()
-		if err != nil {
-			return "", err
-		}
-
-		_, err = io.Copy(outFile, rc)
-		if err != nil {
-			return "", err
-		}
-
-		outFile.Close()
-		rc.Close()
-	}
-
-	mf, err := convert.GetModelFormat(tempDir)
-	if err != nil {
-		return "", err
-	}
-
-	params, err := mf.GetParams(tempDir)
-	if err != nil {
-		return "", err
-	}
-
-	mArch, err := mf.GetModelArch(name, tempDir, params)
-	if err != nil {
-		return "", err
-	}
-
-	fn(api.ProgressResponse{Status: "processing tensors"})
-	if err := mArch.GetTensors(); err != nil {
-		return "", err
-	}
-
-	if err := mArch.LoadVocab(); err != nil {
-		return "", err
-	}
-
-	fn(api.ProgressResponse{Status: "converting model"})
-	path, err = mArch.WriteGGUF()
-	if err != nil {
-		return "", err
-	}
-
-	return path, nil
-}
-
 func CopyModel(src, dst model.Name) error {
 	if !dst.IsFullyQualified() {
 		return model.Unqualified(dst)

+ 29 - 44
server/layers.go → server/layer.go

@@ -5,39 +5,14 @@ import (
 	"fmt"
 	"io"
 	"os"
-	"strings"
-
-	"golang.org/x/exp/slices"
 )
 
-type Layers struct {
-	items []*Layer
-}
-
-func (ls *Layers) Add(layer *Layer) {
-	if layer.Size > 0 {
-		ls.items = append(ls.items, layer)
-	}
-}
-
-func (ls *Layers) Replace(layer *Layer) {
-	if layer.Size > 0 {
-		mediatype := layer.MediaType
-		layers := slices.DeleteFunc(ls.items, func(l *Layer) bool {
-			return l.MediaType == mediatype
-		})
-
-		ls.items = append(layers, layer)
-	}
-}
-
 type Layer struct {
 	MediaType string `json:"mediaType"`
 	Digest    string `json:"digest"`
 	Size      int64  `json:"size"`
 	From      string `json:"from,omitempty"`
-
-	tempFileName string
+	status    string
 }
 
 func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
@@ -46,14 +21,12 @@ func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
 		return nil, err
 	}
 
-	const delimiter = "-"
-
-	pattern := strings.Join([]string{"sha256", "*-partial"}, delimiter)
-	temp, err := os.CreateTemp(blobs, pattern)
+	temp, err := os.CreateTemp(blobs, "sha256-")
 	if err != nil {
 		return nil, err
 	}
 	defer temp.Close()
+	defer os.Remove(temp.Name())
 
 	sha256sum := sha256.New()
 	n, err := io.Copy(io.MultiWriter(temp, sha256sum), r)
@@ -61,11 +34,29 @@ func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
 		return nil, err
 	}
 
+	if err := temp.Close(); err != nil {
+		return nil, err
+	}
+
+	digest := fmt.Sprintf("sha256:%x", sha256sum.Sum(nil))
+	blob, err := GetBlobsPath(digest)
+	if err != nil {
+		return nil, err
+	}
+
+	status := "using existing layer"
+	if _, err := os.Stat(blob); err != nil {
+		status = "creating new layer"
+		if err := os.Rename(temp.Name(), blob); err != nil {
+			return nil, err
+		}
+	}
+
 	return &Layer{
-		MediaType:    mediatype,
-		Digest:       fmt.Sprintf("sha256:%x", sha256sum.Sum(nil)),
-		Size:         n,
-		tempFileName: temp.Name(),
+		MediaType: mediatype,
+		Digest:    digest,
+		Size:      n,
+		status:    fmt.Sprintf("%s %s", status, digest),
 	}, nil
 }
 
@@ -85,21 +76,15 @@ func NewLayerFromLayer(digest, mediatype, from string) (*Layer, error) {
 		Digest:    digest,
 		Size:      fi.Size(),
 		From:      from,
+		status:    fmt.Sprintf("using existing layer %s", digest),
 	}, nil
 }
 
-func (l *Layer) Commit() (bool, error) {
-	// always remove temp
-	defer os.Remove(l.tempFileName)
-
+func (l *Layer) Open() (io.ReadCloser, error) {
 	blob, err := GetBlobsPath(l.Digest)
 	if err != nil {
-		return false, err
-	}
-
-	if _, err := os.Stat(blob); err != nil {
-		return true, os.Rename(l.tempFileName, blob)
+		return nil, err
 	}
 
-	return false, nil
+	return os.Open(blob)
 }

+ 261 - 0
server/model.go

@@ -0,0 +1,261 @@
+package server
+
+import (
+	"archive/zip"
+	"bytes"
+	"context"
+	"errors"
+	"fmt"
+	"io"
+	"net/http"
+	"os"
+	"path/filepath"
+
+	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/convert"
+	"github.com/ollama/ollama/llm"
+	"github.com/ollama/ollama/types/model"
+)
+
+type layerWithGGML struct {
+	*Layer
+	*llm.GGML
+}
+
+func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
+	modelpath := ParseModelPath(name.String())
+	manifest, _, err := GetManifest(modelpath)
+	switch {
+	case errors.Is(err, os.ErrNotExist):
+		if err := PullModel(ctx, name.String(), &registryOptions{}, fn); err != nil {
+			return nil, err
+		}
+
+		modelpath = ParseModelPath(name.String())
+		manifest, _, err = GetManifest(modelpath)
+		if err != nil {
+			return nil, err
+		}
+	case err != nil:
+		return nil, err
+	}
+
+	for _, layer := range manifest.Layers {
+		layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname())
+		if err != nil {
+			return nil, err
+		}
+
+		switch layer.MediaType {
+		case "application/vnd.ollama.image.model",
+			"application/vnd.ollama.image.projector",
+			"application/vnd.ollama.image.adapter":
+			blobpath, err := GetBlobsPath(layer.Digest)
+			if err != nil {
+				return nil, err
+			}
+
+			blob, err := os.Open(blobpath)
+			if err != nil {
+				return nil, err
+			}
+			defer blob.Close()
+
+			ggml, _, err := llm.DecodeGGML(blob)
+			if err != nil {
+				return nil, err
+			}
+
+			layers = append(layers, &layerWithGGML{layer, ggml})
+		default:
+			layers = append(layers, &layerWithGGML{layer, nil})
+		}
+
+	}
+
+	return layers, nil
+}
+
+func parseFromZipFile(_ context.Context, file *os.File, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
+	stat, err := file.Stat()
+	if err != nil {
+		return nil, err
+	}
+
+	r, err := zip.NewReader(file, stat.Size())
+	if err != nil {
+		return nil, err
+	}
+
+	tempdir, err := os.MkdirTemp(filepath.Dir(file.Name()), "")
+	if err != nil {
+		return nil, err
+	}
+	defer os.RemoveAll(tempdir)
+
+	fn(api.ProgressResponse{Status: "unpacking model metadata"})
+	for _, f := range r.File {
+		// TODO(mxyng): this should not write out all files to disk
+		outfile, err := os.Create(filepath.Join(tempdir, f.Name))
+		if err != nil {
+			return nil, err
+		}
+		defer outfile.Close()
+
+		infile, err := f.Open()
+		if err != nil {
+			return nil, err
+		}
+		defer infile.Close()
+
+		if _, err = io.Copy(outfile, infile); err != nil {
+			return nil, err
+		}
+
+		if err := outfile.Close(); err != nil {
+			return nil, err
+		}
+
+		if err := infile.Close(); err != nil {
+			return nil, err
+		}
+	}
+
+	mf, err := convert.GetModelFormat(tempdir)
+	if err != nil {
+		return nil, err
+	}
+
+	params, err := mf.GetParams(tempdir)
+	if err != nil {
+		return nil, err
+	}
+
+	mArch, err := mf.GetModelArch("", tempdir, params)
+	if err != nil {
+		return nil, err
+	}
+
+	fn(api.ProgressResponse{Status: "processing tensors"})
+	if err := mArch.GetTensors(); err != nil {
+		return nil, err
+	}
+
+	if err := mArch.LoadVocab(); err != nil {
+		return nil, err
+	}
+
+	fn(api.ProgressResponse{Status: "converting model"})
+
+	// TODO(mxyng): this should write directly into a layer
+	// e.g. NewLayer(arch.Reader(), "application/vnd.ollama.image.model")
+	temp, err := os.CreateTemp(tempdir, "fp16")
+	if err != nil {
+		return nil, err
+	}
+	defer temp.Close()
+	defer os.Remove(temp.Name())
+
+	if err = mArch.WriteGGUF(temp); err != nil {
+		return nil, err
+	}
+
+	if _, err := temp.Seek(0, io.SeekStart); err != nil {
+		return nil, err
+	}
+
+	layer, err := NewLayer(temp, "application/vnd.ollama.image.model")
+	if err != nil {
+		return nil, fmt.Errorf("aaa: %w", err)
+	}
+
+	blobpath, err := GetBlobsPath(layer.Digest)
+	if err != nil {
+		return nil, err
+	}
+
+	bin, err := os.Open(blobpath)
+	if err != nil {
+		return nil, err
+	}
+	defer bin.Close()
+
+	ggml, _, err := llm.DecodeGGML(bin)
+	if err != nil {
+		return nil, err
+	}
+
+	layer, err = NewLayerFromLayer(layer.Digest, layer.MediaType, "")
+	if err != nil {
+		return nil, err
+	}
+
+	layers = append(layers, &layerWithGGML{layer, ggml})
+	return layers, nil
+}
+
+func parseFromFile(ctx context.Context, file *os.File, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
+	sr := io.NewSectionReader(file, 0, 512)
+	contentType, err := detectContentType(sr)
+	if err != nil {
+		return nil, err
+	}
+
+	switch contentType {
+	case "gguf", "ggla":
+		// noop
+	case "application/zip":
+		return parseFromZipFile(ctx, file, fn)
+	default:
+		return nil, fmt.Errorf("unsupported content type: %s", contentType)
+	}
+
+	stat, err := file.Stat()
+	if err != nil {
+		return nil, err
+	}
+
+	var offset int64
+	for offset < stat.Size() {
+		ggml, n, err := llm.DecodeGGML(file)
+		if errors.Is(err, io.EOF) {
+			break
+		} else if err != nil {
+			return nil, err
+		}
+
+		mediatype := "application/vnd.ollama.image.model"
+		if ggml.Name() == "ggla" {
+			mediatype = "application/vnd.ollama.image.adapter"
+		} else if ggml.KV().Architecture() == "clip" {
+			mediatype = "application/vnd.ollama.image.projector"
+		}
+
+		layer, err := NewLayer(io.NewSectionReader(file, offset, n), mediatype)
+		if err != nil {
+			return nil, err
+		}
+
+		layers = append(layers, &layerWithGGML{layer, ggml})
+		offset = n
+	}
+
+	return layers, nil
+}
+
+func detectContentType(r io.Reader) (string, error) {
+	var b bytes.Buffer
+	if _, err := io.Copy(&b, r); err != nil {
+		return "", err
+	}
+
+	if contentType := llm.DetectGGMLType(b.Bytes()); contentType != "" {
+		return contentType, nil
+	}
+
+	if contentType := http.DetectContentType(b.Bytes()); contentType != "application/octet-stream" {
+		return contentType, nil
+	}
+
+	return "unknown", nil
+}

+ 1 - 6
server/routes.go

@@ -560,7 +560,7 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
 		ctx, cancel := context.WithCancel(c.Request.Context())
 		defer cancel()
 
-		if err := CreateModel(ctx, name.String(), filepath.Dir(req.Path), req.Quantization, modelfile, fn); err != nil {
+		if err := CreateModel(ctx, name.String(), filepath.Dir(req.Path), strings.ToUpper(req.Quantization), modelfile, fn); err != nil {
 			ch <- gin.H{"error": err.Error()}
 		}
 	}()
@@ -852,11 +852,6 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
 		return
 	}
 
-	if _, err := layer.Commit(); err != nil {
-		c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-		return
-	}
-
 	c.Status(http.StatusCreated)
 }
 

+ 20 - 24
server/routes_test.go

@@ -124,14 +124,12 @@ func Test_Routes(t *testing.T) {
 			Method: http.MethodPost,
 			Path:   "/api/create",
 			Setup: func(t *testing.T, req *http.Request) {
-				f, err := os.CreateTemp(t.TempDir(), "ollama-model")
-				assert.Nil(t, err)
-				defer f.Close()
+				fname := createTestFile(t, "ollama-model")
 
 				stream := false
 				createReq := api.CreateRequest{
 					Name:      "t-bone",
-					Modelfile: fmt.Sprintf("FROM %s", f.Name()),
+					Modelfile: fmt.Sprintf("FROM %s", fname),
 					Stream:    &stream,
 				}
 				jsonData, err := json.Marshal(createReq)
@@ -216,27 +214,25 @@ func Test_Routes(t *testing.T) {
 	httpSrv := httptest.NewServer(router)
 	t.Cleanup(httpSrv.Close)
 
-	workDir, err := os.MkdirTemp("", "ollama-test")
-	assert.Nil(t, err)
-	defer os.RemoveAll(workDir)
-	os.Setenv("OLLAMA_MODELS", workDir)
+	t.Setenv("OLLAMA_MODELS", t.TempDir())
 
 	for _, tc := range testCases {
-		t.Logf("Running Test: [%s]", tc.Name)
-		u := httpSrv.URL + tc.Path
-		req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
-		assert.Nil(t, err)
-
-		if tc.Setup != nil {
-			tc.Setup(t, req)
-		}
-
-		resp, err := httpSrv.Client().Do(req)
-		assert.Nil(t, err)
-		defer resp.Body.Close()
-
-		if tc.Expected != nil {
-			tc.Expected(t, resp)
-		}
+		t.Run(tc.Name, func(t *testing.T) {
+			u := httpSrv.URL + tc.Path
+			req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
+			assert.Nil(t, err)
+
+			if tc.Setup != nil {
+				tc.Setup(t, req)
+			}
+
+			resp, err := httpSrv.Client().Do(req)
+			assert.Nil(t, err)
+			defer resp.Body.Close()
+
+			if tc.Expected != nil {
+				tc.Expected(t, resp)
+			}
+		})
 	}
 }