Browse Source

Truncation Integration Tests

Roy Han 10 tháng trước cách đây
mục cha
commit
1a0c8b363c
3 tập tin đã thay đổi với 105 bổ sung20 xóa
  1. 87 2
      integration/embed_test.go
  2. 3 3
      integration/utils_test.go
  3. 15 15
      server/routes.go

+ 87 - 2
integration/embed_test.go

@@ -19,7 +19,11 @@ func TestAllMiniLMEmbed(t *testing.T) {
 		Input: "why is the sky blue?",
 	}
 
-	res := EmbedTestHelper(ctx, t, req)
+	res, err := EmbedTestHelper(ctx, t, req)
+
+	if err != nil {
+		t.Fatalf("error: %v", err)
+	}
 
 	if len(res.Embeddings) != 1 {
 		t.Fatalf("expected 1 embedding, got %d", len(res.Embeddings))
@@ -28,6 +32,10 @@ func TestAllMiniLMEmbed(t *testing.T) {
 	if len(res.Embeddings[0]) != 384 {
 		t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
 	}
+
+	if res.Embeddings[0][0] != 0.010071029038540258 {
+		t.Fatalf("expected 0.010071029038540258, got %f", res.Embeddings[0][0])
+	}
 }
 
 func TestAllMiniLMBatchEmbed(t *testing.T) {
@@ -39,7 +47,11 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
 		Input: []string{"why is the sky blue?", "why is the grass green?"},
 	}
 
-	res := EmbedTestHelper(ctx, t, req)
+	res, err := EmbedTestHelper(ctx, t, req)
+
+	if err != nil {
+		t.Fatalf("error: %v", err)
+	}
 
 	if len(res.Embeddings) != 2 {
 		t.Fatalf("expected 2 embeddings, got %d", len(res.Embeddings))
@@ -48,4 +60,77 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
 	if len(res.Embeddings[0]) != 384 {
 		t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
 	}
+
+	if res.Embeddings[0][0] != 0.010071029038540258 || res.Embeddings[1][0] != -0.00980270794235093 {
+		t.Fatalf("expected 0.010071029038540258 and -0.00980270794235093, got %f and %f", res.Embeddings[0][0], res.Embeddings[1][0])
+	}
+}
+
+func TestAllMiniLmEmbedTruncate(t *testing.T) {
+	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
+	defer cancel()
+
+	truncTrue, truncFalse := true, false
+
+	type testReq struct {
+		Name    string
+		Request api.EmbedRequest
+	}
+
+	reqs := []testReq{
+		{
+			Name: "Target Truncation",
+			Request: api.EmbedRequest{
+				Model: "all-minilm",
+				Input: "why",
+			},
+		},
+		{
+			Name: "Default Truncate",
+			Request: api.EmbedRequest{
+				Model:   "all-minilm",
+				Input:   "why is the sky blue?",
+				Options: map[string]any{"num_ctx": 1},
+			},
+		},
+		{
+			Name: "Explicit Truncate",
+			Request: api.EmbedRequest{
+				Model:    "all-minilm",
+				Input:    "why is the sky blue?",
+				Truncate: &truncTrue,
+				Options:  map[string]any{"num_ctx": 1},
+			},
+		},
+	}
+
+	res := make(map[string]*api.EmbedResponse)
+
+	for _, req := range reqs {
+		response, err := EmbedTestHelper(ctx, t, req.Request)
+		if err != nil {
+			t.Fatalf("error: %v", err)
+		}
+		res[req.Name] = response
+	}
+
+	if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] {
+		t.Fatalf("expected default request to truncate correctly")
+	}
+
+	if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] {
+		t.Fatalf("expected default request and truncate true request to be the same")
+	}
+
+	// check that truncate set to false returns an error if context length is exceeded
+	_, err := EmbedTestHelper(ctx, t, api.EmbedRequest{
+		Model:    "all-minilm",
+		Input:    "why is the sky blue?",
+		Truncate: &truncFalse,
+		Options:  map[string]any{"num_ctx": 1},
+	})
+
+	if err == nil {
+		t.Fatalf("expected error, got nil")
+	}
 }

+ 3 - 3
integration/utils_test.go

@@ -342,7 +342,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
 		}
 }
 
-func EmbedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) *api.EmbedResponse {
+func EmbedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
 	client, _, cleanup := InitServerConnection(ctx, t)
 	defer cleanup()
 	require.NoError(t, PullIfMissing(ctx, client, req.Model))
@@ -350,8 +350,8 @@ func EmbedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) *a
 	response, err := client.Embed(ctx, &req)
 
 	if err != nil {
-		t.Fatalf("Error making request: %v", err)
+		return nil, err
 	}
 
-	return response
+	return response, nil
 }

+ 15 - 15
server/routes.go

@@ -395,7 +395,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 		return
 	}
 
-	truncate := func(s string) (string, error) {
+	checkFit := func(s string, truncate bool) (string, error) {
 		tokens, err := runner.llama.Tokenize(c.Request.Context(), s)
 		if err != nil {
 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -403,8 +403,12 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 		}
 
 		if len(tokens) > opts.NumCtx {
-			tokens = tokens[:opts.NumCtx]
-			return runner.llama.Detokenize(c.Request.Context(), tokens)
+			if truncate {
+				tokens = tokens[:opts.NumCtx]
+				return runner.llama.Detokenize(c.Request.Context(), tokens)
+			} else {
+				return "", fmt.Errorf("input length exceeds maximum context length")
+			}
 		}
 
 		return s, nil
@@ -418,12 +422,10 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 			c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float64{}})
 			return
 		}
-		if *req.Truncate {
-			reqEmbed, err = truncate(reqEmbed)
-			if err != nil {
-				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-				return
-			}
+		reqEmbed, err = checkFit(reqEmbed, *req.Truncate)
+		if err != nil {
+			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+			return
 		}
 		embeddings, err = runner.llama.Embed(c.Request.Context(), []string{reqEmbed})
 	case []any:
@@ -435,12 +437,10 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 		reqEmbedArray := make([]string, len(reqEmbed))
 		for i, v := range reqEmbed {
 			if s, ok := v.(string); ok {
-				if *req.Truncate {
-					s, err = truncate(s)
-					if err != nil {
-						c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-						return
-					}
+				s, err = checkFit(s, *req.Truncate)
+				if err != nil {
+					c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+					return
 				}
 				reqEmbedArray[i] = s
 			} else {