|
@@ -8,6 +8,7 @@ import (
|
|
|
"time"
|
|
|
|
|
|
"github.com/ollama/ollama/api"
|
|
|
+ "github.com/stretchr/testify/require"
|
|
|
)
|
|
|
|
|
|
func TestAllMiniLMEmbed(t *testing.T) {
|
|
@@ -19,7 +20,7 @@ func TestAllMiniLMEmbed(t *testing.T) {
|
|
|
Input: "why is the sky blue?",
|
|
|
}
|
|
|
|
|
|
- res, err := EmbedTestHelper(ctx, t, req)
|
|
|
+ res, err := embedTestHelper(ctx, t, req)
|
|
|
|
|
|
if err != nil {
|
|
|
t.Fatalf("error: %v", err)
|
|
@@ -47,7 +48,7 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
|
|
|
Input: []string{"why is the sky blue?", "why is the grass green?"},
|
|
|
}
|
|
|
|
|
|
- res, err := EmbedTestHelper(ctx, t, req)
|
|
|
+ res, err := embedTestHelper(ctx, t, req)
|
|
|
|
|
|
if err != nil {
|
|
|
t.Fatalf("error: %v", err)
|
|
@@ -107,7 +108,7 @@ func TestAllMiniLmEmbedTruncate(t *testing.T) {
|
|
|
res := make(map[string]*api.EmbedResponse)
|
|
|
|
|
|
for _, req := range reqs {
|
|
|
- response, err := EmbedTestHelper(ctx, t, req.Request)
|
|
|
+ response, err := embedTestHelper(ctx, t, req.Request)
|
|
|
if err != nil {
|
|
|
t.Fatalf("error: %v", err)
|
|
|
}
|
|
@@ -123,7 +124,7 @@ func TestAllMiniLmEmbedTruncate(t *testing.T) {
|
|
|
}
|
|
|
|
|
|
// check that truncate set to false returns an error if context length is exceeded
|
|
|
- _, err := EmbedTestHelper(ctx, t, api.EmbedRequest{
|
|
|
+ _, err := embedTestHelper(ctx, t, api.EmbedRequest{
|
|
|
Model: "all-minilm",
|
|
|
Input: "why is the sky blue?",
|
|
|
Truncate: &truncFalse,
|
|
@@ -134,3 +135,17 @@ func TestAllMiniLmEmbedTruncate(t *testing.T) {
|
|
|
t.Fatalf("expected error, got nil")
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
|
|
|
+ client, _, cleanup := InitServerConnection(ctx, t)
|
|
|
+ defer cleanup()
|
|
|
+ require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
|
|
+
|
|
|
+ response, err := client.Embed(ctx, &req)
|
|
|
+
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ return response, nil
|
|
|
+}
|