Browse Source

embed text document in modelfile

Bruce MacDonald 1 year ago
parent
commit
7a5f3616fd
10 changed files with 371 additions and 52 deletions
  1. 1 0
      api/types.go
  2. 12 6
      cmd/cmd.go
  3. 11 1
      docs/modelfile.md
  4. 1 0
      go.mod
  5. 2 0
      go.sum
  6. 37 0
      llama/llama.go
  7. 1 1
      parser/parser.go
  8. 212 41
      server/images.go
  9. 25 3
      server/routes.go
  10. 69 0
      vector/store.go

+ 1 - 0
api/types.go

@@ -276,6 +276,7 @@ func DefaultOptions() Options {
 		UseMLock:           false,
 		RopeFrequencyBase:  10000.0,
 		RopeFrequencyScale: 1.0,
+		EmbeddingOnly:      true,
 
 		RepeatLastN:      64,
 		RepeatPenalty:    1.1,

+ 12 - 6
cmd/cmd.go

@@ -48,12 +48,18 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 				spinner.Stop()
 			}
 			currentDigest = resp.Digest
-			bar = progressbar.DefaultBytes(
-				int64(resp.Total),
-				fmt.Sprintf("pulling %s...", resp.Digest[7:19]),
-			)
-
-			bar.Set(resp.Completed)
+			switch {
+			case strings.Contains(resp.Status, "embeddings"):
+				bar = progressbar.Default(int64(resp.Total), resp.Status)
+				bar.Set(resp.Completed)
+			default:
+				// pulling
+				bar = progressbar.DefaultBytes(
+					int64(resp.Total),
+					resp.Status,
+				)
+				bar.Set(resp.Completed)
+			}
 		} else if resp.Digest == currentDigest && resp.Digest != "" {
 			bar.Set(resp.Completed)
 		} else {

+ 11 - 1
docs/modelfile.md

@@ -12,6 +12,7 @@ A model file is the blueprint to create and share models with Ollama.
   - [FROM (Required)](#from-required)
     - [Build from llama2](#build-from-llama2)
     - [Build from a bin file](#build-from-a-bin-file)
+  - [EMBED](#embed)
   - [PARAMETER](#parameter)
     - [Valid Parameters and Values](#valid-parameters-and-values)
   - [TEMPLATE](#template)
@@ -88,6 +89,15 @@ FROM ./ollama-model.bin
 
 This bin file location should be specified as an absolute path or relative to the Modelfile location.
 
+### EMBED
+
+The EMBED instruction is used to add embeddings of files to a model. This is useful for adding custom data that the model can reference when generating an answer.
+
+```
+FROM <model name>:<tag>
+EMBED <file path>
+```
+
 ### PARAMETER
 
 The `PARAMETER` instruction defines a parameter that can be set when the model is run.
@@ -163,4 +173,4 @@ LICENSE """
 ## Notes
 
 - the **modelfile is not case sensitive**. In the examples, we use uppercase for instructions to make it easier to distinguish it from arguments.
-- Instructions can be in any order. In the examples, we start with FROM instruction to keep it easily readable.
+- Instructions can be in any order. In the examples, we start with FROM instruction to keep it easily readable.

+ 1 - 0
go.mod

@@ -42,6 +42,7 @@ require (
 	golang.org/x/sys v0.10.0 // indirect
 	golang.org/x/term v0.10.0
 	golang.org/x/text v0.10.0 // indirect
+	gonum.org/v1/gonum v0.13.0
 	google.golang.org/protobuf v1.30.0 // indirect
 	gopkg.in/yaml.v3 v3.0.1 // indirect
 )

+ 2 - 0
go.sum

@@ -139,6 +139,8 @@ golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58=
 golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
 golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
 golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+gonum.org/v1/gonum v0.13.0 h1:a0T3bh+7fhRyqeNbiC3qVHYmkiQgit3wnNan/2c0HMM=
+gonum.org/v1/gonum v0.13.0/go.mod h1:/WPYRckkfWrhWefxyYTfrTtQR0KH4iyHNuzxqXAKyAU=
 google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
 google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
 google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=

+ 37 - 0
llama/llama.go

@@ -85,6 +85,7 @@ llama_token llama_sample(
 }
 */
 import "C"
+
 import (
 	"bytes"
 	"embed"
@@ -93,6 +94,7 @@ import (
 	"io"
 	"log"
 	"os"
+	"reflect"
 	"strings"
 	"sync"
 	"unicode/utf8"
@@ -408,3 +410,38 @@ func (llm *LLM) next() (C.llama_token, error) {
 
 	return token, nil
 }
+
+func (llm *LLM) Embedding(input string) ([]float64, error) {
+	if !llm.EmbeddingOnly {
+		return nil, errors.New("llama: embedding not enabled")
+	}
+
+	tokens := llm.tokenize(input)
+	if tokens == nil {
+		return nil, errors.New("llama: tokenize embedding")
+	}
+
+	retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread))
+	if retval != 0 {
+		return nil, errors.New("llama: eval")
+	}
+
+	n := int(C.llama_n_embd(llm.ctx))
+	if n <= 0 {
+		return nil, errors.New("llama: no embeddings generated")
+	}
+
+	embedPtr := C.llama_get_embeddings(llm.ctx)
+	if embedPtr == nil {
+		return nil, errors.New("llama: embedding retrieval failed")
+	}
+
+	header := reflect.SliceHeader{
+		Data: uintptr(unsafe.Pointer(embedPtr)),
+		Len:  n,
+		Cap:  n,
+	}
+	embedSlice := *(*[]float64)(unsafe.Pointer(&header))
+
+	return embedSlice, nil
+}

+ 1 - 1
parser/parser.go

@@ -40,7 +40,7 @@ func Parse(reader io.Reader) ([]Command, error) {
 			command.Args = string(fields[1])
 			// copy command for validation
 			modelCommand = command
-		case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT":
+		case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT", "EMBED":
 			command.Name = string(bytes.ToLower(fields[0]))
 			command.Args = string(fields[1])
 		case "PARAMETER":

+ 212 - 41
server/images.go

@@ -1,6 +1,7 @@
 package server
 
 import (
+	"bufio"
 	"bytes"
 	"crypto/sha256"
 	"encoding/json"
@@ -9,6 +10,7 @@ import (
 	"html/template"
 	"io"
 	"log"
+	"math"
 	"net/http"
 	"os"
 	"path"
@@ -18,7 +20,9 @@ import (
 	"strings"
 
 	"github.com/jmorganca/ollama/api"
+	"github.com/jmorganca/ollama/llama"
 	"github.com/jmorganca/ollama/parser"
+	"github.com/jmorganca/ollama/vector"
 )
 
 type RegistryOptions struct {
@@ -28,15 +32,16 @@ type RegistryOptions struct {
 }
 
 type Model struct {
-	Name      string `json:"name"`
-	ModelPath string
-	Template  string
-	System    string
-	Digest    string
-	Options   map[string]interface{}
+	Name       string `json:"name"`
+	ModelPath  string
+	Template   string
+	System     string
+	Digest     string
+	Options    map[string]interface{}
+	Embeddings []vector.Embedding
 }
 
-func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
+func (m *Model) Prompt(request api.GenerateRequest, embedding string) (string, error) {
 	t := m.Template
 	if request.Template != "" {
 		t = request.Template
@@ -51,6 +56,7 @@ func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
 		First  bool
 		System string
 		Prompt string
+		Embed  string
 
 		// deprecated: versions <= 0.0.7 used this to omit the system prompt
 		Context []int
@@ -60,6 +66,7 @@ func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
 	vars.System = m.System
 	vars.Prompt = request.Prompt
 	vars.Context = request.Context
+	vars.Embed = embedding
 
 	if request.System != "" {
 		vars.System = request.System
@@ -157,6 +164,16 @@ func GetModel(name string) (*Model, error) {
 		switch layer.MediaType {
 		case "application/vnd.ollama.image.model":
 			model.ModelPath = filename
+		case "application/vnd.ollama.image.embed":
+			file, err := os.Open(filename)
+			if err != nil {
+				return nil, fmt.Errorf("failed to open file: %s", filename)
+			}
+			defer file.Close()
+
+			if err = json.NewDecoder(file).Decode(&model.Embeddings); err != nil {
+				return nil, err
+			}
 		case "application/vnd.ollama.image.template":
 			bts, err := os.ReadFile(filename)
 			if err != nil {
@@ -195,6 +212,26 @@ func GetModel(name string) (*Model, error) {
 	return model, nil
 }
 
+func filenameWithPath(path, f string) (string, error) {
+	// if filePath starts with ~/, replace it with the user's home directory.
+	if strings.HasPrefix(f, "~/") {
+		parts := strings.Split(f, "/")
+		home, err := os.UserHomeDir()
+		if err != nil {
+			return "", fmt.Errorf("failed to open file: %v", err)
+		}
+
+		f = filepath.Join(home, filepath.Join(parts[1:]...))
+	}
+
+	// if filePath is not an absolute path, make it relative to the modelfile path
+	if !filepath.IsAbs(f) {
+		f = filepath.Join(filepath.Dir(path), f)
+	}
+
+	return f, nil
+}
+
 func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) error {
 	mf, err := os.Open(path)
 	if err != nil {
@@ -211,33 +248,20 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
 
 	var layers []*LayerReader
 	params := make(map[string][]string)
-
+	embed := EmbeddingParams{fn: fn, opts: api.DefaultOptions()}
 	for _, c := range commands {
 		log.Printf("[%s] - %s\n", c.Name, c.Args)
 		switch c.Name {
 		case "model":
 			fn(api.ProgressResponse{Status: "looking for model"})
+			embed.model = c.Args
 			mf, err := GetManifest(ParseModelPath(c.Args))
 			if err != nil {
-				fp := c.Args
-
-				// If filePath starts with ~/, replace it with the user's home directory.
-				if strings.HasPrefix(fp, "~/") {
-					parts := strings.Split(fp, "/")
-					home, err := os.UserHomeDir()
-					if err != nil {
-						return fmt.Errorf("failed to open file: %v", err)
-					}
-
-					fp = filepath.Join(home, filepath.Join(parts[1:]...))
-				}
-
-				// If filePath is not an absolute path, make it relative to the modelfile path
-				if !filepath.IsAbs(fp) {
-					fp = filepath.Join(filepath.Dir(path), fp)
+				modelFile, err := filenameWithPath(path, c.Args)
+				if err != nil {
+					return err
 				}
-
-				if _, err := os.Stat(fp); err != nil {
+				if _, err := os.Stat(modelFile); err != nil {
 					// the model file does not exist, try pulling it
 					if errors.Is(err, os.ErrNotExist) {
 						fn(api.ProgressResponse{Status: "pulling model file"})
@@ -248,15 +272,13 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
 						if err != nil {
 							return fmt.Errorf("failed to open file after pull: %v", err)
 						}
-
 					} else {
 						return err
 					}
 				} else {
 					// create a model from this specified file
 					fn(api.ProgressResponse{Status: "creating model layer"})
-
-					file, err := os.Open(fp)
+					file, err := os.Open(modelFile)
 					if err != nil {
 						return fmt.Errorf("failed to open file: %v", err)
 					}
@@ -280,9 +302,14 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
 					layers = append(layers, newLayer)
 				}
 			}
+		case "embed":
+			embedFilePath, err := filenameWithPath(path, c.Args)
+			if err != nil {
+				return err
+			}
+			embed.files = append(embed.files, embedFilePath)
 		case "license":
 			fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
-			// remove the prompt layer if one exists
 			mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
 
 			layer, err := CreateLayer(strings.NewReader(c.Args))
@@ -315,17 +342,34 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
 	if len(params) > 0 {
 		fn(api.ProgressResponse{Status: "creating parameter layer"})
 		layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params")
-		paramData, err := paramsToReader(params)
+		formattedParams, err := formatParams(params)
 		if err != nil {
 			return fmt.Errorf("couldn't create params json: %v", err)
 		}
-		l, err := CreateLayer(paramData)
+
+		bts, err := json.Marshal(formattedParams)
+		if err != nil {
+			return err
+		}
+
+		l, err := CreateLayer(bytes.NewReader(bts))
 		if err != nil {
 			return fmt.Errorf("failed to create layer: %v", err)
 		}
 		l.MediaType = "application/vnd.ollama.image.params"
 		layers = append(layers, l)
+
+		// apply these parameters to the embedding options, in case embeddings need to be generated using this model
+		embed.opts = api.DefaultOptions()
+		embed.opts.FromMap(formattedParams)
+	}
+
+	// generate the embedding layers
+	embeddingLayers, err := embeddingLayers(embed)
+	if err != nil {
+		return err
 	}
+	layers = append(layers, embeddingLayers...)
 
 	digests, err := getLayerDigests(layers)
 	if err != nil {
@@ -361,6 +405,138 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
 	return nil
 }
 
+type EmbeddingParams struct {
+	model string
+	opts  api.Options
+	files []string // paths to files to embed
+	fn    func(resp api.ProgressResponse)
+}
+
+// embeddingLayers loads the associated LLM and generates the embeddings to be stored from an input file
+func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
+	layers := []*LayerReader{}
+	if len(e.files) > 0 {
+		if _, err := os.Stat(e.model); err != nil {
+			if os.IsNotExist(err) {
+				// this is a model name rather than the file
+				model, err := GetModel(e.model)
+				if err != nil {
+					return nil, fmt.Errorf("failed to get model to generate embeddings: %v", err)
+				}
+				e.model = model.ModelPath
+			} else {
+				return nil, fmt.Errorf("failed to get model file to generate embeddings: %v", err)
+			}
+		}
+
+		e.opts.EmbeddingOnly = true
+		llm, err := llama.New(e.model, e.opts)
+		if err != nil {
+			return nil, fmt.Errorf("load model to generate embeddings: %v", err)
+		}
+		defer func() {
+			if llm != nil {
+				llm.Close()
+			}
+		}()
+
+		addedFiles := make(map[string]bool) // keep track of files that have already been added
+		for _, filePattern := range e.files {
+			matchingFiles, err := filepath.Glob(filePattern)
+			if err != nil {
+				return nil, fmt.Errorf("could not find files with pattern %s: %w", filePattern, err)
+			}
+
+			for _, filePath := range matchingFiles {
+				if addedFiles[filePath] {
+					continue
+				}
+				addedFiles[filePath] = true
+				// TODO: check file type
+				f, err := os.Open(filePath)
+				if err != nil {
+					return nil, fmt.Errorf("could not open embed file: %w", err)
+				}
+				scanner := bufio.NewScanner(f)
+				scanner.Split(bufio.ScanLines)
+
+				data := []string{}
+				for scanner.Scan() {
+					data = append(data, scanner.Text())
+				}
+				f.Close()
+
+				// the digest of the file is set here so that the client knows a new operation is in progress
+				fileDigest, _ := GetSHA256Digest(bytes.NewReader([]byte(filePath)))
+
+				embeddings := []vector.Embedding{}
+				for i, d := range data {
+					if strings.TrimSpace(d) == "" {
+						continue
+					}
+					e.fn(api.ProgressResponse{
+						Status:    fmt.Sprintf("creating embeddings for file %s", filePath),
+						Digest:    fileDigest,
+						Total:     len(data) - 1,
+						Completed: i,
+					})
+					retry := 0
+				generate:
+					if retry > 3 {
+						log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
+						continue
+					}
+					embed, err := llm.Embedding(d)
+					if err != nil {
+						log.Printf("retrying embedding generation for '%s' line %d: %v", filePath, i+1, err)
+						retry++
+						goto generate
+					}
+					// Check for NaN and Inf in the embedding, which can't be stored
+					for _, value := range embed {
+						if math.IsNaN(value) || math.IsInf(value, 0) {
+							log.Printf("reloading model, embedding contains NaN or Inf")
+							// reload the model to get a new embedding, the seed can effect these outputs and reloading changes it
+							llm.Close()
+							llm, err = llama.New(e.model, e.opts)
+							if err != nil {
+								return nil, fmt.Errorf("load model to generate embeddings: %v", err)
+							}
+							retry++
+							goto generate
+						}
+					}
+					embeddings = append(embeddings, vector.Embedding{Data: d, Vector: embed})
+				}
+
+				b, err := json.Marshal(embeddings)
+				if err != nil {
+					return nil, fmt.Errorf("failed to encode embeddings: %w", err)
+				}
+				r := bytes.NewReader(b)
+
+				digest, size := GetSHA256Digest(r)
+				// Reset the position of the reader after calculating the digest
+				if _, err := r.Seek(0, io.SeekStart); err != nil {
+					return nil, fmt.Errorf("could not reset embed reader: %w", err)
+				}
+
+				layer := &LayerReader{
+					Layer: Layer{
+						MediaType: "application/vnd.ollama.image.embed",
+						Digest:    digest,
+						Size:      size,
+					},
+					Reader: r,
+				}
+
+				layers = append(layers, layer)
+			}
+		}
+	}
+	return layers, nil
+}
+
 func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerReader {
 	j := 0
 	for _, l := range layers {
@@ -449,8 +625,8 @@ func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) {
 	return newLayer, nil
 }
 
-// paramsToReader converts specified parameter options to their correct types, and returns a reader for the json
-func paramsToReader(params map[string][]string) (io.ReadSeeker, error) {
+// formatParams converts specified parameter options to their correct types
+func formatParams(params map[string][]string) (map[string]interface{}, error) {
 	opts := api.Options{}
 	valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
 	typeOpts := reflect.TypeOf(opts)           // types of the fields in the options struct
@@ -504,12 +680,7 @@ func paramsToReader(params map[string][]string) (io.ReadSeeker, error) {
 		}
 	}
 
-	bts, err := json.Marshal(out)
-	if err != nil {
-		return nil, err
-	}
-
-	return bytes.NewReader(bts), nil
+	return out, nil
 }
 
 func getLayerDigests(layers []*LayerReader) ([]string, error) {
@@ -1042,7 +1213,7 @@ func downloadBlob(mp ModelPath, digest string, regOpts *RegistryOptions, fn func
 
 	for {
 		fn(api.ProgressResponse{
-			Status:    fmt.Sprintf("downloading %s", digest),
+			Status:    fmt.Sprintf("pulling %s...", digest[7:19]),
 			Digest:    digest,
 			Total:     int(total),
 			Completed: int(completed),

+ 25 - 3
server/routes.go

@@ -17,15 +17,18 @@ import (
 
 	"github.com/gin-contrib/cors"
 	"github.com/gin-gonic/gin"
+	"gonum.org/v1/gonum/mat"
 
 	"github.com/jmorganca/ollama/api"
 	"github.com/jmorganca/ollama/llama"
+	"github.com/jmorganca/ollama/vector"
 )
 
 var loaded struct {
 	mu sync.Mutex
 
-	llm *llama.LLM
+	llm        *llama.LLM
+	Embeddings []vector.Embedding
 
 	expireAt    time.Time
 	expireTimer *time.Timer
@@ -72,6 +75,11 @@ func GenerateHandler(c *gin.Context) {
 			loaded.digest = ""
 		}
 
+		if model.Embeddings != nil && len(model.Embeddings) > 0 {
+			opts.EmbeddingOnly = true // this is requried to generate embeddings, completions will still work
+			loaded.Embeddings = model.Embeddings
+		}
+
 		llm, err := llama.New(model.ModelPath, opts)
 		if err != nil {
 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -101,7 +109,6 @@ func GenerateHandler(c *gin.Context) {
 		loaded.digest = model.Digest
 		loaded.options = opts
 	}
-
 	sessionDuration := 5 * time.Minute
 
 	loaded.expireAt = time.Now().Add(sessionDuration)
@@ -127,7 +134,22 @@ func GenerateHandler(c *gin.Context) {
 
 	checkpointLoaded := time.Now()
 
-	prompt, err := model.Prompt(req)
+	embedding := ""
+	if model.Embeddings != nil && len(model.Embeddings) > 0 {
+		promptEmbed, err := loaded.llm.Embedding(req.Prompt)
+		if err != nil {
+			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+			return
+		}
+		// TODO: set embed_top from specified parameters in modelfile
+		embed_top := 3
+		topK := vector.TopK(embed_top, mat.NewVecDense(len(promptEmbed), promptEmbed), loaded.Embeddings)
+		for _, e := range topK {
+			embedding = fmt.Sprintf("%s %s", embedding, e.Embedding.Data)
+		}
+	}
+
+	prompt, err := model.Prompt(req, embedding)
 	if err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return

+ 69 - 0
vector/store.go

@@ -0,0 +1,69 @@
+package vector
+
+import (
+	"container/heap"
+	"sort"
+
+	"gonum.org/v1/gonum/mat"
+)
+
+type Embedding struct {
+	Vector []float64 // the embedding vector
+	Data   string    // the data represted by the embedding
+}
+
+type EmbeddingSimilarity struct {
+	Embedding  Embedding // the embedding that was used to calculate the similarity
+	Similarity float64   // the similarity between the embedding and the query
+}
+
+type Heap []EmbeddingSimilarity
+
+func (h Heap) Len() int           { return len(h) }
+func (h Heap) Less(i, j int) bool { return h[i].Similarity < h[j].Similarity }
+func (h Heap) Swap(i, j int)      { h[i], h[j] = h[j], h[i] }
+func (h *Heap) Push(e any) {
+	*h = append(*h, e.(EmbeddingSimilarity))
+}
+
+func (h *Heap) Pop() interface{} {
+	old := *h
+	n := len(old)
+	x := old[n-1]
+	*h = old[0 : n-1]
+	return x
+}
+
+// cosineSimilarity is a measure that calculates the cosine of the angle between two vectors.
+// This value will range from -1 to 1, where 1 means the vectors are identical.
+func cosineSimilarity(vec1, vec2 *mat.VecDense) float64 {
+	dotProduct := mat.Dot(vec1, vec2)
+	norms := mat.Norm(vec1, 2) * mat.Norm(vec2, 2)
+
+	if norms == 0 {
+		return 0
+	}
+	return dotProduct / norms
+}
+
+func TopK(k int, query *mat.VecDense, embeddings []Embedding) []EmbeddingSimilarity {
+	h := &Heap{}
+	heap.Init(h)
+	for _, emb := range embeddings {
+		similarity := cosineSimilarity(query, mat.NewVecDense(len(emb.Vector), emb.Vector))
+		heap.Push(h, EmbeddingSimilarity{Embedding: emb, Similarity: similarity})
+		if h.Len() > k {
+			heap.Pop(h)
+		}
+	}
+
+	topK := make([]EmbeddingSimilarity, 0, h.Len())
+	for h.Len() > 0 {
+		topK = append(topK, heap.Pop(h).(EmbeddingSimilarity))
+	}
+	sort.Slice(topK, func(i, j int) bool {
+		return topK[i].Similarity > topK[j].Similarity
+	})
+
+	return topK
+}