Browse Source

Add Metrics to `api\embed` response (#5709)

* add prompt tokens to embed response

* rm slog

* metrics

* types

* prompt n

* clean up

* reset submodule

* update tests

* test name

* list metrics
royjhan 9 months ago
parent
commit
1b44d873e7
6 changed files with 39 additions and 15 deletions
  1. 4 0
      api/types.go
  2. 8 0
      integration/embed_test.go
  3. 6 1
      llm/ext_server/server.cpp
  4. 7 6
      llm/server.go
  5. 12 6
      server/routes.go
  6. 2 2
      server/sched_test.go

+ 4 - 0
api/types.go

@@ -267,6 +267,10 @@ type EmbedRequest struct {
 type EmbedResponse struct {
 	Model      string      `json:"model"`
 	Embeddings [][]float32 `json:"embeddings"`
+
+	TotalDuration   time.Duration `json:"total_duration,omitempty"`
+	LoadDuration    time.Duration `json:"load_duration,omitempty"`
+	PromptEvalCount int           `json:"prompt_eval_count,omitempty"`
 }
 
 // EmbeddingRequest is the request passed to [Client.Embeddings].

+ 8 - 0
integration/embed_test.go

@@ -69,6 +69,10 @@ func TestAllMiniLMEmbed(t *testing.T) {
 	if !floatsEqual32(res.Embeddings[0][0], 0.010071031) {
 		t.Fatalf("expected 0.010071031, got %.8f", res.Embeddings[0][0])
 	}
+
+	if res.PromptEvalCount != 8 {
+		t.Fatalf("expected 8 prompt tokens, got %d", res.PromptEvalCount)
+	}
 }
 
 func TestAllMiniLMBatchEmbed(t *testing.T) {
@@ -97,6 +101,10 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
 	if !floatsEqual32(res.Embeddings[0][0], 0.010071031) || !floatsEqual32(res.Embeddings[1][0], -0.009802706) {
 		t.Fatalf("expected 0.010071031 and -0.009802706, got %.8f and %.8f", res.Embeddings[0][0], res.Embeddings[1][0])
 	}
+
+	if res.PromptEvalCount != 16 {
+		t.Fatalf("expected 16 prompt tokens, got %d", res.PromptEvalCount)
+	}
 }
 
 func TestAllMiniLMEmbedTruncate(t *testing.T) {

+ 6 - 1
llm/ext_server/server.cpp

@@ -1221,6 +1221,7 @@ struct llama_server_context
                 res.result_json = json
                 {
                     {"embedding", std::vector<float>(embd, embd + n_embd)},
+                    {"timings",             slot.get_formated_timings()},
                 };
             }
         }
@@ -3203,11 +3204,15 @@ int main(int argc, char **argv) {
 
                     responses = result.result_json.value("results", std::vector<json>{result.result_json});
                     json embeddings = json::array();
+
+                    int prompt_n = 0;
                     for (auto & elem : responses) {
                         embeddings.push_back(elem.at("embedding"));
+                        prompt_n += elem.at("timings").at("prompt_n").get<int>();
                     }
+
                     // send the result
-                    json embedding_res = json{{"embedding", embeddings}};
+                    json embedding_res = json{{"embedding", embeddings}, {"prompt_n", prompt_n}};
                     return res.set_content(embedding_res.dump(), "application/json; charset=utf-8");
                 }
             });

+ 7 - 6
llm/server.go

@@ -33,7 +33,7 @@ type LlamaServer interface {
 	Ping(ctx context.Context) error
 	WaitUntilRunning(ctx context.Context) error
 	Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
-	Embed(ctx context.Context, input []string) ([][]float32, error)
+	Embed(ctx context.Context, input []string) (*EmbedResponse, error)
 	Tokenize(ctx context.Context, content string) ([]int, error)
 	Detokenize(ctx context.Context, tokens []int) (string, error)
 	Close() error
@@ -879,10 +879,11 @@ type EmbedRequest struct {
 }
 
 type EmbedResponse struct {
-	Embedding [][]float32 `json:"embedding"`
+	Embedding       [][]float32 `json:"embedding"`
+	PromptEvalCount int         `json:"prompt_n"`
 }
 
-func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, error) {
+func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse, error) {
 	if err := s.sem.Acquire(ctx, 1); err != nil {
 		slog.Error("Failed to acquire semaphore", "error", err)
 		return nil, err
@@ -924,12 +925,12 @@ func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, err
 		return nil, fmt.Errorf("%s", body)
 	}
 
-	var embedding EmbedResponse
-	if err := json.Unmarshal(body, &embedding); err != nil {
+	var e EmbedResponse
+	if err := json.Unmarshal(body, &e); err != nil {
 		return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
 	}
 
-	return embedding.Embedding, nil
+	return &e, nil
 }
 
 type TokenizeRequest struct {

+ 12 - 6
server/routes.go

@@ -284,6 +284,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 }
 
 func (s *Server) EmbedHandler(c *gin.Context) {
+	checkpointStart := time.Now()
 	var req api.EmbedRequest
 	err := c.ShouldBindJSON(&req)
 	switch {
@@ -332,6 +333,8 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 		return
 	}
 
+	checkpointLoaded := time.Now()
+
 	kvData, err := getKVData(m.ModelPath, false)
 	if err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -370,13 +373,16 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 		return
 	}
 
-	for i, e := range embeddings {
-		embeddings[i] = normalize(e)
+	for i, e := range embeddings.Embedding {
+		embeddings.Embedding[i] = normalize(e)
 	}
 
 	resp := api.EmbedResponse{
-		Model:      req.Model,
-		Embeddings: embeddings,
+		Model:           req.Model,
+		Embeddings:      embeddings.Embedding,
+		TotalDuration:   time.Since(checkpointStart),
+		LoadDuration:    checkpointLoaded.Sub(checkpointStart),
+		PromptEvalCount: embeddings.PromptEvalCount,
 	}
 	c.JSON(http.StatusOK, resp)
 }
@@ -428,9 +434,9 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
 		return
 	}
 
-	embedding := make([]float64, len(embeddings[0]))
+	embedding := make([]float64, len(embeddings.Embedding[0]))
 
-	for i, v := range embeddings[0] {
+	for i, v := range embeddings.Embedding[0] {
 		embedding[i] = float64(v)
 	}
 

+ 2 - 2
server/sched_test.go

@@ -709,7 +709,7 @@ type mockLlm struct {
 	pingResp           error
 	waitResp           error
 	completionResp     error
-	embedResp          [][]float32
+	embedResp          *llm.EmbedResponse
 	embedRespErr       error
 	tokenizeResp       []int
 	tokenizeRespErr    error
@@ -727,7 +727,7 @@ func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitRes
 func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
 	return s.completionResp
 }
-func (s *mockLlm) Embed(ctx context.Context, input []string) ([][]float32, error) {
+func (s *mockLlm) Embed(ctx context.Context, input []string) (*llm.EmbedResponse, error) {
 	return s.embedResp, s.embedRespErr
 }
 func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {