|
@@ -19,7 +19,11 @@ func TestAllMiniLMEmbed(t *testing.T) {
|
|
|
Input: "why is the sky blue?",
|
|
|
}
|
|
|
|
|
|
- res := EmbedTestHelper(ctx, t, req)
|
|
|
+ res, err := EmbedTestHelper(ctx, t, req)
|
|
|
+
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("error: %v", err)
|
|
|
+ }
|
|
|
|
|
|
if len(res.Embeddings) != 1 {
|
|
|
t.Fatalf("expected 1 embedding, got %d", len(res.Embeddings))
|
|
@@ -28,6 +32,10 @@ func TestAllMiniLMEmbed(t *testing.T) {
|
|
|
if len(res.Embeddings[0]) != 384 {
|
|
|
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
|
|
|
}
|
|
|
+
|
|
|
+ if res.Embeddings[0][0] != 0.010071029038540258 {
|
|
|
+ t.Fatalf("expected 0.010071029038540258, got %f", res.Embeddings[0][0])
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
func TestAllMiniLMBatchEmbed(t *testing.T) {
|
|
@@ -39,7 +47,11 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
|
|
|
Input: []string{"why is the sky blue?", "why is the grass green?"},
|
|
|
}
|
|
|
|
|
|
- res := EmbedTestHelper(ctx, t, req)
|
|
|
+ res, err := EmbedTestHelper(ctx, t, req)
|
|
|
+
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("error: %v", err)
|
|
|
+ }
|
|
|
|
|
|
if len(res.Embeddings) != 2 {
|
|
|
t.Fatalf("expected 2 embeddings, got %d", len(res.Embeddings))
|
|
@@ -48,4 +60,77 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
|
|
|
if len(res.Embeddings[0]) != 384 {
|
|
|
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
|
|
|
}
|
|
|
+
|
|
|
+ if res.Embeddings[0][0] != 0.010071029038540258 || res.Embeddings[1][0] != -0.00980270794235093 {
|
|
|
+ t.Fatalf("expected 0.010071029038540258 and -0.00980270794235093, got %f and %f", res.Embeddings[0][0], res.Embeddings[1][0])
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestAllMiniLmEmbedTruncate(t *testing.T) {
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ truncTrue, truncFalse := true, false
|
|
|
+
|
|
|
+ type testReq struct {
|
|
|
+ Name string
|
|
|
+ Request api.EmbedRequest
|
|
|
+ }
|
|
|
+
|
|
|
+ reqs := []testReq{
|
|
|
+ {
|
|
|
+ Name: "Target Truncation",
|
|
|
+ Request: api.EmbedRequest{
|
|
|
+ Model: "all-minilm",
|
|
|
+ Input: "why",
|
|
|
+ },
|
|
|
+ },
|
|
|
+ {
|
|
|
+ Name: "Default Truncate",
|
|
|
+ Request: api.EmbedRequest{
|
|
|
+ Model: "all-minilm",
|
|
|
+ Input: "why is the sky blue?",
|
|
|
+ Options: map[string]any{"num_ctx": 1},
|
|
|
+ },
|
|
|
+ },
|
|
|
+ {
|
|
|
+ Name: "Explicit Truncate",
|
|
|
+ Request: api.EmbedRequest{
|
|
|
+ Model: "all-minilm",
|
|
|
+ Input: "why is the sky blue?",
|
|
|
+ Truncate: &truncTrue,
|
|
|
+ Options: map[string]any{"num_ctx": 1},
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ res := make(map[string]*api.EmbedResponse)
|
|
|
+
|
|
|
+ for _, req := range reqs {
|
|
|
+ response, err := EmbedTestHelper(ctx, t, req.Request)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("error: %v", err)
|
|
|
+ }
|
|
|
+ res[req.Name] = response
|
|
|
+ }
|
|
|
+
|
|
|
+ if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] {
|
|
|
+ t.Fatalf("expected default request to truncate correctly")
|
|
|
+ }
|
|
|
+
|
|
|
+ if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] {
|
|
|
+ t.Fatalf("expected default request and truncate true request to be the same")
|
|
|
+ }
|
|
|
+
|
|
|
+ // check that truncate set to false returns an error if context length is exceeded
|
|
|
+ _, err := EmbedTestHelper(ctx, t, api.EmbedRequest{
|
|
|
+ Model: "all-minilm",
|
|
|
+ Input: "why is the sky blue?",
|
|
|
+ Truncate: &truncFalse,
|
|
|
+ Options: map[string]any{"num_ctx": 1},
|
|
|
+ })
|
|
|
+
|
|
|
+ if err == nil {
|
|
|
+ t.Fatalf("expected error, got nil")
|
|
|
+ }
|
|
|
}
|