Browse Source

Introduce `/api/embed` endpoint supporting batch embedding (#5127)

* Initial Batch Embedding

* Revert "Initial Batch Embedding"

This reverts commit c22d54895a280b54c727279d85a5fc94defb5a29.

* Initial Draft

* mock up notes

* api/embed draft

* add server function

* check normalization

* clean up

* normalization

* playing around with truncate stuff

* Truncation

* Truncation

* move normalization to go

* Integration Test Template

* Truncation Integration Tests

* Clean up

* use float32

* move normalize

* move normalize test

* refactoring

* integration float32

* input handling and handler testing

* Refactoring of legacy and new

* clear comments

* merge conflicts

* touches

* embedding type 64

* merge conflicts

* fix hanging on single string

* refactoring

* test values

* set context length

* clean up

* testing clean up

* testing clean up

* remove function closure

* Revert "remove function closure"

This reverts commit 55d48c6ed17abe42e7a122e69d603ef0c1506787.

* remove function closure

* remove redundant error check

* clean up

* more clean up

* clean up
royjhan 9 months ago
parent
commit
b9f5e16c80
8 changed files with 453 additions and 31 deletions
  1. 10 1
      api/client.go
  2. 24 0
      api/types.go
  3. 152 0
      integration/embed_test.go
  4. 23 16
      llm/ext_server/server.cpp
  5. 8 8
      llm/server.go
  6. 129 2
      server/routes.go
  7. 103 0
      server/routes_test.go
  8. 4 4
      server/sched_test.go

+ 10 - 1
api/client.go

@@ -347,7 +347,16 @@ func (c *Client) Heartbeat(ctx context.Context) error {
 	return nil
 }
 
-// Embeddings generates embeddings from a model.
+// Embed generates embeddings from a model.
+func (c *Client) Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) {
+	var resp EmbedResponse
+	if err := c.do(ctx, http.MethodPost, "/api/embed", req, &resp); err != nil {
+		return nil, err
+	}
+	return &resp, nil
+}
+
+// Embeddings generates an embedding from a model.
 func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) {
 	var resp EmbeddingResponse
 	if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil {

+ 24 - 0
api/types.go

@@ -173,6 +173,30 @@ type Runner struct {
 	NumThread int   `json:"num_thread,omitempty"`
 }
 
+// EmbedRequest is the request passed to [Client.Embed].
+type EmbedRequest struct {
+	// Model is the model name.
+	Model string `json:"model"`
+
+	// Input is the input to embed.
+	Input any `json:"input"`
+
+	// KeepAlive controls how long the model will stay loaded in memory following
+	// this request.
+	KeepAlive *Duration `json:"keep_alive,omitempty"`
+
+	Truncate *bool `json:"truncate,omitempty"`
+
+	// Options lists model-specific options.
+	Options map[string]interface{} `json:"options"`
+}
+
+// EmbedResponse is the response from [Client.Embed].
+type EmbedResponse struct {
+	Model      string      `json:"model"`
+	Embeddings [][]float32 `json:"embeddings,omitempty"`
+}
+
 // EmbeddingRequest is the request passed to [Client.Embeddings].
 type EmbeddingRequest struct {
 	// Model is the model name.

+ 152 - 0
integration/embed_test.go

@@ -0,0 +1,152 @@
+//go:build integration
+
+package integration
+
+import (
+	"context"
+	"testing"
+	"time"
+
+	"github.com/ollama/ollama/api"
+)
+
+func TestAllMiniLMEmbed(t *testing.T) {
+	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
+	defer cancel()
+
+	req := api.EmbedRequest{
+		Model: "all-minilm",
+		Input: "why is the sky blue?",
+	}
+
+	res, err := embedTestHelper(ctx, t, req)
+
+	if err != nil {
+		t.Fatalf("error: %v", err)
+	}
+
+	if len(res.Embeddings) != 1 {
+		t.Fatalf("expected 1 embedding, got %d", len(res.Embeddings))
+	}
+
+	if len(res.Embeddings[0]) != 384 {
+		t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
+	}
+
+	if res.Embeddings[0][0] != 0.010071031 {
+		t.Fatalf("expected 0.010071031, got %f", res.Embeddings[0][0])
+	}
+}
+
+func TestAllMiniLMBatchEmbed(t *testing.T) {
+	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
+	defer cancel()
+
+	req := api.EmbedRequest{
+		Model: "all-minilm",
+		Input: []string{"why is the sky blue?", "why is the grass green?"},
+	}
+
+	res, err := embedTestHelper(ctx, t, req)
+
+	if err != nil {
+		t.Fatalf("error: %v", err)
+	}
+
+	if len(res.Embeddings) != 2 {
+		t.Fatalf("expected 2 embeddings, got %d", len(res.Embeddings))
+	}
+
+	if len(res.Embeddings[0]) != 384 {
+		t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
+	}
+
+	if res.Embeddings[0][0] != 0.010071031 || res.Embeddings[1][0] != -0.009802706 {
+		t.Fatalf("expected 0.010071031 and -0.009802706, got %f and %f", res.Embeddings[0][0], res.Embeddings[1][0])
+	}
+}
+
+func TestAllMiniLmEmbedTruncate(t *testing.T) {
+	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
+	defer cancel()
+
+	truncTrue, truncFalse := true, false
+
+	type testReq struct {
+		Name    string
+		Request api.EmbedRequest
+	}
+
+	reqs := []testReq{
+		{
+			Name: "Target Truncation",
+			Request: api.EmbedRequest{
+				Model: "all-minilm",
+				Input: "why",
+			},
+		},
+		{
+			Name: "Default Truncate",
+			Request: api.EmbedRequest{
+				Model:   "all-minilm",
+				Input:   "why is the sky blue?",
+				Options: map[string]any{"num_ctx": 1},
+			},
+		},
+		{
+			Name: "Explicit Truncate",
+			Request: api.EmbedRequest{
+				Model:    "all-minilm",
+				Input:    "why is the sky blue?",
+				Truncate: &truncTrue,
+				Options:  map[string]any{"num_ctx": 1},
+			},
+		},
+	}
+
+	res := make(map[string]*api.EmbedResponse)
+
+	for _, req := range reqs {
+		response, err := embedTestHelper(ctx, t, req.Request)
+		if err != nil {
+			t.Fatalf("error: %v", err)
+		}
+		res[req.Name] = response
+	}
+
+	if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] {
+		t.Fatal("expected default request to truncate correctly")
+	}
+
+	if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] {
+		t.Fatal("expected default request and truncate true request to be the same")
+	}
+
+	// check that truncate set to false returns an error if context length is exceeded
+	_, err := embedTestHelper(ctx, t, api.EmbedRequest{
+		Model:    "all-minilm",
+		Input:    "why is the sky blue?",
+		Truncate: &truncFalse,
+		Options:  map[string]any{"num_ctx": 1},
+	})
+
+	if err == nil {
+		t.Fatal("expected error, got nil")
+	}
+}
+
+func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
+	client, _, cleanup := InitServerConnection(ctx, t)
+	defer cleanup()
+	if err := PullIfMissing(ctx, client, req.Model); err != nil {
+		t.Fatalf("failed to pull model %s: %v", req.Model, err)
+	}
+
+	response, err := client.Embed(ctx, &req)
+
+	if err != nil {
+		return nil, err
+	}
+
+	return response, nil
+}

+ 23 - 16
llm/ext_server/server.cpp

@@ -3188,26 +3188,33 @@ int main(int argc, char **argv) {
                     prompt = "";
                 }
 
-                json image_data;
-                if (body.count("image_data") != 0) {
-                    image_data = body["image_data"];
-                }
-                else
-                {
-                    image_data = "";
+                if (prompt.size() == 1) {
+                    prompt = prompt[0];
                 }
 
                 // create and queue the task
-                const int task_id = llama.queue_tasks.get_new_id();
-                llama.queue_results.add_waiting_task_id(task_id);
-                llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, true, -1);
-
-                // get the result
-                task_result result = llama.queue_results.recv(task_id);
-                llama.queue_results.remove_waiting_task_id(task_id);
+                json responses;
+                {
+                    const int id_task = llama.queue_tasks.get_new_id();
+                    llama.queue_results.add_waiting_task_id(id_task);
+                    llama.request_completion(id_task, {{"prompt", prompt}}, true, -1);
+
+                    // get the result
+                    task_result result = llama.queue_results.recv(id_task);
+                    llama.queue_results.remove_waiting_task_id(id_task);
+                    if (result.error) {
+                        return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
+                    }
 
-                // send the result
-                return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
+                    responses = result.result_json.value("results", std::vector<json>{result.result_json});
+                    json embeddings = json::array();
+                    for (auto & elem : responses) {
+                        embeddings.push_back(elem.at("embedding"));
+                    }
+                    // send the result
+                    json embedding_res = json{{"embedding", embeddings}};
+                    return res.set_content(embedding_res.dump(), "application/json; charset=utf-8");
+                }
             });
 
     // GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?

+ 8 - 8
llm/server.go

@@ -33,7 +33,7 @@ type LlamaServer interface {
 	Ping(ctx context.Context) error
 	WaitUntilRunning(ctx context.Context) error
 	Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
-	Embedding(ctx context.Context, prompt string) ([]float64, error)
+	Embed(ctx context.Context, input []string) ([][]float32, error)
 	Tokenize(ctx context.Context, content string) ([]int, error)
 	Detokenize(ctx context.Context, tokens []int) (string, error)
 	Close() error
@@ -867,15 +867,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
 	return nil
 }
 
-type EmbeddingRequest struct {
-	Content string `json:"content"`
+type EmbedRequest struct {
+	Content []string `json:"content"`
 }
 
-type EmbeddingResponse struct {
-	Embedding []float64 `json:"embedding"`
+type EmbedResponse struct {
+	Embedding [][]float32 `json:"embedding"`
 }
 
-func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
+func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, error) {
 	if err := s.sem.Acquire(ctx, 1); err != nil {
 		slog.Error("Failed to acquire semaphore", "error", err)
 		return nil, err
@@ -890,7 +890,7 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
 		return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
 	}
 
-	data, err := json.Marshal(TokenizeRequest{Content: prompt})
+	data, err := json.Marshal(EmbedRequest{Content: input})
 	if err != nil {
 		return nil, fmt.Errorf("error marshaling embed data: %w", err)
 	}
@@ -917,7 +917,7 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
 		return nil, fmt.Errorf("%s", body)
 	}
 
-	var embedding EmbeddingResponse
+	var embedding EmbedResponse
 	if err := json.Unmarshal(body, &embedding); err != nil {
 		return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
 	}

+ 129 - 2
server/routes.go

@@ -9,6 +9,7 @@ import (
 	"fmt"
 	"io"
 	"log/slog"
+	"math"
 	"net"
 	"net/http"
 	"net/netip"
@@ -271,6 +272,121 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 	streamResponse(c, ch)
 }
 
+func (s *Server) EmbedHandler(c *gin.Context) {
+	var req api.EmbedRequest
+	err := c.ShouldBindJSON(&req)
+	switch {
+	case errors.Is(err, io.EOF):
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
+		return
+	case err != nil:
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		return
+	}
+
+	truncate := true
+
+	if req.Truncate != nil && !*req.Truncate {
+		truncate = false
+	}
+
+	var input []string
+
+	switch i := req.Input.(type) {
+	case string:
+		if len(i) > 0 {
+			input = append(input, i)
+		}
+	case []any:
+		for _, v := range i {
+			if _, ok := v.(string); !ok {
+				c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
+				return
+			}
+			input = append(input, v.(string))
+		}
+	default:
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
+		return
+	}
+
+	if len(input) == 0 {
+		c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
+		return
+	}
+
+	r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
+	if err != nil {
+		handleScheduleError(c, req.Model, err)
+		return
+	}
+
+	kvData, err := getKVData(m.ModelPath, false)
+	if err != nil {
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		return
+	}
+
+	for i, s := range input {
+		tokens, err := r.Tokenize(c.Request.Context(), s)
+		if err != nil {
+			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+			return
+		}
+
+		ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
+		if len(tokens) > ctxLen {
+			if !truncate {
+				c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})
+				return
+			}
+
+			tokens = tokens[:ctxLen]
+			s, err = r.Detokenize(c.Request.Context(), tokens)
+			if err != nil {
+				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+				return
+			}
+		}
+
+		input[i] = s
+	}
+	embeddings, err := r.Embed(c.Request.Context(), input)
+
+	if err != nil {
+		slog.Error("embedding generation failed", "error", err)
+		c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
+		return
+	}
+
+	for i, e := range embeddings {
+		embeddings[i] = normalize(e)
+	}
+
+	resp := api.EmbedResponse{
+		Model:      req.Model,
+		Embeddings: embeddings,
+	}
+	c.JSON(http.StatusOK, resp)
+}
+
+func normalize(vec []float32) []float32 {
+	var sum float32
+	for _, v := range vec {
+		sum += v * v
+	}
+
+	norm := float32(0.0)
+	if sum > 0 {
+		norm = float32(1.0 / math.Sqrt(float64(sum)))
+	}
+
+	for i := range vec {
+		vec[i] *= norm
+	}
+	return vec
+}
+
 func (s *Server) EmbeddingsHandler(c *gin.Context) {
 	var req api.EmbeddingRequest
 	if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
@@ -293,14 +409,24 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
 		return
 	}
 
-	embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
+	embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt})
+
 	if err != nil {
 		slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
 		c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
 		return
 	}
 
-	c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: embedding})
+	embedding := make([]float64, len(embeddings[0]))
+
+	for i, v := range embeddings[0] {
+		embedding[i] = float64(v)
+	}
+
+	resp := api.EmbeddingResponse{
+		Embedding: embedding,
+	}
+	c.JSON(http.StatusOK, resp)
 }
 
 func (s *Server) PullModelHandler(c *gin.Context) {
@@ -919,6 +1045,7 @@ func (s *Server) GenerateRoutes() http.Handler {
 	r.POST("/api/pull", s.PullModelHandler)
 	r.POST("/api/generate", s.GenerateHandler)
 	r.POST("/api/chat", s.ChatHandler)
+	r.POST("/api/embed", s.EmbedHandler)
 	r.POST("/api/embeddings", s.EmbeddingsHandler)
 	r.POST("/api/create", s.CreateModelHandler)
 	r.POST("/api/push", s.PushModelHandler)

+ 103 - 0
server/routes_test.go

@@ -7,6 +7,7 @@ import (
 	"encoding/json"
 	"fmt"
 	"io"
+	"math"
 	"net/http"
 	"net/http/httptest"
 	"os"
@@ -272,6 +273,73 @@ func Test_Routes(t *testing.T) {
 				assert.Equal(t, "library", retrieveResp.OwnedBy)
 			},
 		},
+		{
+			Name:   "Embed Handler Empty Input",
+			Method: http.MethodPost,
+			Path:   "/api/embed",
+			Setup: func(t *testing.T, req *http.Request) {
+				embedReq := api.EmbedRequest{
+					Model: "t-bone",
+					Input: "",
+				}
+				jsonData, err := json.Marshal(embedReq)
+				require.NoError(t, err)
+				req.Body = io.NopCloser(bytes.NewReader(jsonData))
+			},
+			Expected: func(t *testing.T, resp *http.Response) {
+				contentType := resp.Header.Get("Content-Type")
+				if contentType != "application/json; charset=utf-8" {
+					t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType)
+				}
+				body, err := io.ReadAll(resp.Body)
+				if err != nil {
+					t.Fatal(err)
+				}
+
+				var embedResp api.EmbedResponse
+				err = json.Unmarshal(body, &embedResp)
+				if err != nil {
+					t.Fatal(err)
+				}
+
+				if embedResp.Model != "t-bone" {
+					t.Fatalf("expected model t-bone, got %s", embedResp.Model)
+				}
+
+				if embedResp.Embeddings != nil {
+					t.Fatalf("expected embeddings to be nil, got %v", embedResp.Embeddings)
+				}
+			},
+		},
+		{
+			Name:   "Embed Handler Invalid Input",
+			Method: http.MethodPost,
+			Path:   "/api/embed",
+			Setup: func(t *testing.T, req *http.Request) {
+				embedReq := api.EmbedRequest{
+					Model: "t-bone",
+					Input: 2,
+				}
+				jsonData, err := json.Marshal(embedReq)
+				require.NoError(t, err)
+				req.Body = io.NopCloser(bytes.NewReader(jsonData))
+			},
+			Expected: func(t *testing.T, resp *http.Response) {
+				contentType := resp.Header.Get("Content-Type")
+				if contentType != "application/json; charset=utf-8" {
+					t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType)
+				}
+				_, err := io.ReadAll(resp.Body)
+
+				if err != nil {
+					t.Fatal(err)
+				}
+
+				if resp.StatusCode != http.StatusBadRequest {
+					t.Fatalf("expected status code 400, got %d", resp.StatusCode)
+				}
+			},
+		},
 	}
 
 	t.Setenv("OLLAMA_MODELS", t.TempDir())
@@ -420,3 +488,38 @@ func TestShow(t *testing.T) {
 		t.Fatal("Expected projector architecture to be 'clip', but got", resp.ProjectorInfo["general.architecture"])
 	}
 }
+
+func TestNormalize(t *testing.T) {
+	type testCase struct {
+		input []float32
+	}
+
+	testCases := []testCase{
+		{input: []float32{1}},
+		{input: []float32{0, 1, 2, 3}},
+		{input: []float32{0.1, 0.2, 0.3}},
+		{input: []float32{-0.1, 0.2, 0.3, -0.4}},
+		{input: []float32{0, 0, 0}},
+	}
+
+	isNormalized := func(vec []float32) (res bool) {
+		sum := 0.0
+		for _, v := range vec {
+			sum += float64(v * v)
+		}
+		if math.Abs(sum-1) > 1e-6 {
+			return sum == 0
+		} else {
+			return true
+		}
+	}
+
+	for _, tc := range testCases {
+		t.Run("", func(t *testing.T) {
+			normalized := normalize(tc.input)
+			if !isNormalized(normalized) {
+				t.Errorf("Vector %v is not normalized", tc.input)
+			}
+		})
+	}
+}

+ 4 - 4
server/sched_test.go

@@ -642,8 +642,8 @@ type mockLlm struct {
 	pingResp           error
 	waitResp           error
 	completionResp     error
-	embeddingResp      []float64
-	embeddingRespErr   error
+	embedResp          [][]float32
+	embedRespErr       error
 	tokenizeResp       []int
 	tokenizeRespErr    error
 	detokenizeResp     string
@@ -660,8 +660,8 @@ func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitRes
 func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
 	return s.completionResp
 }
-func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) {
-	return s.embeddingResp, s.embeddingRespErr
+func (s *mockLlm) Embed(ctx context.Context, input []string) ([][]float32, error) {
+	return s.embedResp, s.embedRespErr
 }
 func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
 	return s.tokenizeResp, s.tokenizeRespErr