Explorar o código

normalization

Roy Han hai 10 meses
pai
achega
c111d8bb51
Modificáronse 4 ficheiros con 63 adicións e 11 borrados
  1. 10 2
      llm/ext_server/server.cpp
  2. 17 0
      llm/ext_server/utils.hpp
  3. 8 6
      llm/server.go
  4. 28 3
      server/routes.go

+ 10 - 2
llm/ext_server/server.cpp

@@ -3185,8 +3185,16 @@ int main(int argc, char **argv) {
                             responses = std::vector<json>(1, result.result_json);
                         }
                         json embeddings = json::array();
-                        for (auto & elem : responses) {
-                            embeddings.push_back(json_value(elem, "embedding", json::array()));
+                        if (body["normalize"]) {
+                            for (auto & elem : responses) {
+                                std::vector<float> embedding = elem.at("embedding").get<std::vector<float>>();
+                                embedding = normalize_vector(embedding, embedding.size());
+                                embeddings.push_back(embedding);
+                            }
+                        } else {
+                            for (auto & elem : responses) {
+                                embeddings.push_back(elem.at("embedding"));
+                            }
                         }
                         // send the result
                         json result = json{{"embedding", embeddings}};

+ 17 - 0
llm/ext_server/utils.hpp

@@ -656,3 +656,20 @@ static json probs_vector_to_json(const llama_context *ctx, const std::vector<com
     }
     return out;
 }
+
+// normalize a vector
+std::vector<float> normalize_vector(const std::vector<float>& vec, int size) {
+    double sum = 0.0;
+    for (float value : vec) {
+        sum += value * value;
+    }
+    sum = std::sqrt(sum);
+
+    const float norm = sum > 0.0 ? 1.0f / sum : 0.0f;
+
+    std::vector<float> normalized_vec(size);
+    for (int i = 0; i < size; i++) {
+        normalized_vec[i] = vec[i] * norm;
+    }
+    return normalized_vec;
+}

+ 8 - 6
llm/server.go

@@ -843,7 +843,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
 }
 
 type EmbedRequest struct {
-	Content []string `json:"content"`
+	Content   []string `json:"content"`
+	Normalize bool     `json:"normalize"`
 }
 
 type EmbedResponse struct {
@@ -865,7 +866,7 @@ func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float64, err
 		return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
 	}
 
-	data, err := json.Marshal(EmbedRequest{Content: input})
+	data, err := json.Marshal(EmbedRequest{Content: input, Normalize: true})
 	if err != nil {
 		return nil, fmt.Errorf("error marshaling embed data: %w", err)
 	}
@@ -901,11 +902,12 @@ func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float64, err
 }
 
 type EmbeddingRequest struct {
-	Content string `json:"content"`
+	Content   string `json:"content"`
+	Normalize bool   `json:"normalize"`
 }
 
 type EmbeddingResponse struct {
-	Embedding []float64 `json:"embedding"`
+	Embedding [][]float64 `json:"embedding"`
 }
 
 func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
@@ -923,7 +925,7 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
 		return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
 	}
 
-	data, err := json.Marshal(TokenizeRequest{Content: prompt})
+	data, err := json.Marshal(EmbeddingRequest{Content: prompt, Normalize: false})
 	if err != nil {
 		return nil, fmt.Errorf("error marshaling embed data: %w", err)
 	}
@@ -955,7 +957,7 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
 		return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
 	}
 
-	return embedding.Embedding, nil
+	return embedding.Embedding[0], nil
 }
 
 type TokenizeRequest struct {

+ 28 - 3
server/routes.go

@@ -398,12 +398,22 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 			return
 		}
 		embeddings, err = runner.llama.Embed(c.Request.Context(), []string{reqEmbed})
-	case []string:
+	case []any:
 		if reqEmbed == nil {
 			c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float64{}})
 			return
 		}
-		embeddings, err = runner.llama.Embed(c.Request.Context(), reqEmbed)
+
+		reqEmbedArray := make([]string, len(reqEmbed))
+		for i, v := range reqEmbed {
+			if s, ok := v.(string); ok {
+				reqEmbedArray[i] = s
+			} else {
+				c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
+				return
+			}
+		}
+		embeddings, err = runner.llama.Embed(c.Request.Context(), reqEmbedArray)
 	default:
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
 	}
@@ -414,6 +424,19 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 		return
 	}
 
+	// assert that embedding is normalized
+	for _, e := range embeddings {
+		sum := 0.0
+		for _, v := range e {
+			sum += v * v
+		}
+		if math.Abs(sum-1) > 1e-6 {
+			slog.Info("embedding is not normalized", "sum", sum)
+		} else {
+			slog.Info("embedding is normalized", "sum", sum)
+		}
+	}
+
 	resp := api.EmbedResponse{Embeddings: embeddings}
 	c.JSON(http.StatusOK, resp)
 }
@@ -486,7 +509,9 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
 	for _, v := range embedding {
 		sum += v * v
 	}
-	if math.Abs(sum-1) > 1e-6 {
+	if math.Abs(sum-1) < 1e-6 {
+		slog.Info("embedding is normalized", "sum", sum)
+	} else {
 		slog.Info("embedding is not normalized", "sum", sum)
 	}