浏览代码

refactor create model

Michael Yang 1 年之前
父节点
当前提交
b0d14ed51c
共有 2 个文件被更改,包括 152 次插入174 次删除
  1. 137 173
      server/images.go
  2. 15 1
      server/routes.go

+ 137 - 173
server/images.go

@@ -248,200 +248,172 @@ func filenameWithPath(path, f string) (string, error) {
 	return f, nil
 }
 
-func CreateModel(ctx context.Context, name string, path string, fn func(resp api.ProgressResponse)) error {
-	mp := ParseModelPath(name)
-
-	var manifest *ManifestV2
-	var err error
-	var noprune string
-
-	// build deleteMap to prune unused layers
-	deleteMap := make(map[string]bool)
-
-	if noprune = os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
-		manifest, _, err = GetManifest(mp)
-		if err != nil && !errors.Is(err, os.ErrNotExist) {
-			return err
-		}
-
-		if manifest != nil {
-			for _, l := range manifest.Layers {
-				deleteMap[l.Digest] = true
-			}
-			deleteMap[manifest.Config.Digest] = true
-		}
+func realpath(p string) string {
+	abspath, err := filepath.Abs(p)
+	if err != nil {
+		return p
 	}
 
-	mf, err := os.Open(path)
+	home, err := os.UserHomeDir()
 	if err != nil {
-		fn(api.ProgressResponse{Status: fmt.Sprintf("couldn't open modelfile '%s'", path)})
-		return fmt.Errorf("failed to open file: %w", err)
+		return abspath
 	}
-	defer mf.Close()
 
-	fn(api.ProgressResponse{Status: "parsing modelfile"})
-	commands, err := parser.Parse(mf)
-	if err != nil {
-		return err
+	if p == "~" {
+		return home
+	} else if strings.HasPrefix(p, "~/") {
+		return filepath.Join(home, p[2:])
 	}
 
+	return abspath
+}
+
+func CreateModel(ctx context.Context, name string, commands []parser.Command, fn func(resp api.ProgressResponse)) error {
 	config := ConfigV2{
-		Architecture: "amd64",
 		OS:           "linux",
+		Architecture: "amd64",
 	}
 
+	deleteMap := make(map[string]struct{})
+
 	var layers []*LayerReader
+
 	params := make(map[string][]string)
-	var sourceParams map[string]any
+	fromParams := make(map[string]any)
+
 	for _, c := range commands {
-		log.Printf("[%s] - %s\n", c.Name, c.Args)
+		log.Printf("[%s] - %s", c.Name, c.Args)
+		mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
+
 		switch c.Name {
 		case "model":
-			fn(api.ProgressResponse{Status: "looking for model"})
-
-			mp := ParseModelPath(c.Args)
-			mf, _, err := GetManifest(mp)
+			bin, err := os.Open(realpath(c.Args))
 			if err != nil {
-				modelFile, err := filenameWithPath(path, c.Args)
-				if err != nil {
-					return err
-				}
-				if _, err := os.Stat(modelFile); err != nil {
-					// the model file does not exist, try pulling it
-					if errors.Is(err, os.ErrNotExist) {
-						fn(api.ProgressResponse{Status: "pulling model file"})
-						if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil {
-							return err
-						}
-						mf, _, err = GetManifest(mp)
-						if err != nil {
-							return fmt.Errorf("failed to open file after pull: %v", err)
-						}
-					} else {
+				// not a file on disk so must be a model reference
+				modelpath := ParseModelPath(c.Args)
+				manifest, _, err := GetManifest(modelpath)
+				switch {
+				case errors.Is(err, os.ErrNotExist):
+					fn(api.ProgressResponse{Status: "pulling model"})
+					if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil {
 						return err
 					}
-				} else {
-					// create a model from this specified file
-					fn(api.ProgressResponse{Status: "creating model layer"})
-					file, err := os.Open(modelFile)
-					if err != nil {
-						return fmt.Errorf("failed to open file: %v", err)
-					}
-					defer file.Close()
 
-					ggml, err := llm.DecodeGGML(file)
+					manifest, _, err = GetManifest(modelpath)
 					if err != nil {
 						return err
 					}
-
-					config.ModelFormat = ggml.Name()
-					config.ModelFamily = ggml.ModelFamily()
-					config.ModelType = ggml.ModelType()
-					config.FileType = ggml.FileType()
-
-					// reset the file
-					file.Seek(0, io.SeekStart)
-
-					l, err := CreateLayer(file)
-					if err != nil {
-						return fmt.Errorf("failed to create layer: %v", err)
-					}
-					l.MediaType = "application/vnd.ollama.image.model"
-					layers = append(layers, l)
+				case err != nil:
+					return err
 				}
-			}
 
-			if mf != nil {
 				fn(api.ProgressResponse{Status: "reading model metadata"})
-				sourceBlobPath, err := GetBlobsPath(mf.Config.Digest)
+				fromConfigPath, err := GetBlobsPath(manifest.Config.Digest)
 				if err != nil {
 					return err
 				}
 
-				sourceBlob, err := os.Open(sourceBlobPath)
+				fromConfigFile, err := os.Open(fromConfigPath)
 				if err != nil {
 					return err
 				}
-				defer sourceBlob.Close()
+				defer fromConfigFile.Close()
 
-				var source ConfigV2
-				if err := json.NewDecoder(sourceBlob).Decode(&source); err != nil {
+				var fromConfig ConfigV2
+				if err := json.NewDecoder(fromConfigFile).Decode(&fromConfig); err != nil {
 					return err
 				}
 
-				// copy the model metadata
-				config.ModelFamily = source.ModelFamily
-				config.ModelType = source.ModelType
-				config.ModelFormat = source.ModelFormat
-				config.FileType = source.FileType
+				config.ModelFormat = fromConfig.ModelFormat
+				config.ModelFamily = fromConfig.ModelFamily
+				config.ModelType = fromConfig.ModelType
+				config.FileType = fromConfig.FileType
 
-				for _, l := range mf.Layers {
-					if l.MediaType == "application/vnd.ollama.image.params" {
-						sourceParamsBlobPath, err := GetBlobsPath(l.Digest)
+				for _, layer := range manifest.Layers {
+					deleteMap[layer.Digest] = struct{}{}
+					if layer.MediaType == "application/vnd.ollama.image.params" {
+						fromParamsPath, err := GetBlobsPath(layer.Digest)
 						if err != nil {
 							return err
 						}
 
-						sourceParamsBlob, err := os.Open(sourceParamsBlobPath)
+						fromParamsFile, err := os.Open(fromParamsPath)
 						if err != nil {
 							return err
 						}
-						defer sourceParamsBlob.Close()
+						defer fromParamsFile.Close()
 
-						if err := json.NewDecoder(sourceParamsBlob).Decode(&sourceParams); err != nil {
+						if err := json.NewDecoder(fromParamsFile).Decode(&fromParams); err != nil {
 							return err
 						}
 					}
 
-					newLayer, err := GetLayerWithBufferFromLayer(l)
+					layer, err := GetLayerWithBufferFromLayer(layer)
 					if err != nil {
 						return err
 					}
-					newLayer.From = mp.GetShortTagname()
-					layers = append(layers, newLayer)
+
+					layer.From = modelpath.GetShortTagname()
+					layers = append(layers, layer)
 				}
+
+				deleteMap[manifest.Config.Digest] = struct{}{}
+				continue
 			}
-		case "adapter":
-			fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
+			defer bin.Close()
 
-			fp, err := filenameWithPath(path, c.Args)
+			fn(api.ProgressResponse{Status: "creating model layer"})
+			ggml, err := llm.DecodeGGML(bin)
 			if err != nil {
 				return err
 			}
 
-			// create a model from this specified file
-			fn(api.ProgressResponse{Status: "creating model layer"})
+			config.ModelFormat = ggml.Name()
+			config.ModelFamily = ggml.ModelFamily()
+			config.ModelType = ggml.ModelType()
+			config.FileType = ggml.FileType()
 
-			file, err := os.Open(fp)
+			bin.Seek(0, io.SeekStart)
+			layer, err := CreateLayer(bin)
 			if err != nil {
-				return fmt.Errorf("failed to open file: %v", err)
+				return err
 			}
-			defer file.Close()
 
-			l, err := CreateLayer(file)
+			layer.MediaType = mediatype
+			layers = append(layers, layer)
+		case "adapter":
+			fn(api.ProgressResponse{Status: "creating adapter layer"})
+			bin, err := os.Open(realpath(c.Args))
 			if err != nil {
-				return fmt.Errorf("failed to create layer: %v", err)
+				return err
 			}
-			l.MediaType = "application/vnd.ollama.image.adapter"
-			layers = append(layers, l)
-		case "license":
-			fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
-			mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
+			defer bin.Close()
 
+			layer, err := CreateLayer(bin)
+			if err != nil {
+				return err
+			}
+
+			if layer.Size > 0 {
+				layer.MediaType = mediatype
+				layers = append(layers, layer)
+			}
+		case "license":
+			fn(api.ProgressResponse{Status: "creating license layer"})
 			layer, err := CreateLayer(strings.NewReader(c.Args))
 			if err != nil {
 				return err
 			}
 
 			if layer.Size > 0 {
-				layer.MediaType = mediaType
+				layer.MediaType = mediatype
 				layers = append(layers, layer)
 			}
-		case "template", "system", "prompt":
-			fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
-			// remove the layer if one exists
-			mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
-			layers = removeLayerFromLayers(layers, mediaType)
+		case "template", "system":
+			fn(api.ProgressResponse{Status: fmt.Sprintf("creating %s layer", c.Name)})
+
+			// remove duplicates layers
+			layers = removeLayerFromLayers(layers, mediatype)
 
 			layer, err := CreateLayer(strings.NewReader(c.Args))
 			if err != nil {
@@ -449,48 +421,47 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
 			}
 
 			if layer.Size > 0 {
-				layer.MediaType = mediaType
+				layer.MediaType = mediatype
 				layers = append(layers, layer)
 			}
 		default:
-			// runtime parameters, build a list of args for each parameter to allow multiple values to be specified (ex: multiple stop sequences)
 			params[c.Name] = append(params[c.Name], c.Args)
 		}
 	}
 
-	// Create a single layer for the parameters
 	if len(params) > 0 {
-		fn(api.ProgressResponse{Status: "creating parameter layer"})
+		fn(api.ProgressResponse{Status: "creating parameters layer"})
 
-		layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params")
 		formattedParams, err := formatParams(params)
 		if err != nil {
-			return fmt.Errorf("couldn't create params json: %v", err)
+			return err
 		}
 
-		for k, v := range sourceParams {
+		for k, v := range fromParams {
 			if _, ok := formattedParams[k]; !ok {
 				formattedParams[k] = v
 			}
 		}
 
 		if config.ModelType == "65B" {
-			if numGQA, ok := formattedParams["num_gqa"].(int); ok && numGQA == 8 {
+			if gqa, ok := formattedParams["gqa"].(int); ok && gqa == 8 {
 				config.ModelType = "70B"
 			}
 		}
 
-		bts, err := json.Marshal(formattedParams)
-		if err != nil {
+		var b bytes.Buffer
+		if err := json.NewEncoder(&b).Encode(formattedParams); err != nil {
 			return err
 		}
 
-		l, err := CreateLayer(bytes.NewReader(bts))
+		fn(api.ProgressResponse{Status: "creating config layer"})
+		layer, err := CreateLayer(bytes.NewReader(b.Bytes()))
 		if err != nil {
-			return fmt.Errorf("failed to create layer: %v", err)
+			return err
 		}
-		l.MediaType = "application/vnd.ollama.image.params"
-		layers = append(layers, l)
+
+		layer.MediaType = "application/vnd.ollama.image.params"
+		layers = append(layers, layer)
 	}
 
 	digests, err := getLayerDigests(layers)
@@ -498,36 +469,31 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
 		return err
 	}
 
-	var manifestLayers []*Layer
-	for _, l := range layers {
-		manifestLayers = append(manifestLayers, &l.Layer)
-		delete(deleteMap, l.Layer.Digest)
-	}
-
-	// Create a layer for the config object
-	fn(api.ProgressResponse{Status: "creating config layer"})
-	cfg, err := createConfigLayer(config, digests)
+	configLayer, err := createConfigLayer(config, digests)
 	if err != nil {
 		return err
 	}
-	layers = append(layers, cfg)
-	delete(deleteMap, cfg.Layer.Digest)
+
+	layers = append(layers, configLayer)
+	delete(deleteMap, configLayer.Digest)
 
 	if err := SaveLayers(layers, fn, false); err != nil {
 		return err
 	}
 
-	// Create the manifest
+	var contentLayers []*Layer
+	for _, layer := range layers {
+		contentLayers = append(contentLayers, &layer.Layer)
+		delete(deleteMap, layer.Digest)
+	}
+
 	fn(api.ProgressResponse{Status: "writing manifest"})
-	err = CreateManifest(name, cfg, manifestLayers)
-	if err != nil {
+	if err := CreateManifest(name, configLayer, contentLayers); err != nil {
 		return err
 	}
 
-	if noprune == "" {
-		fn(api.ProgressResponse{Status: "removing any unused layers"})
-		err = deleteUnusedLayers(nil, deleteMap, false)
-		if err != nil {
+	if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
+		if err := deleteUnusedLayers(nil, deleteMap, false); err != nil {
 			return err
 		}
 	}
@@ -739,7 +705,7 @@ func CopyModel(src, dest string) error {
 	return nil
 }
 
-func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]bool, dryRun bool) error {
+func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{}, dryRun bool) error {
 	fp, err := GetManifestPath()
 	if err != nil {
 		return err
@@ -779,21 +745,19 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]bool, dry
 	}
 
 	// only delete the files which are still in the deleteMap
-	for k, v := range deleteMap {
-		if v {
-			fp, err := GetBlobsPath(k)
-			if err != nil {
-				log.Printf("couldn't get file path for '%s': %v", k, err)
+	for k := range deleteMap {
+		fp, err := GetBlobsPath(k)
+		if err != nil {
+			log.Printf("couldn't get file path for '%s': %v", k, err)
+			continue
+		}
+		if !dryRun {
+			if err := os.Remove(fp); err != nil {
+				log.Printf("couldn't remove file '%s': %v", fp, err)
 				continue
 			}
-			if !dryRun {
-				if err := os.Remove(fp); err != nil {
-					log.Printf("couldn't remove file '%s': %v", fp, err)
-					continue
-				}
-			} else {
-				log.Printf("wanted to remove: %s", fp)
-			}
+		} else {
+			log.Printf("wanted to remove: %s", fp)
 		}
 	}
 
@@ -801,7 +765,7 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]bool, dry
 }
 
 func PruneLayers() error {
-	deleteMap := make(map[string]bool)
+	deleteMap := make(map[string]struct{})
 	p, err := GetBlobsPath("")
 	if err != nil {
 		return err
@@ -818,7 +782,7 @@ func PruneLayers() error {
 		if runtime.GOOS == "windows" {
 			name = strings.ReplaceAll(name, "-", ":")
 		}
-		deleteMap[name] = true
+		deleteMap[name] = struct{}{}
 	}
 
 	log.Printf("total blobs: %d", len(deleteMap))
@@ -873,11 +837,11 @@ func DeleteModel(name string) error {
 		return err
 	}
 
-	deleteMap := make(map[string]bool)
+	deleteMap := make(map[string]struct{})
 	for _, layer := range manifest.Layers {
-		deleteMap[layer.Digest] = true
+		deleteMap[layer.Digest] = struct{}{}
 	}
-	deleteMap[manifest.Config.Digest] = true
+	deleteMap[manifest.Config.Digest] = struct{}{}
 
 	err = deleteUnusedLayers(&mp, deleteMap, false)
 	if err != nil {
@@ -1013,7 +977,7 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
 	var noprune string
 
 	// build deleteMap to prune unused layers
-	deleteMap := make(map[string]bool)
+	deleteMap := make(map[string]struct{})
 
 	if noprune = os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
 		manifest, _, err = GetManifest(mp)
@@ -1023,9 +987,9 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
 
 		if manifest != nil {
 			for _, l := range manifest.Layers {
-				deleteMap[l.Digest] = true
+				deleteMap[l.Digest] = struct{}{}
 			}
-			deleteMap[manifest.Config.Digest] = true
+			deleteMap[manifest.Config.Digest] = struct{}{}
 		}
 	}
 

+ 15 - 1
server/routes.go

@@ -26,6 +26,7 @@ import (
 
 	"github.com/jmorganca/ollama/api"
 	"github.com/jmorganca/ollama/llm"
+	"github.com/jmorganca/ollama/parser"
 	"github.com/jmorganca/ollama/version"
 )
 
@@ -414,6 +415,19 @@ func CreateModelHandler(c *gin.Context) {
 		return
 	}
 
+	modelfile, err := os.Open(req.Path)
+	if err != nil {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		return
+	}
+	defer modelfile.Close()
+
+	commands, err := parser.Parse(modelfile)
+	if err != nil {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		return
+	}
+
 	ch := make(chan any)
 	go func() {
 		defer close(ch)
@@ -424,7 +438,7 @@ func CreateModelHandler(c *gin.Context) {
 		ctx, cancel := context.WithCancel(c.Request.Context())
 		defer cancel()
 
-		if err := CreateModel(ctx, req.Name, req.Path, fn); err != nil {
+		if err := CreateModel(ctx, req.Name, commands, fn); err != nil {
 			ch <- gin.H{"error": err.Error()}
 		}
 	}()