浏览代码

pr comments

- default to embeddings enabled
- move embedding logic for loaded model to request
- allow embedding full directory
- close llm on reload
Bruce MacDonald 1 年之前
父节点
当前提交
21ddcaa1f1
共有 3 个文件被更改,包括 97 次插入82 次删除
  1. 1 0
      api/types.go
  2. 79 81
      server/images.go
  3. 17 1
      server/routes.go

+ 1 - 0
api/types.go

@@ -275,6 +275,7 @@ func DefaultOptions() Options {
 		UseMLock:           false,
 		RopeFrequencyBase:  10000.0,
 		RopeFrequencyScale: 1.0,
+		EmbeddingOnly:      true,
 
 		RepeatLastN:      64,
 		RepeatPenalty:    1.1,

+ 79 - 81
server/images.go

@@ -23,7 +23,6 @@ import (
 	"github.com/jmorganca/ollama/llama"
 	"github.com/jmorganca/ollama/parser"
 	"github.com/jmorganca/ollama/vector"
-	"gonum.org/v1/gonum/mat"
 )
 
 type RegistryOptions struct {
@@ -42,7 +41,7 @@ type Model struct {
 	Embeddings []vector.Embedding
 }
 
-func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
+func (m *Model) Prompt(request api.GenerateRequest, embedding string) (string, error) {
 	t := m.Template
 	if request.Template != "" {
 		t = request.Template
@@ -67,26 +66,12 @@ func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
 	vars.System = m.System
 	vars.Prompt = request.Prompt
 	vars.Context = request.Context
+	vars.Embed = embedding
 
 	if request.System != "" {
 		vars.System = request.System
 	}
 
-	if len(m.Embeddings) > 0 {
-		promptEmbed, err := loaded.llm.Embedding(request.Prompt)
-		if err != nil {
-			return "", fmt.Errorf("failed to get embedding for prompt: %v", err)
-		}
-		// TODO: set embed_top from specified parameters in modelfile
-		embed_top := 3
-		embed := vector.TopK(embed_top, mat.NewVecDense(len(promptEmbed), promptEmbed), loaded.Embeddings)
-		toEmbed := ""
-		for _, e := range embed {
-			toEmbed = fmt.Sprintf("%s %s", toEmbed, e.Embedding.Data)
-		}
-		vars.Embed = toEmbed
-	}
-
 	var sb strings.Builder
 	if err := tmpl.Execute(&sb, vars); err != nil {
 		return "", err
@@ -432,85 +417,98 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
 			return nil, fmt.Errorf("load model to generate embeddings: %v", err)
 		}
 
-		for _, filePath := range e.files {
-			// TODO: check if txt file type
-			f, err := os.Open(filePath)
+		addedFiles := make(map[string]bool) // keep track of files that have already been added
+		for _, filePattern := range e.files {
+			matchingFiles, err := filepath.Glob(filePattern)
 			if err != nil {
-				return nil, fmt.Errorf("could not open embed file: %w", err)
+				return nil, fmt.Errorf("could not find files with pattern %s: %w", filePattern, err)
 			}
-			scanner := bufio.NewScanner(f)
-			scanner.Split(bufio.ScanLines)
 
-			data := []string{}
-			for scanner.Scan() {
-				data = append(data, scanner.Text())
-			}
-			f.Close()
-
-			// the digest of the file is set here so that the client knows a new operation is in progress
-			fileDigest, _ := GetSHA256Digest(bytes.NewReader([]byte(filePath)))
-
-			embeddings := []vector.Embedding{}
-			for i, d := range data {
-				if strings.TrimSpace(d) == "" {
+			for _, filePath := range matchingFiles {
+				if addedFiles[filePath] {
 					continue
 				}
-				e.fn(api.ProgressResponse{
-					Status:    fmt.Sprintf("creating embeddings for file %s", filePath),
-					Digest:    fileDigest,
-					Total:     len(data) - 1,
-					Completed: i,
-				})
-				retry := 0
-			generate:
-				if retry > 3 {
-					log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
-					continue
-				}
-				embed, err := llm.Embedding(d)
+				addedFiles[filePath] = true
+				// TODO: check file type
+				f, err := os.Open(filePath)
 				if err != nil {
-					log.Printf("retrying embedding generation for '%s' line %d: %v", filePath, i+1, err)
-					retry++
-					goto generate
+					return nil, fmt.Errorf("could not open embed file: %w", err)
 				}
-				// Check for NaN and Inf in the embedding, which can't be stored
-				for _, value := range embed {
-					if math.IsNaN(value) || math.IsInf(value, 0) {
-						log.Printf("reloading model, embedding contains NaN or Inf")
-						// reload the model to get a new embedding
-						llm, err = llama.New(model.ModelPath, e.opts)
-						if err != nil {
-							return nil, fmt.Errorf("load model to generate embeddings: %v", err)
-						}
+				scanner := bufio.NewScanner(f)
+				scanner.Split(bufio.ScanLines)
+
+				data := []string{}
+				for scanner.Scan() {
+					data = append(data, scanner.Text())
+				}
+				f.Close()
+
+				// the digest of the file is set here so that the client knows a new operation is in progress
+				fileDigest, _ := GetSHA256Digest(bytes.NewReader([]byte(filePath)))
+
+				embeddings := []vector.Embedding{}
+				for i, d := range data {
+					if strings.TrimSpace(d) == "" {
+						continue
+					}
+					e.fn(api.ProgressResponse{
+						Status:    fmt.Sprintf("creating embeddings for file %s", filePath),
+						Digest:    fileDigest,
+						Total:     len(data) - 1,
+						Completed: i,
+					})
+					retry := 0
+				generate:
+					if retry > 3 {
+						log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
+						continue
+					}
+					embed, err := llm.Embedding(d)
+					if err != nil {
+						log.Printf("retrying embedding generation for '%s' line %d: %v", filePath, i+1, err)
 						retry++
 						goto generate
 					}
+					// Check for NaN and Inf in the embedding, which can't be stored
+					for _, value := range embed {
+						if math.IsNaN(value) || math.IsInf(value, 0) {
+							log.Printf("reloading model, embedding contains NaN or Inf")
+							// reload the model to get a new embedding, the seed can effect these outputs and reloading changes it
+							llm.Close()
+							llm, err = llama.New(model.ModelPath, e.opts)
+							if err != nil {
+								return nil, fmt.Errorf("load model to generate embeddings: %v", err)
+							}
+							retry++
+							goto generate
+						}
+					}
+					embeddings = append(embeddings, vector.Embedding{Data: d, Vector: embed})
 				}
-				embeddings = append(embeddings, vector.Embedding{Data: d, Vector: embed})
-			}
 
-			b, err := json.Marshal(embeddings)
-			if err != nil {
-				return nil, fmt.Errorf("failed to encode embeddings: %w", err)
-			}
-			r := bytes.NewReader(b)
+				b, err := json.Marshal(embeddings)
+				if err != nil {
+					return nil, fmt.Errorf("failed to encode embeddings: %w", err)
+				}
+				r := bytes.NewReader(b)
 
-			digest, size := GetSHA256Digest(r)
-			// Reset the position of the reader after calculating the digest
-			if _, err := r.Seek(0, 0); err != nil {
-				return nil, fmt.Errorf("could not reset embed reader: %w", err)
-			}
+				digest, size := GetSHA256Digest(r)
+				// Reset the position of the reader after calculating the digest
+				if _, err := r.Seek(0, io.SeekStart); err != nil {
+					return nil, fmt.Errorf("could not reset embed reader: %w", err)
+				}
 
-			layer := &LayerReader{
-				Layer: Layer{
-					MediaType: "application/vnd.ollama.image.embed",
-					Digest:    digest,
-					Size:      size,
-				},
-				Reader: r,
-			}
+				layer := &LayerReader{
+					Layer: Layer{
+						MediaType: "application/vnd.ollama.image.embed",
+						Digest:    digest,
+						Size:      size,
+					},
+					Reader: r,
+				}
 
-			layers = append(layers, layer)
+				layers = append(layers, layer)
+			}
 		}
 	}
 	return layers, nil

+ 17 - 1
server/routes.go

@@ -17,6 +17,7 @@ import (
 
 	"github.com/gin-contrib/cors"
 	"github.com/gin-gonic/gin"
+	"gonum.org/v1/gonum/mat"
 
 	"github.com/jmorganca/ollama/api"
 	"github.com/jmorganca/ollama/llama"
@@ -114,7 +115,22 @@ func GenerateHandler(c *gin.Context) {
 
 	checkpointLoaded := time.Now()
 
-	prompt, err := model.Prompt(req)
+	embedding := ""
+	if model.Embeddings != nil && len(model.Embeddings) > 0 {
+		promptEmbed, err := loaded.llm.Embedding(req.Prompt)
+		if err != nil {
+			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+			return
+		}
+		// TODO: set embed_top from specified parameters in modelfile
+		embed_top := 3
+		topK := vector.TopK(embed_top, mat.NewVecDense(len(promptEmbed), promptEmbed), loaded.Embeddings)
+		for _, e := range topK {
+			embedding = fmt.Sprintf("%s %s", embedding, e.Embedding.Data)
+		}
+	}
+
+	prompt, err := model.Prompt(req, embedding)
 	if err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return