Bladeren bron

Revamp go based integration tests

This uplevels the integration tests to run the server which can allow
testing an existing server, or a remote server.
Daniel Hiltgen 1 jaar geleden
bovenliggende
commit
949b6c01e0

+ 11 - 0
integration/README.md

@@ -0,0 +1,11 @@
+# Integration Tests
+
+This directory contains integration tests to exercise Ollama end-to-end to verify behavior
+
+By default, these tests are disabled so `go test ./...` will exercise only unit tests.  To run integration tests you must pass the integration tag.  `go test -tags=integration ./...`
+
+
+The integration tests have 2 modes of operating.
+
+1. By default, they will start the server on a random port, run the tests, and then shutdown the server.
+2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote

+ 28 - 0
integration/basic_test.go

@@ -0,0 +1,28 @@
+//go:build integration
+
+package integration
+
+import (
+	"context"
+	"net/http"
+	"testing"
+	"time"
+
+	"github.com/jmorganca/ollama/api"
+)
+
+func TestOrcaMiniBlueSky(t *testing.T) {
+	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
+	defer cancel()
+	// Set up the test data
+	req := api.GenerateRequest{
+		Model:  "orca-mini",
+		Prompt: "why is the sky blue?",
+		Stream: &stream,
+		Options: map[string]interface{}{
+			"temperature": 0,
+			"seed":        123,
+		},
+	}
+	GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"rayleigh"})
+}

+ 11 - 22
server/llm_image_test.go → integration/llm_image_test.go

@@ -1,49 +1,38 @@
 //go:build integration
 
-package server
+package integration
 
 import (
 	"context"
 	"encoding/base64"
-	"log"
-	"os"
-	"strings"
+	"net/http"
 	"testing"
 	"time"
 
 	"github.com/jmorganca/ollama/api"
-	"github.com/jmorganca/ollama/llm"
-	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/require"
 )
 
 func TestIntegrationMultimodal(t *testing.T) {
-	SkipIFNoTestData(t)
 	image, err := base64.StdEncoding.DecodeString(imageEncoding)
 	require.NoError(t, err)
 	req := api.GenerateRequest{
-		Model:   "llava:7b",
-		Prompt:  "what does the text in this image say?",
-		Options: map[string]interface{}{},
+		Model:  "llava:7b",
+		Prompt: "what does the text in this image say?",
+		Stream: &stream,
+		Options: map[string]interface{}{
+			"seed":        42,
+			"temperature": 0.0,
+		},
 		Images: []api.ImageData{
 			image,
 		},
 	}
+
 	resp := "the ollamas"
-	workDir, err := os.MkdirTemp("", "ollama")
-	require.NoError(t, err)
-	defer os.RemoveAll(workDir)
-	require.NoError(t, llm.Init(workDir))
 	ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
 	defer cancel()
-	opts := api.DefaultOptions()
-	opts.Seed = 42
-	opts.Temperature = 0.0
-	model, llmRunner := PrepareModelForPrompts(t, req.Model, opts)
-	defer llmRunner.Close()
-	response := OneShotPromptResponse(t, ctx, req, model, llmRunner)
-	log.Print(response)
-	assert.Contains(t, strings.ToLower(response), resp)
+	GenerateTestHelper(ctx, t, &http.Client{}, req, []string{resp})
 }
 
 const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb

+ 73 - 0
integration/llm_test.go

@@ -0,0 +1,73 @@
+//go:build integration
+
+package integration
+
+import (
+	"context"
+	"net/http"
+	"sync"
+	"testing"
+	"time"
+
+	"github.com/jmorganca/ollama/api"
+)
+
+// TODO - this would ideally be in the llm package, but that would require some refactoring of interfaces in the server
+//        package to avoid circular dependencies
+
+// WARNING - these tests will fail on mac if you don't manually copy ggml-metal.metal to this dir (./server)
+//
+// TODO - Fix this ^^
+
+var (
+	stream = false
+	req    = [2]api.GenerateRequest{
+		{
+			Model:  "orca-mini",
+			Prompt: "why is the ocean blue?",
+			Stream: &stream,
+			Options: map[string]interface{}{
+				"seed":        42,
+				"temperature": 0.0,
+			},
+		}, {
+			Model:  "orca-mini",
+			Prompt: "what is the origin of the us thanksgiving holiday?",
+			Stream: &stream,
+			Options: map[string]interface{}{
+				"seed":        42,
+				"temperature": 0.0,
+			},
+		},
+	}
+	resp = [2]string{
+		"scattering",
+		"united states thanksgiving",
+	}
+)
+
+func TestIntegrationSimpleOrcaMini(t *testing.T) {
+	ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
+	defer cancel()
+	GenerateTestHelper(ctx, t, &http.Client{}, req[0], []string{resp[0]})
+}
+
+// TODO
+// The server always loads a new runner and closes the old one, which forces serial execution
+// At present this test case fails with concurrency problems.  Eventually we should try to
+// get true concurrency working with n_parallel support in the backend
+func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
+	var wg sync.WaitGroup
+	wg.Add(len(req))
+	ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
+	defer cancel()
+	for i := 0; i < len(req); i++ {
+		go func(i int) {
+			defer wg.Done()
+			GenerateTestHelper(ctx, t, &http.Client{}, req[i], []string{resp[i]})
+		}(i)
+	}
+	wg.Wait()
+}
+
+// TODO - create a parallel test with 2 different models once we support concurrency

+ 190 - 0
integration/utils_test.go

@@ -0,0 +1,190 @@
+//go:build integration
+
+package integration
+
+import (
+	"bytes"
+	"context"
+	"encoding/json"
+	"fmt"
+	"io"
+	"log/slog"
+	"math/rand"
+	"net"
+	"net/http"
+	"os"
+	"path/filepath"
+	"runtime"
+	"strconv"
+	"strings"
+	"sync"
+	"testing"
+	"time"
+
+	"github.com/jmorganca/ollama/api"
+	"github.com/jmorganca/ollama/app/lifecycle"
+	"github.com/stretchr/testify/assert"
+)
+
+func FindPort() string {
+	port := 0
+	if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
+		var l *net.TCPListener
+		if l, err = net.ListenTCP("tcp", a); err == nil {
+			port = l.Addr().(*net.TCPAddr).Port
+			l.Close()
+		}
+	}
+	if port == 0 {
+		port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
+	}
+	return strconv.Itoa(port)
+}
+
+func GetTestEndpoint() (string, string) {
+	defaultPort := "11434"
+	ollamaHost := os.Getenv("OLLAMA_HOST")
+
+	scheme, hostport, ok := strings.Cut(ollamaHost, "://")
+	if !ok {
+		scheme, hostport = "http", ollamaHost
+	}
+
+	// trim trailing slashes
+	hostport = strings.TrimRight(hostport, "/")
+
+	host, port, err := net.SplitHostPort(hostport)
+	if err != nil {
+		host, port = "127.0.0.1", defaultPort
+		if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil {
+			host = ip.String()
+		} else if hostport != "" {
+			host = hostport
+		}
+	}
+
+	if os.Getenv("OLLAMA_TEST_EXISTING") == "" && port == defaultPort {
+		port = FindPort()
+	}
+
+	url := fmt.Sprintf("%s:%s", host, port)
+	slog.Info("server connection", "url", url)
+	return scheme, url
+}
+
+// TODO make fanicier, grab logs, etc.
+var serverMutex sync.Mutex
+var serverReady bool
+
+func StartServer(ctx context.Context, ollamaHost string) error {
+	// Make sure the server has been built
+	CLIName, err := filepath.Abs("../ollama")
+	if err != nil {
+		return err
+	}
+
+	if runtime.GOOS == "windows" {
+		CLIName += ".exe"
+	}
+	_, err = os.Stat(CLIName)
+	if err != nil {
+		return fmt.Errorf("CLI missing, did you forget to build first?  %w", err)
+	}
+	serverMutex.Lock()
+	defer serverMutex.Unlock()
+	if serverReady {
+		return nil
+	}
+
+	if tmp := os.Getenv("OLLAMA_HOST"); tmp != ollamaHost {
+		slog.Info("setting env", "OLLAMA_HOST", ollamaHost)
+		os.Setenv("OLLAMA_HOST", ollamaHost)
+	}
+
+	slog.Info("starting server", "url", ollamaHost)
+	done, err := lifecycle.SpawnServer(ctx, "../ollama")
+	if err != nil {
+		return fmt.Errorf("failed to start server: %w", err)
+	}
+
+	go func() {
+		<-ctx.Done()
+		serverMutex.Lock()
+		defer serverMutex.Unlock()
+		exitCode := <-done
+		if exitCode > 0 {
+			slog.Warn("server failure", "exit", exitCode)
+		}
+		serverReady = false
+	}()
+
+	// TODO wait only long enough for the server to be responsive...
+	time.Sleep(500 * time.Millisecond)
+
+	serverReady = true
+	return nil
+}
+
+func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, genReq api.GenerateRequest, anyResp []string) {
+	requestJSON, err := json.Marshal(genReq)
+	if err != nil {
+		t.Fatalf("Error serializing request: %v", err)
+	}
+	defer func() {
+		if t.Failed() && os.Getenv("OLLAMA_TEST_EXISTING") == "" {
+			// TODO
+			fp, err := os.Open(lifecycle.ServerLogFile)
+			if err != nil {
+				slog.Error("failed to open server log", "logfile", lifecycle.ServerLogFile, "error", err)
+				return
+			}
+			data, err := io.ReadAll(fp)
+			if err != nil {
+				slog.Error("failed to read server log", "logfile", lifecycle.ServerLogFile, "error", err)
+				return
+			}
+			slog.Warn("SERVER LOG FOLLOWS")
+			os.Stderr.Write(data)
+			slog.Warn("END OF SERVER")
+		}
+		err = os.Remove(lifecycle.ServerLogFile)
+		if err != nil && !os.IsNotExist(err) {
+			slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
+		}
+	}()
+	scheme, testEndpoint := GetTestEndpoint()
+
+	if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
+		assert.NoError(t, StartServer(ctx, testEndpoint))
+	}
+
+	// Make the request and get the response
+	req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/generate", bytes.NewReader(requestJSON))
+	if err != nil {
+		t.Fatalf("Error creating request: %v", err)
+	}
+
+	// Set the content type for the request
+	req.Header.Set("Content-Type", "application/json")
+
+	// Make the request with the HTTP client
+	response, err := client.Do(req.WithContext(ctx))
+	if err != nil {
+		t.Fatalf("Error making request: %v", err)
+	}
+	body, err := io.ReadAll(response.Body)
+	assert.NoError(t, err)
+	assert.Equal(t, response.StatusCode, 200, string(body))
+
+	// Verify the response is valid JSON
+	var payload api.GenerateResponse
+	err = json.Unmarshal(body, &payload)
+	if err != nil {
+		assert.NoError(t, err, body)
+	}
+
+	// Verify the response contains the expected data
+	for _, resp := range anyResp {
+		assert.Contains(t, strings.ToLower(payload.Response), resp)
+	}
+}

+ 0 - 41
scripts/setup_integration_tests.sh

@@ -1,41 +0,0 @@
-#!/bin/bash
-
-# This script sets up integration tests which run the full stack to verify
-# inference locally
-#
-# To run the relevant tests use
-# go test -tags=integration ./server
-set -e
-set -o pipefail
-
-REPO=$(dirname $0)/../
-export OLLAMA_MODELS=${REPO}/test_data/models
-REGISTRY_SCHEME=https
-REGISTRY=registry.ollama.ai
-TEST_MODELS=("library/orca-mini:latest" "library/llava:7b")
-ACCEPT_HEADER="Accept: application/vnd.docker.distribution.manifest.v2+json"
-
-for model in ${TEST_MODELS[@]}; do
-    TEST_MODEL=$(echo ${model} | cut -f1 -d:)
-    TEST_MODEL_TAG=$(echo ${model} | cut -f2 -d:)
-    mkdir -p ${OLLAMA_MODELS}/manifests/${REGISTRY}/${TEST_MODEL}/
-    mkdir -p ${OLLAMA_MODELS}/blobs/
-
-    echo "Pulling manifest for ${TEST_MODEL}:${TEST_MODEL_TAG}"
-    curl -s --header "${ACCEPT_HEADER}" \
-        -o ${OLLAMA_MODELS}/manifests/${REGISTRY}/${TEST_MODEL}/${TEST_MODEL_TAG} \
-        ${REGISTRY_SCHEME}://${REGISTRY}/v2/${TEST_MODEL}/manifests/${TEST_MODEL_TAG}
-
-    CFG_HASH=$(cat ${OLLAMA_MODELS}/manifests/${REGISTRY}/${TEST_MODEL}/${TEST_MODEL_TAG} | jq -r ".config.digest")
-    echo "Pulling config blob ${CFG_HASH}"
-    curl -L -C - --header "${ACCEPT_HEADER}" \
-        -o ${OLLAMA_MODELS}/blobs/${CFG_HASH} \
-        ${REGISTRY_SCHEME}://${REGISTRY}/v2/${TEST_MODEL}/blobs/${CFG_HASH}
-
-    for LAYER in $(cat ${OLLAMA_MODELS}/manifests/${REGISTRY}/${TEST_MODEL}/${TEST_MODEL_TAG} | jq -r ".layers[].digest"); do
-        echo "Pulling blob ${LAYER}"
-        curl -L -C - --header "${ACCEPT_HEADER}" \
-            -o ${OLLAMA_MODELS}/blobs/${LAYER} \
-            ${REGISTRY_SCHEME}://${REGISTRY}/v2/${TEST_MODEL}/blobs/${LAYER}
-    done
-done

+ 0 - 123
server/llm_test.go

@@ -1,123 +0,0 @@
-//go:build integration
-
-package server
-
-import (
-	"context"
-	"os"
-	"strings"
-	"sync"
-	"testing"
-	"time"
-
-	"github.com/stretchr/testify/assert"
-	"github.com/stretchr/testify/require"
-
-	"github.com/jmorganca/ollama/api"
-	"github.com/jmorganca/ollama/llm"
-)
-
-// TODO - this would ideally be in the llm package, but that would require some refactoring of interfaces in the server
-//        package to avoid circular dependencies
-
-// WARNING - these tests will fail on mac if you don't manually copy ggml-metal.metal to this dir (./server)
-//
-// TODO - Fix this ^^
-
-var (
-	req = [2]api.GenerateRequest{
-		{
-			Model:   "orca-mini",
-			Prompt:  "tell me a short story about agi?",
-			Options: map[string]interface{}{},
-		}, {
-			Model:   "orca-mini",
-			Prompt:  "what is the origin of the us thanksgiving holiday?",
-			Options: map[string]interface{}{},
-		},
-	}
-	resp = [2]string{
-		"once upon a time",
-		"united states thanksgiving",
-	}
-)
-
-func TestIntegrationSimpleOrcaMini(t *testing.T) {
-	SkipIFNoTestData(t)
-	workDir, err := os.MkdirTemp("", "ollama")
-	require.NoError(t, err)
-	defer os.RemoveAll(workDir)
-	require.NoError(t, llm.Init(workDir))
-	ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
-	defer cancel()
-	opts := api.DefaultOptions()
-	opts.Seed = 42
-	opts.Temperature = 0.0
-	model, llmRunner := PrepareModelForPrompts(t, req[0].Model, opts)
-	defer llmRunner.Close()
-	response := OneShotPromptResponse(t, ctx, req[0], model, llmRunner)
-	assert.Contains(t, strings.ToLower(response), resp[0])
-}
-
-// TODO
-// The server always loads a new runner and closes the old one, which forces serial execution
-// At present this test case fails with concurrency problems.  Eventually we should try to
-// get true concurrency working with n_parallel support in the backend
-func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
-	SkipIFNoTestData(t)
-
-	t.Skip("concurrent prediction on single runner not currently supported")
-
-	workDir, err := os.MkdirTemp("", "ollama")
-	require.NoError(t, err)
-	defer os.RemoveAll(workDir)
-	require.NoError(t, llm.Init(workDir))
-	ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
-	defer cancel()
-	opts := api.DefaultOptions()
-	opts.Seed = 42
-	opts.Temperature = 0.0
-	var wg sync.WaitGroup
-	wg.Add(len(req))
-	model, llmRunner := PrepareModelForPrompts(t, req[0].Model, opts)
-	defer llmRunner.Close()
-	for i := 0; i < len(req); i++ {
-		go func(i int) {
-			defer wg.Done()
-			response := OneShotPromptResponse(t, ctx, req[i], model, llmRunner)
-			t.Logf("Prompt: %s\nResponse: %s", req[0].Prompt, response)
-			assert.Contains(t, strings.ToLower(response), resp[i], "error in thread %d (%s)", i, req[i].Prompt)
-		}(i)
-	}
-	wg.Wait()
-}
-
-func TestIntegrationConcurrentRunnersOrcaMini(t *testing.T) {
-	SkipIFNoTestData(t)
-	workDir, err := os.MkdirTemp("", "ollama")
-	require.NoError(t, err)
-	defer os.RemoveAll(workDir)
-	require.NoError(t, llm.Init(workDir))
-	ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
-	defer cancel()
-	opts := api.DefaultOptions()
-	opts.Seed = 42
-	opts.Temperature = 0.0
-	var wg sync.WaitGroup
-	wg.Add(len(req))
-
-	t.Logf("Running %d concurrently", len(req))
-	for i := 0; i < len(req); i++ {
-		go func(i int) {
-			defer wg.Done()
-			model, llmRunner := PrepareModelForPrompts(t, req[0].Model, opts)
-			defer llmRunner.Close()
-			response := OneShotPromptResponse(t, ctx, req[i], model, llmRunner)
-			t.Logf("Prompt: %s\nResponse: %s", req[0].Prompt, response)
-			assert.Contains(t, strings.ToLower(response), resp[i], "error in thread %d (%s)", i, req[i].Prompt)
-		}(i)
-	}
-	wg.Wait()
-}
-
-// TODO - create a parallel test with 2 different models once we support concurrency

+ 0 - 75
server/llm_utils_test.go

@@ -1,75 +0,0 @@
-//go:build integration
-
-package server
-
-import (
-	"context"
-	"errors"
-	"os"
-	"path"
-	"runtime"
-	"testing"
-
-	"github.com/jmorganca/ollama/api"
-	"github.com/jmorganca/ollama/llm"
-	"github.com/stretchr/testify/require"
-)
-
-func SkipIFNoTestData(t *testing.T) {
-	modelDir := getModelDir()
-	if _, err := os.Stat(modelDir); errors.Is(err, os.ErrNotExist) {
-		t.Skipf("%s does not exist - skipping integration tests", modelDir)
-	}
-}
-
-func getModelDir() string {
-	_, filename, _, _ := runtime.Caller(0)
-	return path.Dir(path.Dir(filename) + "/../test_data/models/.")
-}
-
-func PrepareModelForPrompts(t *testing.T, modelName string, opts api.Options) (*Model, llm.LLM) {
-	modelDir := getModelDir()
-	os.Setenv("OLLAMA_MODELS", modelDir)
-	model, err := GetModel(modelName)
-	require.NoError(t, err, "GetModel ")
-	err = opts.FromMap(model.Options)
-	require.NoError(t, err, "opts from model ")
-	runner, err := llm.New("unused", model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
-	require.NoError(t, err, "llm.New failed")
-	return model, runner
-}
-
-func OneShotPromptResponse(t *testing.T, ctx context.Context, req api.GenerateRequest, model *Model, runner llm.LLM) string {
-	prompt, err := model.PreResponsePrompt(PromptVars{
-		System: req.System,
-		Prompt: req.Prompt,
-		First:  len(req.Context) == 0,
-	})
-	require.NoError(t, err, "prompt generation failed")
-	success := make(chan bool, 1)
-	response := ""
-	cb := func(r llm.PredictResult) {
-
-		if !r.Done {
-			response += r.Content
-		} else {
-			success <- true
-		}
-	}
-
-	predictReq := llm.PredictOpts{
-		Prompt: prompt,
-		Format: req.Format,
-		Images: req.Images,
-	}
-	err = runner.Predict(ctx, predictReq, cb)
-	require.NoError(t, err, "predict call failed")
-
-	select {
-	case <-ctx.Done():
-		t.Errorf("failed to complete before timeout: \n%s", response)
-		return ""
-	case <-success:
-		return response
-	}
-}