فهرست منبع

embed text document in modelfile

Bruce MacDonald 1 سال پیش
والد
کامیت
7a5f3616fd
10فایلهای تغییر یافته به همراه371 افزوده شده و 52 حذف شده
  1. 1 0
      api/types.go
  2. 12 6
      cmd/cmd.go
  3. 11 1
      docs/modelfile.md
  4. 1 0
      go.mod
  5. 2 0
      go.sum
  6. 37 0
      llama/llama.go
  7. 1 1
      parser/parser.go
  8. 212 41
      server/images.go
  9. 25 3
      server/routes.go
  10. 69 0
      vector/store.go

+ 1 - 0
api/types.go

@@ -276,6 +276,7 @@ func DefaultOptions() Options {
 		UseMLock:           false,
 		UseMLock:           false,
 		RopeFrequencyBase:  10000.0,
 		RopeFrequencyBase:  10000.0,
 		RopeFrequencyScale: 1.0,
 		RopeFrequencyScale: 1.0,
+		EmbeddingOnly:      true,
 
 
 		RepeatLastN:      64,
 		RepeatLastN:      64,
 		RepeatPenalty:    1.1,
 		RepeatPenalty:    1.1,

+ 12 - 6
cmd/cmd.go

@@ -48,12 +48,18 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 				spinner.Stop()
 				spinner.Stop()
 			}
 			}
 			currentDigest = resp.Digest
 			currentDigest = resp.Digest
-			bar = progressbar.DefaultBytes(
-				int64(resp.Total),
-				fmt.Sprintf("pulling %s...", resp.Digest[7:19]),
-			)
-
-			bar.Set(resp.Completed)
+			switch {
+			case strings.Contains(resp.Status, "embeddings"):
+				bar = progressbar.Default(int64(resp.Total), resp.Status)
+				bar.Set(resp.Completed)
+			default:
+				// pulling
+				bar = progressbar.DefaultBytes(
+					int64(resp.Total),
+					resp.Status,
+				)
+				bar.Set(resp.Completed)
+			}
 		} else if resp.Digest == currentDigest && resp.Digest != "" {
 		} else if resp.Digest == currentDigest && resp.Digest != "" {
 			bar.Set(resp.Completed)
 			bar.Set(resp.Completed)
 		} else {
 		} else {

+ 11 - 1
docs/modelfile.md

@@ -12,6 +12,7 @@ A model file is the blueprint to create and share models with Ollama.
   - [FROM (Required)](#from-required)
   - [FROM (Required)](#from-required)
     - [Build from llama2](#build-from-llama2)
     - [Build from llama2](#build-from-llama2)
     - [Build from a bin file](#build-from-a-bin-file)
     - [Build from a bin file](#build-from-a-bin-file)
+  - [EMBED](#embed)
   - [PARAMETER](#parameter)
   - [PARAMETER](#parameter)
     - [Valid Parameters and Values](#valid-parameters-and-values)
     - [Valid Parameters and Values](#valid-parameters-and-values)
   - [TEMPLATE](#template)
   - [TEMPLATE](#template)
@@ -88,6 +89,15 @@ FROM ./ollama-model.bin
 
 
 This bin file location should be specified as an absolute path or relative to the Modelfile location.
 This bin file location should be specified as an absolute path or relative to the Modelfile location.
 
 
+### EMBED
+
+The EMBED instruction is used to add embeddings of files to a model. This is useful for adding custom data that the model can reference when generating an answer.
+
+```
+FROM <model name>:<tag>
+EMBED <file path>
+```
+
 ### PARAMETER
 ### PARAMETER
 
 
 The `PARAMETER` instruction defines a parameter that can be set when the model is run.
 The `PARAMETER` instruction defines a parameter that can be set when the model is run.
@@ -163,4 +173,4 @@ LICENSE """
 ## Notes
 ## Notes
 
 
 - the **modelfile is not case sensitive**. In the examples, we use uppercase for instructions to make it easier to distinguish it from arguments.
 - the **modelfile is not case sensitive**. In the examples, we use uppercase for instructions to make it easier to distinguish it from arguments.
-- Instructions can be in any order. In the examples, we start with FROM instruction to keep it easily readable.
+- Instructions can be in any order. In the examples, we start with FROM instruction to keep it easily readable.

+ 1 - 0
go.mod

@@ -42,6 +42,7 @@ require (
 	golang.org/x/sys v0.10.0 // indirect
 	golang.org/x/sys v0.10.0 // indirect
 	golang.org/x/term v0.10.0
 	golang.org/x/term v0.10.0
 	golang.org/x/text v0.10.0 // indirect
 	golang.org/x/text v0.10.0 // indirect
+	gonum.org/v1/gonum v0.13.0
 	google.golang.org/protobuf v1.30.0 // indirect
 	google.golang.org/protobuf v1.30.0 // indirect
 	gopkg.in/yaml.v3 v3.0.1 // indirect
 	gopkg.in/yaml.v3 v3.0.1 // indirect
 )
 )

+ 2 - 0
go.sum

@@ -139,6 +139,8 @@ golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58=
 golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
 golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
 golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
 golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
 golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+gonum.org/v1/gonum v0.13.0 h1:a0T3bh+7fhRyqeNbiC3qVHYmkiQgit3wnNan/2c0HMM=
+gonum.org/v1/gonum v0.13.0/go.mod h1:/WPYRckkfWrhWefxyYTfrTtQR0KH4iyHNuzxqXAKyAU=
 google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
 google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
 google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
 google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
 google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
 google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=

+ 37 - 0
llama/llama.go

@@ -85,6 +85,7 @@ llama_token llama_sample(
 }
 }
 */
 */
 import "C"
 import "C"
+
 import (
 import (
 	"bytes"
 	"bytes"
 	"embed"
 	"embed"
@@ -93,6 +94,7 @@ import (
 	"io"
 	"io"
 	"log"
 	"log"
 	"os"
 	"os"
+	"reflect"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
 	"unicode/utf8"
 	"unicode/utf8"
@@ -408,3 +410,38 @@ func (llm *LLM) next() (C.llama_token, error) {
 
 
 	return token, nil
 	return token, nil
 }
 }
+
+func (llm *LLM) Embedding(input string) ([]float64, error) {
+	if !llm.EmbeddingOnly {
+		return nil, errors.New("llama: embedding not enabled")
+	}
+
+	tokens := llm.tokenize(input)
+	if tokens == nil {
+		return nil, errors.New("llama: tokenize embedding")
+	}
+
+	retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread))
+	if retval != 0 {
+		return nil, errors.New("llama: eval")
+	}
+
+	n := int(C.llama_n_embd(llm.ctx))
+	if n <= 0 {
+		return nil, errors.New("llama: no embeddings generated")
+	}
+
+	embedPtr := C.llama_get_embeddings(llm.ctx)
+	if embedPtr == nil {
+		return nil, errors.New("llama: embedding retrieval failed")
+	}
+
+	header := reflect.SliceHeader{
+		Data: uintptr(unsafe.Pointer(embedPtr)),
+		Len:  n,
+		Cap:  n,
+	}
+	embedSlice := *(*[]float64)(unsafe.Pointer(&header))
+
+	return embedSlice, nil
+}

+ 1 - 1
parser/parser.go

@@ -40,7 +40,7 @@ func Parse(reader io.Reader) ([]Command, error) {
 			command.Args = string(fields[1])
 			command.Args = string(fields[1])
 			// copy command for validation
 			// copy command for validation
 			modelCommand = command
 			modelCommand = command
-		case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT":
+		case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT", "EMBED":
 			command.Name = string(bytes.ToLower(fields[0]))
 			command.Name = string(bytes.ToLower(fields[0]))
 			command.Args = string(fields[1])
 			command.Args = string(fields[1])
 		case "PARAMETER":
 		case "PARAMETER":

+ 212 - 41
server/images.go

@@ -1,6 +1,7 @@
 package server
 package server
 
 
 import (
 import (
+	"bufio"
 	"bytes"
 	"bytes"
 	"crypto/sha256"
 	"crypto/sha256"
 	"encoding/json"
 	"encoding/json"
@@ -9,6 +10,7 @@ import (
 	"html/template"
 	"html/template"
 	"io"
 	"io"
 	"log"
 	"log"
+	"math"
 	"net/http"
 	"net/http"
 	"os"
 	"os"
 	"path"
 	"path"
@@ -18,7 +20,9 @@ import (
 	"strings"
 	"strings"
 
 
 	"github.com/jmorganca/ollama/api"
 	"github.com/jmorganca/ollama/api"
+	"github.com/jmorganca/ollama/llama"
 	"github.com/jmorganca/ollama/parser"
 	"github.com/jmorganca/ollama/parser"
+	"github.com/jmorganca/ollama/vector"
 )
 )
 
 
 type RegistryOptions struct {
 type RegistryOptions struct {
@@ -28,15 +32,16 @@ type RegistryOptions struct {
 }
 }
 
 
 type Model struct {
 type Model struct {
-	Name      string `json:"name"`
-	ModelPath string
-	Template  string
-	System    string
-	Digest    string
-	Options   map[string]interface{}
+	Name       string `json:"name"`
+	ModelPath  string
+	Template   string
+	System     string
+	Digest     string
+	Options    map[string]interface{}
+	Embeddings []vector.Embedding
 }
 }
 
 
-func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
+func (m *Model) Prompt(request api.GenerateRequest, embedding string) (string, error) {
 	t := m.Template
 	t := m.Template
 	if request.Template != "" {
 	if request.Template != "" {
 		t = request.Template
 		t = request.Template
@@ -51,6 +56,7 @@ func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
 		First  bool
 		First  bool
 		System string
 		System string
 		Prompt string
 		Prompt string
+		Embed  string
 
 
 		// deprecated: versions <= 0.0.7 used this to omit the system prompt
 		// deprecated: versions <= 0.0.7 used this to omit the system prompt
 		Context []int
 		Context []int
@@ -60,6 +66,7 @@ func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
 	vars.System = m.System
 	vars.System = m.System
 	vars.Prompt = request.Prompt
 	vars.Prompt = request.Prompt
 	vars.Context = request.Context
 	vars.Context = request.Context
+	vars.Embed = embedding
 
 
 	if request.System != "" {
 	if request.System != "" {
 		vars.System = request.System
 		vars.System = request.System
@@ -157,6 +164,16 @@ func GetModel(name string) (*Model, error) {
 		switch layer.MediaType {
 		switch layer.MediaType {
 		case "application/vnd.ollama.image.model":
 		case "application/vnd.ollama.image.model":
 			model.ModelPath = filename
 			model.ModelPath = filename
+		case "application/vnd.ollama.image.embed":
+			file, err := os.Open(filename)
+			if err != nil {
+				return nil, fmt.Errorf("failed to open file: %s", filename)
+			}
+			defer file.Close()
+
+			if err = json.NewDecoder(file).Decode(&model.Embeddings); err != nil {
+				return nil, err
+			}
 		case "application/vnd.ollama.image.template":
 		case "application/vnd.ollama.image.template":
 			bts, err := os.ReadFile(filename)
 			bts, err := os.ReadFile(filename)
 			if err != nil {
 			if err != nil {
@@ -195,6 +212,26 @@ func GetModel(name string) (*Model, error) {
 	return model, nil
 	return model, nil
 }
 }
 
 
+func filenameWithPath(path, f string) (string, error) {
+	// if filePath starts with ~/, replace it with the user's home directory.
+	if strings.HasPrefix(f, "~/") {
+		parts := strings.Split(f, "/")
+		home, err := os.UserHomeDir()
+		if err != nil {
+			return "", fmt.Errorf("failed to open file: %v", err)
+		}
+
+		f = filepath.Join(home, filepath.Join(parts[1:]...))
+	}
+
+	// if filePath is not an absolute path, make it relative to the modelfile path
+	if !filepath.IsAbs(f) {
+		f = filepath.Join(filepath.Dir(path), f)
+	}
+
+	return f, nil
+}
+
 func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) error {
 func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) error {
 	mf, err := os.Open(path)
 	mf, err := os.Open(path)
 	if err != nil {
 	if err != nil {
@@ -211,33 +248,20 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
 
 
 	var layers []*LayerReader
 	var layers []*LayerReader
 	params := make(map[string][]string)
 	params := make(map[string][]string)
-
+	embed := EmbeddingParams{fn: fn, opts: api.DefaultOptions()}
 	for _, c := range commands {
 	for _, c := range commands {
 		log.Printf("[%s] - %s\n", c.Name, c.Args)
 		log.Printf("[%s] - %s\n", c.Name, c.Args)
 		switch c.Name {
 		switch c.Name {
 		case "model":
 		case "model":
 			fn(api.ProgressResponse{Status: "looking for model"})
 			fn(api.ProgressResponse{Status: "looking for model"})
+			embed.model = c.Args
 			mf, err := GetManifest(ParseModelPath(c.Args))
 			mf, err := GetManifest(ParseModelPath(c.Args))
 			if err != nil {
 			if err != nil {
-				fp := c.Args
-
-				// If filePath starts with ~/, replace it with the user's home directory.
-				if strings.HasPrefix(fp, "~/") {
-					parts := strings.Split(fp, "/")
-					home, err := os.UserHomeDir()
-					if err != nil {
-						return fmt.Errorf("failed to open file: %v", err)
-					}
-
-					fp = filepath.Join(home, filepath.Join(parts[1:]...))
-				}
-
-				// If filePath is not an absolute path, make it relative to the modelfile path
-				if !filepath.IsAbs(fp) {
-					fp = filepath.Join(filepath.Dir(path), fp)
+				modelFile, err := filenameWithPath(path, c.Args)
+				if err != nil {
+					return err
 				}
 				}
-
-				if _, err := os.Stat(fp); err != nil {
+				if _, err := os.Stat(modelFile); err != nil {
 					// the model file does not exist, try pulling it
 					// the model file does not exist, try pulling it
 					if errors.Is(err, os.ErrNotExist) {
 					if errors.Is(err, os.ErrNotExist) {
 						fn(api.ProgressResponse{Status: "pulling model file"})
 						fn(api.ProgressResponse{Status: "pulling model file"})
@@ -248,15 +272,13 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
 						if err != nil {
 						if err != nil {
 							return fmt.Errorf("failed to open file after pull: %v", err)
 							return fmt.Errorf("failed to open file after pull: %v", err)
 						}
 						}
-
 					} else {
 					} else {
 						return err
 						return err
 					}
 					}
 				} else {
 				} else {
 					// create a model from this specified file
 					// create a model from this specified file
 					fn(api.ProgressResponse{Status: "creating model layer"})
 					fn(api.ProgressResponse{Status: "creating model layer"})
-
-					file, err := os.Open(fp)
+					file, err := os.Open(modelFile)
 					if err != nil {
 					if err != nil {
 						return fmt.Errorf("failed to open file: %v", err)
 						return fmt.Errorf("failed to open file: %v", err)
 					}
 					}
@@ -280,9 +302,14 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
 					layers = append(layers, newLayer)
 					layers = append(layers, newLayer)
 				}
 				}
 			}
 			}
+		case "embed":
+			embedFilePath, err := filenameWithPath(path, c.Args)
+			if err != nil {
+				return err
+			}
+			embed.files = append(embed.files, embedFilePath)
 		case "license":
 		case "license":
 			fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
 			fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
-			// remove the prompt layer if one exists
 			mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
 			mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
 
 
 			layer, err := CreateLayer(strings.NewReader(c.Args))
 			layer, err := CreateLayer(strings.NewReader(c.Args))
@@ -315,17 +342,34 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
 	if len(params) > 0 {
 	if len(params) > 0 {
 		fn(api.ProgressResponse{Status: "creating parameter layer"})
 		fn(api.ProgressResponse{Status: "creating parameter layer"})
 		layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params")
 		layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params")
-		paramData, err := paramsToReader(params)
+		formattedParams, err := formatParams(params)
 		if err != nil {
 		if err != nil {
 			return fmt.Errorf("couldn't create params json: %v", err)
 			return fmt.Errorf("couldn't create params json: %v", err)
 		}
 		}
-		l, err := CreateLayer(paramData)
+
+		bts, err := json.Marshal(formattedParams)
+		if err != nil {
+			return err
+		}
+
+		l, err := CreateLayer(bytes.NewReader(bts))
 		if err != nil {
 		if err != nil {
 			return fmt.Errorf("failed to create layer: %v", err)
 			return fmt.Errorf("failed to create layer: %v", err)
 		}
 		}
 		l.MediaType = "application/vnd.ollama.image.params"
 		l.MediaType = "application/vnd.ollama.image.params"
 		layers = append(layers, l)
 		layers = append(layers, l)
+
+		// apply these parameters to the embedding options, in case embeddings need to be generated using this model
+		embed.opts = api.DefaultOptions()
+		embed.opts.FromMap(formattedParams)
+	}
+
+	// generate the embedding layers
+	embeddingLayers, err := embeddingLayers(embed)
+	if err != nil {
+		return err
 	}
 	}
+	layers = append(layers, embeddingLayers...)
 
 
 	digests, err := getLayerDigests(layers)
 	digests, err := getLayerDigests(layers)
 	if err != nil {
 	if err != nil {
@@ -361,6 +405,138 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
 	return nil
 	return nil
 }
 }
 
 
+type EmbeddingParams struct {
+	model string
+	opts  api.Options
+	files []string // paths to files to embed
+	fn    func(resp api.ProgressResponse)
+}
+
+// embeddingLayers loads the associated LLM and generates the embeddings to be stored from an input file
+func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
+	layers := []*LayerReader{}
+	if len(e.files) > 0 {
+		if _, err := os.Stat(e.model); err != nil {
+			if os.IsNotExist(err) {
+				// this is a model name rather than the file
+				model, err := GetModel(e.model)
+				if err != nil {
+					return nil, fmt.Errorf("failed to get model to generate embeddings: %v", err)
+				}
+				e.model = model.ModelPath
+			} else {
+				return nil, fmt.Errorf("failed to get model file to generate embeddings: %v", err)
+			}
+		}
+
+		e.opts.EmbeddingOnly = true
+		llm, err := llama.New(e.model, e.opts)
+		if err != nil {
+			return nil, fmt.Errorf("load model to generate embeddings: %v", err)
+		}
+		defer func() {
+			if llm != nil {
+				llm.Close()
+			}
+		}()
+
+		addedFiles := make(map[string]bool) // keep track of files that have already been added
+		for _, filePattern := range e.files {
+			matchingFiles, err := filepath.Glob(filePattern)
+			if err != nil {
+				return nil, fmt.Errorf("could not find files with pattern %s: %w", filePattern, err)
+			}
+
+			for _, filePath := range matchingFiles {
+				if addedFiles[filePath] {
+					continue
+				}
+				addedFiles[filePath] = true
+				// TODO: check file type
+				f, err := os.Open(filePath)
+				if err != nil {
+					return nil, fmt.Errorf("could not open embed file: %w", err)
+				}
+				scanner := bufio.NewScanner(f)
+				scanner.Split(bufio.ScanLines)
+
+				data := []string{}
+				for scanner.Scan() {
+					data = append(data, scanner.Text())
+				}
+				f.Close()
+
+				// the digest of the file is set here so that the client knows a new operation is in progress
+				fileDigest, _ := GetSHA256Digest(bytes.NewReader([]byte(filePath)))
+
+				embeddings := []vector.Embedding{}
+				for i, d := range data {
+					if strings.TrimSpace(d) == "" {
+						continue
+					}
+					e.fn(api.ProgressResponse{
+						Status:    fmt.Sprintf("creating embeddings for file %s", filePath),
+						Digest:    fileDigest,
+						Total:     len(data) - 1,
+						Completed: i,
+					})
+					retry := 0
+				generate:
+					if retry > 3 {
+						log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
+						continue
+					}
+					embed, err := llm.Embedding(d)
+					if err != nil {
+						log.Printf("retrying embedding generation for '%s' line %d: %v", filePath, i+1, err)
+						retry++
+						goto generate
+					}
+					// Check for NaN and Inf in the embedding, which can't be stored
+					for _, value := range embed {
+						if math.IsNaN(value) || math.IsInf(value, 0) {
+							log.Printf("reloading model, embedding contains NaN or Inf")
+							// reload the model to get a new embedding, the seed can effect these outputs and reloading changes it
+							llm.Close()
+							llm, err = llama.New(e.model, e.opts)
+							if err != nil {
+								return nil, fmt.Errorf("load model to generate embeddings: %v", err)
+							}
+							retry++
+							goto generate
+						}
+					}
+					embeddings = append(embeddings, vector.Embedding{Data: d, Vector: embed})
+				}
+
+				b, err := json.Marshal(embeddings)
+				if err != nil {
+					return nil, fmt.Errorf("failed to encode embeddings: %w", err)
+				}
+				r := bytes.NewReader(b)
+
+				digest, size := GetSHA256Digest(r)
+				// Reset the position of the reader after calculating the digest
+				if _, err := r.Seek(0, io.SeekStart); err != nil {
+					return nil, fmt.Errorf("could not reset embed reader: %w", err)
+				}
+
+				layer := &LayerReader{
+					Layer: Layer{
+						MediaType: "application/vnd.ollama.image.embed",
+						Digest:    digest,
+						Size:      size,
+					},
+					Reader: r,
+				}
+
+				layers = append(layers, layer)
+			}
+		}
+	}
+	return layers, nil
+}
+
 func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerReader {
 func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerReader {
 	j := 0
 	j := 0
 	for _, l := range layers {
 	for _, l := range layers {
@@ -449,8 +625,8 @@ func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) {
 	return newLayer, nil
 	return newLayer, nil
 }
 }
 
 
-// paramsToReader converts specified parameter options to their correct types, and returns a reader for the json
-func paramsToReader(params map[string][]string) (io.ReadSeeker, error) {
+// formatParams converts specified parameter options to their correct types
+func formatParams(params map[string][]string) (map[string]interface{}, error) {
 	opts := api.Options{}
 	opts := api.Options{}
 	valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
 	valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
 	typeOpts := reflect.TypeOf(opts)           // types of the fields in the options struct
 	typeOpts := reflect.TypeOf(opts)           // types of the fields in the options struct
@@ -504,12 +680,7 @@ func paramsToReader(params map[string][]string) (io.ReadSeeker, error) {
 		}
 		}
 	}
 	}
 
 
-	bts, err := json.Marshal(out)
-	if err != nil {
-		return nil, err
-	}
-
-	return bytes.NewReader(bts), nil
+	return out, nil
 }
 }
 
 
 func getLayerDigests(layers []*LayerReader) ([]string, error) {
 func getLayerDigests(layers []*LayerReader) ([]string, error) {
@@ -1042,7 +1213,7 @@ func downloadBlob(mp ModelPath, digest string, regOpts *RegistryOptions, fn func
 
 
 	for {
 	for {
 		fn(api.ProgressResponse{
 		fn(api.ProgressResponse{
-			Status:    fmt.Sprintf("downloading %s", digest),
+			Status:    fmt.Sprintf("pulling %s...", digest[7:19]),
 			Digest:    digest,
 			Digest:    digest,
 			Total:     int(total),
 			Total:     int(total),
 			Completed: int(completed),
 			Completed: int(completed),

+ 25 - 3
server/routes.go

@@ -17,15 +17,18 @@ import (
 
 
 	"github.com/gin-contrib/cors"
 	"github.com/gin-contrib/cors"
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
+	"gonum.org/v1/gonum/mat"
 
 
 	"github.com/jmorganca/ollama/api"
 	"github.com/jmorganca/ollama/api"
 	"github.com/jmorganca/ollama/llama"
 	"github.com/jmorganca/ollama/llama"
+	"github.com/jmorganca/ollama/vector"
 )
 )
 
 
 var loaded struct {
 var loaded struct {
 	mu sync.Mutex
 	mu sync.Mutex
 
 
-	llm *llama.LLM
+	llm        *llama.LLM
+	Embeddings []vector.Embedding
 
 
 	expireAt    time.Time
 	expireAt    time.Time
 	expireTimer *time.Timer
 	expireTimer *time.Timer
@@ -72,6 +75,11 @@ func GenerateHandler(c *gin.Context) {
 			loaded.digest = ""
 			loaded.digest = ""
 		}
 		}
 
 
+		if model.Embeddings != nil && len(model.Embeddings) > 0 {
+			opts.EmbeddingOnly = true // this is requried to generate embeddings, completions will still work
+			loaded.Embeddings = model.Embeddings
+		}
+
 		llm, err := llama.New(model.ModelPath, opts)
 		llm, err := llama.New(model.ModelPath, opts)
 		if err != nil {
 		if err != nil {
 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -101,7 +109,6 @@ func GenerateHandler(c *gin.Context) {
 		loaded.digest = model.Digest
 		loaded.digest = model.Digest
 		loaded.options = opts
 		loaded.options = opts
 	}
 	}
-
 	sessionDuration := 5 * time.Minute
 	sessionDuration := 5 * time.Minute
 
 
 	loaded.expireAt = time.Now().Add(sessionDuration)
 	loaded.expireAt = time.Now().Add(sessionDuration)
@@ -127,7 +134,22 @@ func GenerateHandler(c *gin.Context) {
 
 
 	checkpointLoaded := time.Now()
 	checkpointLoaded := time.Now()
 
 
-	prompt, err := model.Prompt(req)
+	embedding := ""
+	if model.Embeddings != nil && len(model.Embeddings) > 0 {
+		promptEmbed, err := loaded.llm.Embedding(req.Prompt)
+		if err != nil {
+			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+			return
+		}
+		// TODO: set embed_top from specified parameters in modelfile
+		embed_top := 3
+		topK := vector.TopK(embed_top, mat.NewVecDense(len(promptEmbed), promptEmbed), loaded.Embeddings)
+		for _, e := range topK {
+			embedding = fmt.Sprintf("%s %s", embedding, e.Embedding.Data)
+		}
+	}
+
+	prompt, err := model.Prompt(req, embedding)
 	if err != nil {
 	if err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 		return

+ 69 - 0
vector/store.go

@@ -0,0 +1,69 @@
+package vector
+
+import (
+	"container/heap"
+	"sort"
+
+	"gonum.org/v1/gonum/mat"
+)
+
+type Embedding struct {
+	Vector []float64 // the embedding vector
+	Data   string    // the data represted by the embedding
+}
+
+type EmbeddingSimilarity struct {
+	Embedding  Embedding // the embedding that was used to calculate the similarity
+	Similarity float64   // the similarity between the embedding and the query
+}
+
+type Heap []EmbeddingSimilarity
+
+func (h Heap) Len() int           { return len(h) }
+func (h Heap) Less(i, j int) bool { return h[i].Similarity < h[j].Similarity }
+func (h Heap) Swap(i, j int)      { h[i], h[j] = h[j], h[i] }
+func (h *Heap) Push(e any) {
+	*h = append(*h, e.(EmbeddingSimilarity))
+}
+
+func (h *Heap) Pop() interface{} {
+	old := *h
+	n := len(old)
+	x := old[n-1]
+	*h = old[0 : n-1]
+	return x
+}
+
+// cosineSimilarity is a measure that calculates the cosine of the angle between two vectors.
+// This value will range from -1 to 1, where 1 means the vectors are identical.
+func cosineSimilarity(vec1, vec2 *mat.VecDense) float64 {
+	dotProduct := mat.Dot(vec1, vec2)
+	norms := mat.Norm(vec1, 2) * mat.Norm(vec2, 2)
+
+	if norms == 0 {
+		return 0
+	}
+	return dotProduct / norms
+}
+
+func TopK(k int, query *mat.VecDense, embeddings []Embedding) []EmbeddingSimilarity {
+	h := &Heap{}
+	heap.Init(h)
+	for _, emb := range embeddings {
+		similarity := cosineSimilarity(query, mat.NewVecDense(len(emb.Vector), emb.Vector))
+		heap.Push(h, EmbeddingSimilarity{Embedding: emb, Similarity: similarity})
+		if h.Len() > k {
+			heap.Pop(h)
+		}
+	}
+
+	topK := make([]EmbeddingSimilarity, 0, h.Len())
+	for h.Len() > 0 {
+		topK = append(topK, heap.Pop(h).(EmbeddingSimilarity))
+	}
+	sort.Slice(topK, func(i, j int) bool {
+		return topK[i].Similarity > topK[j].Similarity
+	})
+
+	return topK
+}