Browse Source

embed text document in modelfile

Bruce MacDonald 1 year ago
parent
commit
a6f6d18f83
8 changed files with 330 additions and 59 deletions
  1. 12 6
      cmd/cmd.go
  2. 1 0
      go.mod
  3. 2 0
      go.sum
  4. 37 0
      llama/llama.go
  5. 1 1
      parser/parser.go
  6. 200 50
      server/images.go
  7. 8 2
      server/routes.go
  8. 69 0
      vector/store.go

+ 12 - 6
cmd/cmd.go

@@ -48,12 +48,18 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 				spinner.Stop()
 			}
 			currentDigest = resp.Digest
-			bar = progressbar.DefaultBytes(
-				int64(resp.Total),
-				fmt.Sprintf("pulling %s...", resp.Digest[7:19]),
-			)
-
-			bar.Set(resp.Completed)
+			switch {
+			case strings.Contains(resp.Status, "embeddings"):
+				bar = progressbar.Default(int64(resp.Total), resp.Status)
+				bar.Set(resp.Completed)
+			default:
+				// pulling
+				bar = progressbar.DefaultBytes(
+					int64(resp.Total),
+					resp.Status,
+				)
+				bar.Set(resp.Completed)
+			}
 		} else if resp.Digest == currentDigest && resp.Digest != "" {
 			bar.Set(resp.Completed)
 		} else {

+ 1 - 0
go.mod

@@ -42,6 +42,7 @@ require (
 	golang.org/x/sys v0.10.0 // indirect
 	golang.org/x/term v0.10.0
 	golang.org/x/text v0.10.0 // indirect
+	gonum.org/v1/gonum v0.13.0
 	google.golang.org/protobuf v1.30.0 // indirect
 	gopkg.in/yaml.v3 v3.0.1 // indirect
 )

+ 2 - 0
go.sum

@@ -139,6 +139,8 @@ golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58=
 golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
 golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
 golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+gonum.org/v1/gonum v0.13.0 h1:a0T3bh+7fhRyqeNbiC3qVHYmkiQgit3wnNan/2c0HMM=
+gonum.org/v1/gonum v0.13.0/go.mod h1:/WPYRckkfWrhWefxyYTfrTtQR0KH4iyHNuzxqXAKyAU=
 google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
 google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
 google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=

+ 37 - 0
llama/llama.go

@@ -85,6 +85,7 @@ llama_token llama_sample(
 }
 */
 import "C"
+
 import (
 	"bytes"
 	"embed"
@@ -93,6 +94,7 @@ import (
 	"io"
 	"log"
 	"os"
+	"reflect"
 	"strings"
 	"sync"
 	"unicode/utf8"
@@ -414,3 +416,38 @@ func (llm *LLM) next() (C.llama_token, error) {
 
 	return token, nil
 }
+
+func (llm *LLM) Embedding(input string) ([]float64, error) {
+	if !llm.EmbeddingOnly {
+		return nil, errors.New("llama: embedding not enabled")
+	}
+
+	tokens := llm.tokenize(input)
+	if tokens == nil {
+		return nil, errors.New("llama: tokenize embedding")
+	}
+
+	retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread))
+	if retval != 0 {
+		return nil, errors.New("llama: eval")
+	}
+
+	n := int(C.llama_n_embd(llm.ctx))
+	if n <= 0 {
+		return nil, errors.New("llama: no embeddings generated")
+	}
+
+	embedPtr := C.llama_get_embeddings(llm.ctx)
+	if embedPtr == nil {
+		return nil, errors.New("llama: embedding retrieval failed")
+	}
+
+	header := reflect.SliceHeader{
+		Data: uintptr(unsafe.Pointer(embedPtr)),
+		Len:  n,
+		Cap:  n,
+	}
+	embedSlice := *(*[]float64)(unsafe.Pointer(&header))
+
+	return embedSlice, nil
+}

+ 1 - 1
parser/parser.go

@@ -40,7 +40,7 @@ func Parse(reader io.Reader) ([]Command, error) {
 			command.Args = string(fields[1])
 			// copy command for validation
 			modelCommand = command
-		case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT":
+		case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT", "EMBED":
 			command.Name = string(bytes.ToLower(fields[0]))
 			command.Args = string(fields[1])
 		case "PARAMETER":

+ 200 - 50
server/images.go

@@ -1,6 +1,7 @@
 package server
 
 import (
+	"bufio"
 	"bytes"
 	"crypto/sha256"
 	"encoding/json"
@@ -9,6 +10,7 @@ import (
 	"html/template"
 	"io"
 	"log"
+	"math"
 	"net/http"
 	"os"
 	"path"
@@ -18,7 +20,10 @@ import (
 	"strings"
 
 	"github.com/jmorganca/ollama/api"
+	"github.com/jmorganca/ollama/llama"
 	"github.com/jmorganca/ollama/parser"
+	"github.com/jmorganca/ollama/vector"
+	"gonum.org/v1/gonum/mat"
 )
 
 type RegistryOptions struct {
@@ -28,12 +33,13 @@ type RegistryOptions struct {
 }
 
 type Model struct {
-	Name      string `json:"name"`
-	ModelPath string
-	Template  string
-	System    string
-	Digest    string
-	Options   map[string]interface{}
+	Name       string `json:"name"`
+	ModelPath  string
+	Template   string
+	System     string
+	Digest     string
+	Options    map[string]interface{}
+	Embeddings []vector.Embedding
 }
 
 func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
@@ -51,6 +57,7 @@ func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
 		First  bool
 		System string
 		Prompt string
+		Embed  string
 
 		// deprecated: versions <= 0.0.7 used this to omit the system prompt
 		Context []int
@@ -65,6 +72,21 @@ func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
 		vars.System = request.System
 	}
 
+	if len(m.Embeddings) > 0 {
+		promptEmbed, err := loaded.llm.Embedding(request.Prompt)
+		if err != nil {
+			return "", fmt.Errorf("failed to get embedding for prompt: %v", err)
+		}
+		// TODO: set embed_top from specified parameters in modelfile
+		embed_top := 3
+		embed := vector.TopK(embed_top, mat.NewVecDense(len(promptEmbed), promptEmbed), loaded.Embeddings)
+		toEmbed := ""
+		for _, e := range embed {
+			toEmbed = fmt.Sprintf("%s %s", toEmbed, e.Embedding.Data)
+		}
+		vars.Embed = toEmbed
+	}
+
 	var sb strings.Builder
 	if err := tmpl.Execute(&sb, vars); err != nil {
 		return "", err
@@ -157,6 +179,16 @@ func GetModel(name string) (*Model, error) {
 		switch layer.MediaType {
 		case "application/vnd.ollama.image.model":
 			model.ModelPath = filename
+		case "application/vnd.ollama.image.embed":
+			file, err := os.Open(filename)
+			if err != nil {
+				return nil, fmt.Errorf("failed to open file: %s", filename)
+			}
+			defer file.Close()
+
+			if err = json.NewDecoder(file).Decode(&model.Embeddings); err != nil {
+				return nil, err
+			}
 		case "application/vnd.ollama.image.template":
 			bts, err := os.ReadFile(filename)
 			if err != nil {
@@ -195,6 +227,26 @@ func GetModel(name string) (*Model, error) {
 	return model, nil
 }
 
+func filenameWithPath(path, f string) (string, error) {
+	// if filePath starts with ~/, replace it with the user's home directory.
+	if strings.HasPrefix(f, "~/") {
+		parts := strings.Split(f, "/")
+		home, err := os.UserHomeDir()
+		if err != nil {
+			return "", fmt.Errorf("failed to open file: %v", err)
+		}
+
+		f = filepath.Join(home, filepath.Join(parts[1:]...))
+	}
+
+	// if filePath is not an absolute path, make it relative to the modelfile path
+	if !filepath.IsAbs(f) {
+		f = filepath.Join(filepath.Dir(path), f)
+	}
+
+	return f, nil
+}
+
 func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) error {
 	mf, err := os.Open(path)
 	if err != nil {
@@ -211,52 +263,37 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
 
 	var layers []*LayerReader
 	params := make(map[string][]string)
-
+	embed := EmbeddingParams{fn: fn, opts: api.DefaultOptions()}
 	for _, c := range commands {
 		log.Printf("[%s] - %s\n", c.Name, c.Args)
 		switch c.Name {
 		case "model":
 			fn(api.ProgressResponse{Status: "looking for model"})
+			embed.model = c.Args
 			mf, err := GetManifest(ParseModelPath(c.Args))
 			if err != nil {
-				fp := c.Args
-
-				// If filePath starts with ~/, replace it with the user's home directory.
-				if strings.HasPrefix(fp, "~/") {
-					parts := strings.Split(fp, "/")
-					home, err := os.UserHomeDir()
-					if err != nil {
-						return fmt.Errorf("failed to open file: %v", err)
-					}
-
-					fp = filepath.Join(home, filepath.Join(parts[1:]...))
-				}
-
-				// If filePath is not an absolute path, make it relative to the modelfile path
-				if !filepath.IsAbs(fp) {
-					fp = filepath.Join(filepath.Dir(path), fp)
+				modelFile, err := filenameWithPath(path, c.Args)
+				if err != nil {
+					return err
 				}
-
-				if _, err := os.Stat(fp); err != nil {
+				if _, err := os.Stat(modelFile); err != nil {
 					// the model file does not exist, try pulling it
 					if errors.Is(err, os.ErrNotExist) {
 						fn(api.ProgressResponse{Status: "pulling model file"})
 						if err := PullModel(c.Args, &RegistryOptions{}, fn); err != nil {
 							return err
 						}
-						mf, err = GetManifest(ParseModelPath(c.Args))
+						mf, err = GetManifest(ParseModelPath(modelFile))
 						if err != nil {
 							return fmt.Errorf("failed to open file after pull: %v", err)
 						}
-
 					} else {
 						return err
 					}
 				} else {
 					// create a model from this specified file
 					fn(api.ProgressResponse{Status: "creating model layer"})
-
-					file, err := os.Open(fp)
+					file, err := os.Open(modelFile)
 					if err != nil {
 						return fmt.Errorf("failed to open file: %v", err)
 					}
@@ -280,19 +317,14 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
 					layers = append(layers, newLayer)
 				}
 			}
-		case "license":
-			fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
-			// remove the prompt layer if one exists
-			mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
-
-			layer, err := CreateLayer(strings.NewReader(c.Args))
+		case "embed":
+			// TODO: support entire directories here
+			embedFilePath, err := filenameWithPath(path, c.Args)
 			if err != nil {
 				return err
 			}
-
-			layer.MediaType = mediaType
-			layers = append(layers, layer)
-		case "template", "system", "prompt":
+			embed.files = append(embed.files, embedFilePath)
+		case "license", "template", "system", "prompt":
 			fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
 			// remove the prompt layer if one exists
 			mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
@@ -315,18 +347,35 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
 	if len(params) > 0 {
 		fn(api.ProgressResponse{Status: "creating parameter layer"})
 		layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params")
-		paramData, err := paramsToReader(params)
+		formattedParams, err := formatParams(params)
 		if err != nil {
 			return fmt.Errorf("couldn't create params json: %v", err)
 		}
-		l, err := CreateLayer(paramData)
+
+		bts, err := json.Marshal(formattedParams)
+		if err != nil {
+			return err
+		}
+
+		l, err := CreateLayer(bytes.NewReader(bts))
 		if err != nil {
 			return fmt.Errorf("failed to create layer: %v", err)
 		}
 		l.MediaType = "application/vnd.ollama.image.params"
 		layers = append(layers, l)
+
+		// apply these parameters to the embedding options, in case embeddings need to be generated using this model
+		embed.opts = api.DefaultOptions()
+		embed.opts.FromMap(formattedParams)
 	}
 
+	// generate the embedding layers
+	embeddingLayers, err := embeddingLayers(embed)
+	if err != nil {
+		return err
+	}
+	layers = append(layers, embeddingLayers...)
+
 	digests, err := getLayerDigests(layers)
 	if err != nil {
 		return err
@@ -361,6 +410,112 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
 	return nil
 }
 
+type EmbeddingParams struct {
+	model string
+	opts  api.Options
+	files []string // paths to files to embed
+	fn    func(resp api.ProgressResponse)
+}
+
+// embeddingLayers loads the associated LLM and generates the embeddings to be stored from an input file
+func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
+	layers := []*LayerReader{}
+	if len(e.files) > 0 {
+		model, err := GetModel(e.model)
+		if err != nil {
+			return nil, fmt.Errorf("failed to get model to generate embeddings: %v", err)
+		}
+
+		e.opts.EmbeddingOnly = true
+		llm, err := llama.New(model.ModelPath, e.opts)
+		if err != nil {
+			return nil, fmt.Errorf("load model to generate embeddings: %v", err)
+		}
+
+		for _, filePath := range e.files {
+			// TODO: check if txt file type
+			f, err := os.Open(filePath)
+			if err != nil {
+				return nil, fmt.Errorf("could not open embed file: %w", err)
+			}
+			scanner := bufio.NewScanner(f)
+			scanner.Split(bufio.ScanLines)
+
+			data := []string{}
+			for scanner.Scan() {
+				data = append(data, scanner.Text())
+			}
+			f.Close()
+
+			// the digest of the file is set here so that the client knows a new operation is in progress
+			fileDigest, _ := GetSHA256Digest(bytes.NewReader([]byte(filePath)))
+
+			embeddings := []vector.Embedding{}
+			for i, d := range data {
+				if strings.TrimSpace(d) == "" {
+					continue
+				}
+				e.fn(api.ProgressResponse{
+					Status:    fmt.Sprintf("creating embeddings for file %s", filePath),
+					Digest:    fileDigest,
+					Total:     len(data) - 1,
+					Completed: i,
+				})
+				retry := 0
+			generate:
+				if retry > 3 {
+					log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
+					continue
+				}
+				embed, err := llm.Embedding(d)
+				if err != nil {
+					log.Printf("retrying embedding generation for '%s' line %d: %v", filePath, i+1, err)
+					retry++
+					goto generate
+				}
+				// Check for NaN and Inf in the embedding, which can't be stored
+				for _, value := range embed {
+					if math.IsNaN(value) || math.IsInf(value, 0) {
+						log.Printf("reloading model, embedding contains NaN or Inf")
+						// reload the model to get a new embedding
+						llm, err = llama.New(model.ModelPath, e.opts)
+						if err != nil {
+							return nil, fmt.Errorf("load model to generate embeddings: %v", err)
+						}
+						retry++
+						goto generate
+					}
+				}
+				embeddings = append(embeddings, vector.Embedding{Data: d, Vector: embed})
+			}
+
+			b, err := json.Marshal(embeddings)
+			if err != nil {
+				return nil, fmt.Errorf("failed to encode embeddings: %w", err)
+			}
+			r := bytes.NewReader(b)
+
+			digest, size := GetSHA256Digest(r)
+			// Reset the position of the reader after calculating the digest
+			if _, err := r.Seek(0, 0); err != nil {
+				return nil, fmt.Errorf("could not reset embed reader: %w", err)
+			}
+
+			layer := &LayerReader{
+				Layer: Layer{
+					MediaType: "application/vnd.ollama.image.embed",
+					Digest:    digest,
+					Size:      size,
+				},
+				Reader: r,
+			}
+
+			layers = append(layers, layer)
+		}
+	}
+	return layers, nil
+}
+
 func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerReader {
 	j := 0
 	for _, l := range layers {
@@ -449,8 +604,8 @@ func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) {
 	return newLayer, nil
 }
 
-// paramsToReader converts specified parameter options to their correct types, and returns a reader for the json
-func paramsToReader(params map[string][]string) (io.ReadSeeker, error) {
+// formatParams converts specified parameter options to their correct types
+func formatParams(params map[string][]string) (map[string]interface{}, error) {
 	opts := api.Options{}
 	valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
 	typeOpts := reflect.TypeOf(opts)           // types of the fields in the options struct
@@ -504,12 +659,7 @@ func paramsToReader(params map[string][]string) (io.ReadSeeker, error) {
 		}
 	}
 
-	bts, err := json.Marshal(out)
-	if err != nil {
-		return nil, err
-	}
-
-	return bytes.NewReader(bts), nil
+	return out, nil
 }
 
 func getLayerDigests(layers []*LayerReader) ([]string, error) {
@@ -1042,7 +1192,7 @@ func downloadBlob(mp ModelPath, digest string, regOpts *RegistryOptions, fn func
 
 	for {
 		fn(api.ProgressResponse{
-			Status:    fmt.Sprintf("downloading %s", digest),
+			Status:    fmt.Sprintf("pulling %s...", digest[7:19]),
 			Digest:    digest,
 			Total:     int(total),
 			Completed: int(completed),

+ 8 - 2
server/routes.go

@@ -20,12 +20,14 @@ import (
 
 	"github.com/jmorganca/ollama/api"
 	"github.com/jmorganca/ollama/llama"
+	"github.com/jmorganca/ollama/vector"
 )
 
 var loaded struct {
 	mu sync.Mutex
 
-	llm *llama.LLM
+	llm        *llama.LLM
+	Embeddings []vector.Embedding
 
 	expireAt    time.Time
 	expireTimer *time.Timer
@@ -72,6 +74,11 @@ func GenerateHandler(c *gin.Context) {
 			loaded.digest = ""
 		}
 
+		if model.Embeddings != nil && len(model.Embeddings) > 0 {
+			opts.EmbeddingOnly = true // this is requried to generate embeddings, completions will still work
+			loaded.Embeddings = model.Embeddings
+		}
+
 		llm, err := llama.New(model.ModelPath, opts)
 		if err != nil {
 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -82,7 +89,6 @@ func GenerateHandler(c *gin.Context) {
 		loaded.digest = model.Digest
 		loaded.options = opts
 	}
-
 	sessionDuration := 5 * time.Minute
 
 	loaded.expireAt = time.Now().Add(sessionDuration)

+ 69 - 0
vector/store.go

@@ -0,0 +1,69 @@
+package vector
+
+import (
+	"container/heap"
+	"sort"
+
+	"gonum.org/v1/gonum/mat"
+)
+
+type Embedding struct {
+	Vector []float64 // the embedding vector
+	Data   string    // the data represted by the embedding
+}
+
+type EmbeddingSimilarity struct {
+	Embedding  Embedding // the embedding that was used to calculate the similarity
+	Similarity float64   // the similarity between the embedding and the query
+}
+
+type Heap []EmbeddingSimilarity
+
+func (h Heap) Len() int           { return len(h) }
+func (h Heap) Less(i, j int) bool { return h[i].Similarity < h[j].Similarity }
+func (h Heap) Swap(i, j int)      { h[i], h[j] = h[j], h[i] }
+func (h *Heap) Push(e any) {
+	*h = append(*h, e.(EmbeddingSimilarity))
+}
+
+func (h *Heap) Pop() interface{} {
+	old := *h
+	n := len(old)
+	x := old[n-1]
+	*h = old[0 : n-1]
+	return x
+}
+
+// cosineSimilarity is a measure that calculates the cosine of the angle between two vectors.
+// This value will range from -1 to 1, where 1 means the vectors are identical.
+func cosineSimilarity(vec1, vec2 *mat.VecDense) float64 {
+	dotProduct := mat.Dot(vec1, vec2)
+	norms := mat.Norm(vec1, 2) * mat.Norm(vec2, 2)
+
+	if norms == 0 {
+		return 0
+	}
+	return dotProduct / norms
+}
+
+func TopK(k int, query *mat.VecDense, embeddings []Embedding) []EmbeddingSimilarity {
+	h := &Heap{}
+	heap.Init(h)
+	for _, emb := range embeddings {
+		similarity := cosineSimilarity(query, mat.NewVecDense(len(emb.Vector), emb.Vector))
+		heap.Push(h, EmbeddingSimilarity{Embedding: emb, Similarity: similarity})
+		if h.Len() > k {
+			heap.Pop(h)
+		}
+	}
+
+	topK := make([]EmbeddingSimilarity, 0, h.Len())
+	for h.Len() > 0 {
+		topK = append(topK, heap.Pop(h).(EmbeddingSimilarity))
+	}
+	sort.Slice(topK, func(i, j int) bool {
+		return topK[i].Similarity > topK[j].Similarity
+	})
+
+	return topK
+}