|
@@ -5,6 +5,7 @@ package integration
|
|
|
import (
|
|
|
"bytes"
|
|
|
"context"
|
|
|
+ "encoding/json"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
"io"
|
|
@@ -24,6 +25,7 @@ import (
|
|
|
|
|
|
"github.com/ollama/ollama/api"
|
|
|
"github.com/ollama/ollama/app/lifecycle"
|
|
|
+ "github.com/stretchr/testify/assert"
|
|
|
"github.com/stretchr/testify/require"
|
|
|
)
|
|
|
|
|
@@ -285,6 +287,7 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
|
|
|
// Generate a set of requests
|
|
|
// By default each request uses orca-mini as the model
|
|
|
func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
|
|
+ stream := false
|
|
|
return []api.GenerateRequest{
|
|
|
{
|
|
|
Model: "orca-mini",
|
|
@@ -336,3 +339,83 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
|
|
[]string{"nitrogen", "oxygen", "carbon", "dioxide"},
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+func EmbeddingTestHelper(ctx context.Context, t *testing.T, client *http.Client, req api.EmbeddingRequest) api.EmbeddingResponse {
|
|
|
+
|
|
|
+ // TODO maybe stuff in an init routine?
|
|
|
+ lifecycle.InitLogging()
|
|
|
+
|
|
|
+ requestJSON, err := json.Marshal(req)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("Error serializing request: %v", err)
|
|
|
+ }
|
|
|
+ defer func() {
|
|
|
+ if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
|
|
|
+ defer serverProcMutex.Unlock()
|
|
|
+ if t.Failed() {
|
|
|
+ fp, err := os.Open(lifecycle.ServerLogFile)
|
|
|
+ if err != nil {
|
|
|
+ slog.Error("failed to open server log", "logfile", lifecycle.ServerLogFile, "error", err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ data, err := io.ReadAll(fp)
|
|
|
+ if err != nil {
|
|
|
+ slog.Error("failed to read server log", "logfile", lifecycle.ServerLogFile, "error", err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ slog.Warn("SERVER LOG FOLLOWS")
|
|
|
+ os.Stderr.Write(data)
|
|
|
+ slog.Warn("END OF SERVER")
|
|
|
+ }
|
|
|
+ err = os.Remove(lifecycle.ServerLogFile)
|
|
|
+ if err != nil && !os.IsNotExist(err) {
|
|
|
+ slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }()
|
|
|
+ scheme, testEndpoint := GetTestEndpoint()
|
|
|
+
|
|
|
+ if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
|
|
|
+ serverProcMutex.Lock()
|
|
|
+ fp, err := os.CreateTemp("", "ollama-server-*.log")
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("failed to generate log file: %s", err)
|
|
|
+ }
|
|
|
+ lifecycle.ServerLogFile = fp.Name()
|
|
|
+ fp.Close()
|
|
|
+ assert.NoError(t, StartServer(ctx, testEndpoint))
|
|
|
+ }
|
|
|
+
|
|
|
+ err = PullIfMissing(ctx, client, scheme, testEndpoint, req.Model)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("Error pulling model: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Make the request and get the response
|
|
|
+ httpReq, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/embeddings", bytes.NewReader(requestJSON))
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("Error creating request: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Set the content type for the request
|
|
|
+ httpReq.Header.Set("Content-Type", "application/json")
|
|
|
+
|
|
|
+ // Make the request with the HTTP client
|
|
|
+ response, err := client.Do(httpReq.WithContext(ctx))
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("Error making request: %v", err)
|
|
|
+ }
|
|
|
+ defer response.Body.Close()
|
|
|
+ body, err := io.ReadAll(response.Body)
|
|
|
+ assert.NoError(t, err)
|
|
|
+ assert.Equal(t, response.StatusCode, 200, string(body))
|
|
|
+
|
|
|
+ // Verify the response is valid JSON
|
|
|
+ var res api.EmbeddingResponse
|
|
|
+ err = json.Unmarshal(body, &res)
|
|
|
+ if err != nil {
|
|
|
+ assert.NoError(t, err, body)
|
|
|
+ }
|
|
|
+
|
|
|
+ return res
|
|
|
+}
|