Просмотр исходного кода

allow embedding from model binary

Bruce MacDonald 1 год назад
Родитель
Сommit
884d78ceb3
1 измененных файлов с 16 добавлено и 6 удалено
  1. 16 6
      server/images.go

+ 16 - 6
server/images.go

@@ -268,7 +268,7 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
 						if err := PullModel(c.Args, &RegistryOptions{}, fn); err != nil {
 							return err
 						}
-						mf, err = GetManifest(ParseModelPath(modelFile))
+						mf, err = GetManifest(ParseModelPath(c.Args))
 						if err != nil {
 							return fmt.Errorf("failed to open file after pull: %v", err)
 						}
@@ -354,6 +354,8 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
 		embed.opts.FromMap(formattedParams)
 	}
 
+	fmt.Println(embed.model)
+
 	// generate the embedding layers
 	embeddingLayers, err := embeddingLayers(embed)
 	if err != nil {
@@ -406,13 +408,21 @@ type EmbeddingParams struct {
 func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
 	layers := []*LayerReader{}
 	if len(e.files) > 0 {
-		model, err := GetModel(e.model)
-		if err != nil {
-			return nil, fmt.Errorf("failed to get model to generate embeddings: %v", err)
+		if _, err := os.Stat(e.model); err != nil {
+			if os.IsNotExist(err) {
+				// this is a model name rather than the file
+				model, err := GetModel(e.model)
+				if err != nil {
+					return nil, fmt.Errorf("failed to get model to generate embeddings: %v", err)
+				}
+				e.model = model.ModelPath
+			} else {
+				return nil, fmt.Errorf("failed to get model file to generate embeddings: %v", err)
+			}
 		}
 
 		e.opts.EmbeddingOnly = true
-		llm, err := llama.New(model.ModelPath, e.opts)
+		llm, err := llama.New(e.model, e.opts)
 		if err != nil {
 			return nil, fmt.Errorf("load model to generate embeddings: %v", err)
 		}
@@ -475,7 +485,7 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
 							log.Printf("reloading model, embedding contains NaN or Inf")
 							// reload the model to get a new embedding, the seed can effect these outputs and reloading changes it
 							llm.Close()
-							llm, err = llama.New(model.ModelPath, e.opts)
+							llm, err = llama.New(e.model, e.opts)
 							if err != nil {
 								return nil, fmt.Errorf("load model to generate embeddings: %v", err)
 							}