瀏覽代碼

testing clean up

Roy Han 9 月之前
父節點
當前提交
53e9576f46
共有 2 個文件被更改,包括 8 次插入7 次删除
  1. 6 5
      integration/embed_test.go
  2. 2 2
      server/routes_test.go

+ 6 - 5
integration/embed_test.go

@@ -8,7 +8,6 @@ import (
 	"time"
 
 	"github.com/ollama/ollama/api"
-	"github.com/stretchr/testify/require"
 )
 
 func TestAllMiniLMEmbed(t *testing.T) {
@@ -116,11 +115,11 @@ func TestAllMiniLmEmbedTruncate(t *testing.T) {
 	}
 
 	if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] {
-		t.Fatalf("expected default request to truncate correctly")
+		t.Fatal("expected default request to truncate correctly")
 	}
 
 	if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] {
-		t.Fatalf("expected default request and truncate true request to be the same")
+		t.Fatal("expected default request and truncate true request to be the same")
 	}
 
 	// check that truncate set to false returns an error if context length is exceeded
@@ -132,14 +131,16 @@ func TestAllMiniLmEmbedTruncate(t *testing.T) {
 	})
 
 	if err == nil {
-		t.Fatalf("expected error, got nil")
+		t.Fatal("expected error, got nil")
 	}
 }
 
 func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
 	client, _, cleanup := InitServerConnection(ctx, t)
 	defer cleanup()
-	require.NoError(t, PullIfMissing(ctx, client, req.Model))
+	if err := PullIfMissing(ctx, client, req.Model); err != nil {
+		t.Fatalf("failed to pull model %s: %v", req.Model, err)
+	}
 
 	response, err := client.Embed(ctx, &req)
 

+ 2 - 2
server/routes_test.go

@@ -483,7 +483,7 @@ func TestNormalize(t *testing.T) {
 		{input: []float32{0, 0, 0}},
 	}
 
-	assertNorm := func(vec []float32) (res bool) {
+	isNormalized := func(vec []float32) (res bool) {
 		sum := 0.0
 		for _, v := range vec {
 			sum += float64(v * v)
@@ -498,7 +498,7 @@ func TestNormalize(t *testing.T) {
 	for _, tc := range testCases {
 		t.Run("", func(t *testing.T) {
 			normalized := normalize(tc.input)
-			if !assertNorm(normalized) {
+			if !isNormalized(normalized) {
 				t.Errorf("Vector %v is not normalized", tc.input)
 			}
 		})