Browse Source

Fix Embed Test Flakes (#5893)

* float cmp

* increase tolerance
royjhan 9 months ago
parent
commit
ac33aa7d37
1 changed files with 54 additions and 5 deletions
  1. 54 5
      integration/embed_test.go

+ 54 - 5
integration/embed_test.go

@@ -4,12 +4,45 @@ package integration
 
 import (
 	"context"
+	"math"
 	"testing"
 	"time"
 
 	"github.com/ollama/ollama/api"
 )
 
+func floatsEqual32(a, b float32) bool {
+	return math.Abs(float64(a-b)) <= 1e-4
+}
+
+func floatsEqual64(a, b float64) bool {
+	return math.Abs(a-b) <= 1e-4
+}
+
+func TestAllMiniLMEmbeddings(t *testing.T) {
+	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
+	defer cancel()
+
+	req := api.EmbeddingRequest{
+		Model:  "all-minilm",
+		Prompt: "why is the sky blue?",
+	}
+
+	res, err := embeddingTestHelper(ctx, t, req)
+
+	if err != nil {
+		t.Fatalf("error: %v", err)
+	}
+
+	if len(res.Embedding) != 384 {
+		t.Fatalf("expected 384 floats, got %d", len(res.Embedding))
+	}
+
+	if !floatsEqual64(res.Embedding[0], 0.06642947345972061) {
+		t.Fatalf("expected 0.06642947345972061, got %.16f", res.Embedding[0])
+	}
+}
+
 func TestAllMiniLMEmbed(t *testing.T) {
 	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
 	defer cancel()
@@ -33,8 +66,8 @@ func TestAllMiniLMEmbed(t *testing.T) {
 		t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
 	}
 
-	if res.Embeddings[0][0] != 0.010071031 {
-		t.Fatalf("expected 0.010071031, got %f", res.Embeddings[0][0])
+	if !floatsEqual32(res.Embeddings[0][0], 0.010071031) {
+		t.Fatalf("expected 0.010071031, got %.8f", res.Embeddings[0][0])
 	}
 }
 
@@ -61,12 +94,12 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
 		t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
 	}
 
-	if res.Embeddings[0][0] != 0.010071031 || res.Embeddings[1][0] != -0.009802706 {
-		t.Fatalf("expected 0.010071031 and -0.009802706, got %f and %f", res.Embeddings[0][0], res.Embeddings[1][0])
+	if !floatsEqual32(res.Embeddings[0][0], 0.010071031) || !floatsEqual32(res.Embeddings[1][0], -0.009802706) {
+		t.Fatalf("expected 0.010071031 and -0.009802706, got %.8f and %.8f", res.Embeddings[0][0], res.Embeddings[1][0])
 	}
 }
 
-func TestAllMiniLmEmbedTruncate(t *testing.T) {
+func TestAllMiniLMEmbedTruncate(t *testing.T) {
 	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
 	defer cancel()
 
@@ -135,6 +168,22 @@ func TestAllMiniLmEmbedTruncate(t *testing.T) {
 	}
 }
 
+func embeddingTestHelper(ctx context.Context, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
+	client, _, cleanup := InitServerConnection(ctx, t)
+	defer cleanup()
+	if err := PullIfMissing(ctx, client, req.Model); err != nil {
+		t.Fatalf("failed to pull model %s: %v", req.Model, err)
+	}
+
+	response, err := client.Embeddings(ctx, &req)
+
+	if err != nil {
+		return nil, err
+	}
+
+	return response, nil
+}
+
 func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
 	client, _, cleanup := InitServerConnection(ctx, t)
 	defer cleanup()