Roy Han 9 月之前
父節點
當前提交
bcb63e6e0e
共有 4 個文件被更改,包括 24 次插入29 次删除
  1. 19 4
      integration/embed_test.go
  2. 0 14
      integration/utils_test.go
  3. 1 7
      llm/ext_server/server.cpp
  4. 4 4
      server/routes.go

+ 19 - 4
integration/embed_test.go

@@ -8,6 +8,7 @@ import (
 	"time"
 
 	"github.com/ollama/ollama/api"
+	"github.com/stretchr/testify/require"
 )
 
 func TestAllMiniLMEmbed(t *testing.T) {
@@ -19,7 +20,7 @@ func TestAllMiniLMEmbed(t *testing.T) {
 		Input: "why is the sky blue?",
 	}
 
-	res, err := EmbedTestHelper(ctx, t, req)
+	res, err := embedTestHelper(ctx, t, req)
 
 	if err != nil {
 		t.Fatalf("error: %v", err)
@@ -47,7 +48,7 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
 		Input: []string{"why is the sky blue?", "why is the grass green?"},
 	}
 
-	res, err := EmbedTestHelper(ctx, t, req)
+	res, err := embedTestHelper(ctx, t, req)
 
 	if err != nil {
 		t.Fatalf("error: %v", err)
@@ -107,7 +108,7 @@ func TestAllMiniLmEmbedTruncate(t *testing.T) {
 	res := make(map[string]*api.EmbedResponse)
 
 	for _, req := range reqs {
-		response, err := EmbedTestHelper(ctx, t, req.Request)
+		response, err := embedTestHelper(ctx, t, req.Request)
 		if err != nil {
 			t.Fatalf("error: %v", err)
 		}
@@ -123,7 +124,7 @@ func TestAllMiniLmEmbedTruncate(t *testing.T) {
 	}
 
 	// check that truncate set to false returns an error if context length is exceeded
-	_, err := EmbedTestHelper(ctx, t, api.EmbedRequest{
+	_, err := embedTestHelper(ctx, t, api.EmbedRequest{
 		Model:    "all-minilm",
 		Input:    "why is the sky blue?",
 		Truncate: &truncFalse,
@@ -134,3 +135,17 @@ func TestAllMiniLmEmbedTruncate(t *testing.T) {
 		t.Fatalf("expected error, got nil")
 	}
 }
+
+func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
+	client, _, cleanup := InitServerConnection(ctx, t)
+	defer cleanup()
+	require.NoError(t, PullIfMissing(ctx, client, req.Model))
+
+	response, err := client.Embed(ctx, &req)
+
+	if err != nil {
+		return nil, err
+	}
+
+	return response, nil
+}

+ 0 - 14
integration/utils_test.go

@@ -341,17 +341,3 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
 			[]string{"nitrogen", "oxygen", "carbon", "dioxide"},
 		}
 }
-
-func EmbedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
-	client, _, cleanup := InitServerConnection(ctx, t)
-	defer cleanup()
-	require.NoError(t, PullIfMissing(ctx, client, req.Model))
-
-	response, err := client.Embed(ctx, &req)
-
-	if err != nil {
-		return nil, err
-	}
-
-	return response, nil
-}

+ 1 - 7
llm/ext_server/server.cpp

@@ -3199,13 +3199,7 @@ int main(int argc, char **argv) {
                     task_result result = llama.queue_results.recv(id_task);
                     llama.queue_results.remove_waiting_task_id(id_task);
                     if (!result.error) {
-                        if (result.result_json.count("results")) {
-                            // result for multi-task
-                            responses = result.result_json.at("results");
-                        } else {
-                            // result for single task
-                            responses = std::vector<json>(1, result.result_json);
-                        }
+                        responses = result.result_json.value("results", std::vector<json>{result.result_json});
                         json embeddings = json::array();
                         for (auto & elem : responses) {
                             embeddings.push_back(elem.at("embedding"));

+ 4 - 4
server/routes.go

@@ -9,7 +9,6 @@ import (
 	"io"
 	"io/fs"
 	"log/slog"
-	"math"
 	"net"
 	"net/http"
 	"net/netip"
@@ -21,6 +20,7 @@ import (
 	"syscall"
 	"time"
 
+	"github.com/chewxy/math32"
 	"github.com/gin-contrib/cors"
 	"github.com/gin-gonic/gin"
 
@@ -443,14 +443,14 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 }
 
 func normalize(vec []float32) []float32 {
-	var sum float64
+	var sum float32
 	for _, v := range vec {
-		sum += float64(v * v)
+		sum += v * v
 	}
 
 	norm := float32(0.0)
 	if sum > 0 {
-		norm = float32(1.0 / math.Sqrt(sum))
+		norm = float32(1.0 / math32.Sqrt(sum))
 	}
 
 	for i := range vec {