Browse Source

server: parallelize embeddings in API web handler instead of in subprocess runner (#6220)

For simplicity, perform parallelization of embedding requests in the API handler instead of offloading this to the subprocess runner. This keeps the scheduling story simpler as it builds on existing parallel requests, similar to existing text completion functionality.
Jeffrey Morgan 8 tháng trước cách đây
mục cha
commit
15c2d8fe14
4 tập tin đã thay đổi với 53 bổ sung71 xóa
  1. 8 34
      llm/ext_server/server.cpp
  2. 14 18
      llm/server.go
  3. 27 15
      server/routes.go
  4. 4 4
      server/sched_test.go

+ 8 - 34
llm/ext_server/server.cpp

@@ -1223,9 +1223,7 @@ struct llama_server_context
 
                 res.result_json = json
                 {
-                    {"id", res.id},
                     {"embedding", std::vector<float>(embd, embd + n_embd)},
-                    {"timings",             slot.get_formated_timings()},
                 };
             }
         }
@@ -3194,41 +3192,17 @@ int main(int argc, char **argv) {
                     prompt = "";
                 }
 
-                if (prompt.size() == 1) {
-                    prompt = prompt[0];
-                }
-
                 // create and queue the task
-                json responses;
-                {
-                    const int id_task = llama.queue_tasks.get_new_id();
-                    llama.queue_results.add_waiting_task_id(id_task);
-                    llama.request_completion(id_task, {{"prompt", prompt}}, true, -1);
-
-                    // get the result
-                    task_result result = llama.queue_results.recv(id_task);
-                    llama.queue_results.remove_waiting_task_id(id_task);
-                    if (result.error) {
-                        return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
-                    }
-
-                    responses = result.result_json.value("results", std::vector<json>{result.result_json});
-                    std::sort(responses.begin(), responses.end(), [](const json& a, const json& b) {
-                        return a["id"] < b["id"];
-                    });
-
-                    json embeddings = json::array();
+                const int task_id = llama.queue_tasks.get_new_id();
+                llama.queue_results.add_waiting_task_id(task_id);
+                llama.request_completion(task_id, {{"prompt", prompt}}, true, -1);
 
-                    int prompt_n = 0;
-                    for (auto & elem : responses) {
-                        embeddings.push_back(elem.at("embedding"));
-                        prompt_n += elem.at("timings").at("prompt_n").get<int>();
-                    }
+                // get the result
+                task_result result = llama.queue_results.recv(task_id);
+                llama.queue_results.remove_waiting_task_id(task_id);
 
-                    // send the result
-                    json embedding_res = json{{"embedding", embeddings}, {"prompt_n", prompt_n}};
-                    return res.set_content(embedding_res.dump(), "application/json; charset=utf-8");
-                }
+                // send the result
+                return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
             });
 
     // GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?

+ 14 - 18
llm/server.go

@@ -33,7 +33,7 @@ type LlamaServer interface {
 	Ping(ctx context.Context) error
 	WaitUntilRunning(ctx context.Context) error
 	Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
-	Embed(ctx context.Context, input []string) (*EmbedResponse, error)
+	Embedding(ctx context.Context, input string) ([]float32, error)
 	Tokenize(ctx context.Context, content string) ([]int, error)
 	Detokenize(ctx context.Context, tokens []int) (string, error)
 	Close() error
@@ -883,24 +883,20 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
 	return nil
 }
 
-type EmbedRequest struct {
-	Content []string `json:"content"`
+type EmbeddingRequest struct {
+	Content string `json:"content"`
 }
 
-type EmbedResponse struct {
-	Embedding       [][]float32 `json:"embedding"`
-	PromptEvalCount int         `json:"prompt_n"`
+type EmbeddingResponse struct {
+	Embedding []float32 `json:"embedding"`
 }
 
-func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse, error) {
-	// each input will use a slot, so we need to acquire the semaphore for
-	// the number of inputs up to numParallel
-	slots := int64(min(len(input), s.numParallel))
-	if err := s.sem.Acquire(ctx, slots); err != nil {
+func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, error) {
+	if err := s.sem.Acquire(ctx, 1); err != nil {
 		slog.Error("Failed to acquire semaphore", "error", err)
 		return nil, err
 	}
-	defer s.sem.Release(slots)
+	defer s.sem.Release(1)
 
 	// Make sure the server is ready
 	status, err := s.getServerStatusRetry(ctx)
@@ -910,18 +906,18 @@ func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse,
 		return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
 	}
 
-	data, err := json.Marshal(EmbedRequest{Content: input})
+	data, err := json.Marshal(EmbeddingRequest{Content: input})
 	if err != nil {
 		return nil, fmt.Errorf("error marshaling embed data: %w", err)
 	}
 
-	req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
+	r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
 	if err != nil {
 		return nil, fmt.Errorf("error creating embed request: %w", err)
 	}
-	req.Header.Set("Content-Type", "application/json")
+	r.Header.Set("Content-Type", "application/json")
 
-	resp, err := http.DefaultClient.Do(req)
+	resp, err := http.DefaultClient.Do(r)
 	if err != nil {
 		return nil, fmt.Errorf("do embedding request: %w", err)
 	}
@@ -937,12 +933,12 @@ func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse,
 		return nil, fmt.Errorf("%s", body)
 	}
 
-	var e EmbedResponse
+	var e EmbeddingResponse
 	if err := json.Unmarshal(body, &e); err != nil {
 		return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
 	}
 
-	return &e, nil
+	return e.Embedding, nil
 }
 
 type TokenizeRequest struct {

+ 27 - 15
server/routes.go

@@ -23,6 +23,7 @@ import (
 
 	"github.com/gin-contrib/cors"
 	"github.com/gin-gonic/gin"
+	"golang.org/x/sync/errgroup"
 
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/envconfig"
@@ -346,6 +347,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 		return
 	}
 
+	var count int
 	for i, s := range input {
 		tokens, err := r.Tokenize(c.Request.Context(), s)
 		if err != nil {
@@ -368,25 +370,36 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 			}
 		}
 
+		count += len(tokens)
+
 		input[i] = s
 	}
-	embeddings, err := r.Embed(c.Request.Context(), input)
-	if err != nil {
-		slog.Error("embedding generation failed", "error", err)
-		c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
-		return
+
+	var g errgroup.Group
+	embeddings := make([][]float32, len(input))
+	for i, text := range input {
+		g.Go(func() error {
+			embedding, err := r.Embedding(c.Request.Context(), text)
+			if err != nil {
+				return err
+			}
+			embeddings[i] = normalize(embedding)
+			return nil
+		})
 	}
 
-	for i, e := range embeddings.Embedding {
-		embeddings.Embedding[i] = normalize(e)
+	if err := g.Wait(); err != nil {
+		slog.Error("embedding generation failed", "error", err)
+		c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Errorf("failed to generate embeddings: %v", err)})
+		return
 	}
 
 	resp := api.EmbedResponse{
 		Model:           req.Model,
-		Embeddings:      embeddings.Embedding,
+		Embeddings:      embeddings,
 		TotalDuration:   time.Since(checkpointStart),
 		LoadDuration:    checkpointLoaded.Sub(checkpointStart),
-		PromptEvalCount: embeddings.PromptEvalCount,
+		PromptEvalCount: count,
 	}
 	c.JSON(http.StatusOK, resp)
 }
@@ -430,21 +443,20 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
 		return
 	}
 
-	embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt})
+	embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
 	if err != nil {
 		slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
 		c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
 		return
 	}
 
-	embedding := make([]float64, len(embeddings.Embedding[0]))
-
-	for i, v := range embeddings.Embedding[0] {
-		embedding[i] = float64(v)
+	var e []float64
+	for _, v := range embedding {
+		e = append(e, float64(v))
 	}
 
 	resp := api.EmbeddingResponse{
-		Embedding: embedding,
+		Embedding: e,
 	}
 	c.JSON(http.StatusOK, resp)
 }

+ 4 - 4
server/sched_test.go

@@ -708,8 +708,8 @@ type mockLlm struct {
 	pingResp           error
 	waitResp           error
 	completionResp     error
-	embedResp          *llm.EmbedResponse
-	embedRespErr       error
+	embeddingResp      []float32
+	embeddingRespErr   error
 	tokenizeResp       []int
 	tokenizeRespErr    error
 	detokenizeResp     string
@@ -727,8 +727,8 @@ func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn
 	return s.completionResp
 }
 
-func (s *mockLlm) Embed(ctx context.Context, input []string) (*llm.EmbedResponse, error) {
-	return s.embedResp, s.embedRespErr
+func (s *mockLlm) Embedding(ctx context.Context, input string) ([]float32, error) {
+	return s.embeddingResp, s.embeddingRespErr
 }
 
 func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {