Ver Fonte

use loaded llm for embeddings

Bruce MacDonald há 1 ano atrás
pai
commit
326de48930
2 ficheiros alterados com 17 adições e 25 exclusões
  1. 14 24
      server/images.go
  2. 3 1
      server/routes.go

+ 14 - 24
server/images.go

@@ -263,7 +263,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
 
 
 	var layers []*LayerReader
 	var layers []*LayerReader
 	params := make(map[string][]string)
 	params := make(map[string][]string)
-	embed := EmbeddingParams{fn: fn, opts: api.DefaultOptions()}
+	embed := EmbeddingParams{fn: fn}
 	for _, c := range commands {
 	for _, c := range commands {
 		log.Printf("[%s] - %s\n", c.Name, c.Args)
 		log.Printf("[%s] - %s\n", c.Name, c.Args)
 		switch c.Name {
 		switch c.Name {
@@ -291,6 +291,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
 						return err
 						return err
 					}
 					}
 				} else {
 				} else {
+					embed.model = modelFile
 					// create a model from this specified file
 					// create a model from this specified file
 					fn(api.ProgressResponse{Status: "creating model layer"})
 					fn(api.ProgressResponse{Status: "creating model layer"})
 					file, err := os.Open(modelFile)
 					file, err := os.Open(modelFile)
@@ -422,8 +423,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
 		layers = append(layers, l)
 		layers = append(layers, l)
 
 
 		// apply these parameters to the embedding options, in case embeddings need to be generated using this model
 		// apply these parameters to the embedding options, in case embeddings need to be generated using this model
-		embed.opts = api.DefaultOptions()
-		embed.opts.FromMap(formattedParams)
+		embed.opts = formattedParams
 	}
 	}
 
 
 	// generate the embedding layers
 	// generate the embedding layers
@@ -469,7 +469,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
 
 
 type EmbeddingParams struct {
 type EmbeddingParams struct {
 	model string
 	model string
-	opts  api.Options
+	opts  map[string]interface{}
 	files []string // paths to files to embed
 	files []string // paths to files to embed
 	fn    func(resp api.ProgressResponse)
 	fn    func(resp api.ProgressResponse)
 }
 }
@@ -478,32 +478,22 @@ type EmbeddingParams struct {
 func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
 func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
 	layers := []*LayerReader{}
 	layers := []*LayerReader{}
 	if len(e.files) > 0 {
 	if len(e.files) > 0 {
-		if _, err := os.Stat(e.model); err != nil {
-			if os.IsNotExist(err) {
-				// this is a model name rather than the file
-				model, err := GetModel(e.model)
-				if err != nil {
-					return nil, fmt.Errorf("failed to get model to generate embeddings: %v", err)
-				}
-				e.model = model.ModelPath
-			} else {
-				return nil, fmt.Errorf("failed to get model file to generate embeddings: %v", err)
+		// check if the model is a file path or a model name
+		model, err := GetModel(e.model)
+		if err != nil {
+			if !strings.Contains(err.Error(), "couldn't open file") {
+				return nil, fmt.Errorf("unexpected error opening model to generate embeddings: %v", err)
 			}
 			}
+			// the model may be a file path, create a model from this file
+			model = &Model{ModelPath: e.model}
 		}
 		}
 
 
-		e.opts.EmbeddingOnly = true
-		llmModel, err := llm.New(e.model, []string{}, e.opts)
-		if err != nil {
+		if err := load(model, e.opts, defaultSessionDuration); err != nil {
 			return nil, fmt.Errorf("load model to generate embeddings: %v", err)
 			return nil, fmt.Errorf("load model to generate embeddings: %v", err)
 		}
 		}
-		defer func() {
-			if llmModel != nil {
-				llmModel.Close()
-			}
-		}()
 
 
 		// this will be used to check if we already have embeddings for a file
 		// this will be used to check if we already have embeddings for a file
-		modelInfo, err := os.Stat(e.model)
+		modelInfo, err := os.Stat(model.ModelPath)
 		if err != nil {
 		if err != nil {
 			return nil, fmt.Errorf("failed to get model file info: %v", err)
 			return nil, fmt.Errorf("failed to get model file info: %v", err)
 		}
 		}
@@ -561,7 +551,7 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
 						embeddings = append(embeddings, vector.Embedding{Data: d, Vector: existing[d]})
 						embeddings = append(embeddings, vector.Embedding{Data: d, Vector: existing[d]})
 						continue
 						continue
 					}
 					}
-					embed, err := llmModel.Embedding(d)
+					embed, err := loaded.llm.Embedding(d)
 					if err != nil {
 					if err != nil {
 						log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
 						log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
 						continue
 						continue

+ 3 - 1
server/routes.go

@@ -38,6 +38,8 @@ var loaded struct {
 	options api.Options
 	options api.Options
 }
 }
 
 
+var defaultSessionDuration = 5 * time.Minute
+
 // load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
 // load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
 func load(model *Model, reqOpts map[string]interface{}, sessionDuration time.Duration) error {
 func load(model *Model, reqOpts map[string]interface{}, sessionDuration time.Duration) error {
 	opts := api.DefaultOptions()
 	opts := api.DefaultOptions()
@@ -134,7 +136,7 @@ func GenerateHandler(c *gin.Context) {
 		return
 		return
 	}
 	}
 
 
-	sessionDuration := 5 * time.Minute
+	sessionDuration := defaultSessionDuration // TODO: set this duration from the request if specified
 	if err := load(model, req.Options, sessionDuration); err != nil {
 	if err := load(model, req.Options, sessionDuration); err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 		return