浏览代码

Integration tests conditionally pull

If images aren't present, pull them.
Also fixes the expected responses
Daniel Hiltgen 1 年之前
父节点
当前提交
7b6cbc10ec
共有 4 个文件被更改,包括 70 次插入10 次删除
  1. 1 1
      integration/basic_test.go
  2. 1 1
      integration/llm_image_test.go
  3. 7 7
      integration/llm_test.go
  4. 61 1
      integration/utils_test.go

+ 1 - 1
integration/basic_test.go

@@ -12,7 +12,7 @@ import (
 )
 
 func TestOrcaMiniBlueSky(t *testing.T) {
-	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
+	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
 	defer cancel()
 	// Set up the test data
 	req := api.GenerateRequest{

+ 1 - 1
integration/llm_image_test.go

@@ -30,7 +30,7 @@ func TestIntegrationMultimodal(t *testing.T) {
 	}
 
 	resp := "the ollamas"
-	ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
+	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
 	defer cancel()
 	GenerateTestHelper(ctx, t, &http.Client{}, req, []string{resp})
 }

+ 7 - 7
integration/llm_test.go

@@ -40,16 +40,16 @@ var (
 			},
 		},
 	}
-	resp = [2]string{
-		"scattering",
-		"united states thanksgiving",
+	resp = [2][]string{
+		[]string{"sunlight"},
+		[]string{"england", "english", "massachusetts", "pilgrims"},
 	}
 )
 
 func TestIntegrationSimpleOrcaMini(t *testing.T) {
-	ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
+	ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
 	defer cancel()
-	GenerateTestHelper(ctx, t, &http.Client{}, req[0], []string{resp[0]})
+	GenerateTestHelper(ctx, t, &http.Client{}, req[0], resp[0])
 }
 
 // TODO
@@ -59,12 +59,12 @@ func TestIntegrationSimpleOrcaMini(t *testing.T) {
 func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
 	var wg sync.WaitGroup
 	wg.Add(len(req))
-	ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
+	ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
 	defer cancel()
 	for i := 0; i < len(req); i++ {
 		go func(i int) {
 			defer wg.Done()
-			GenerateTestHelper(ctx, t, &http.Client{}, req[i], []string{resp[i]})
+			GenerateTestHelper(ctx, t, &http.Client{}, req[i], resp[i])
 		}(i)
 	}
 	wg.Wait()

+ 61 - 1
integration/utils_test.go

@@ -125,6 +125,55 @@ func StartServer(ctx context.Context, ollamaHost string) error {
 	return nil
 }
 
+func PullIfMissing(ctx context.Context, client *http.Client, scheme, testEndpoint, modelName string) error {
+	slog.Debug("checking status of model", "model", modelName)
+	showReq := &api.ShowRequest{Name: modelName}
+	requestJSON, err := json.Marshal(showReq)
+	if err != nil {
+		return err
+	}
+
+	req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/show", bytes.NewReader(requestJSON))
+	if err != nil {
+		return err
+	}
+
+	// Make the request with the HTTP client
+	response, err := client.Do(req.WithContext(ctx))
+	if err != nil {
+		return err
+	}
+	defer response.Body.Close()
+	if response.StatusCode == 200 {
+		slog.Info("model already present", "model", modelName)
+		return nil
+	}
+	slog.Info("model missing", "status", response.StatusCode)
+
+	pullReq := &api.PullRequest{Name: modelName, Stream: &stream}
+	requestJSON, err = json.Marshal(pullReq)
+	if err != nil {
+		return err
+	}
+
+	req, err = http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/pull", bytes.NewReader(requestJSON))
+	if err != nil {
+		return err
+	}
+	slog.Info("pulling", "model", modelName)
+
+	response, err = client.Do(req.WithContext(ctx))
+	if err != nil {
+		return err
+	}
+	defer response.Body.Close()
+	if response.StatusCode != 200 {
+		return fmt.Errorf("failed to pull model") // TODO more details perhaps
+	}
+	slog.Info("model pulled", "model", modelName)
+	return nil
+}
+
 func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, genReq api.GenerateRequest, anyResp []string) {
 	requestJSON, err := json.Marshal(genReq)
 	if err != nil {
@@ -158,6 +207,11 @@ func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client,
 		assert.NoError(t, StartServer(ctx, testEndpoint))
 	}
 
+	err = PullIfMissing(ctx, client, scheme, testEndpoint, genReq.Model)
+	if err != nil {
+		t.Fatalf("Error pulling model: %v", err)
+	}
+
 	// Make the request and get the response
 	req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/generate", bytes.NewReader(requestJSON))
 	if err != nil {
@@ -172,6 +226,7 @@ func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client,
 	if err != nil {
 		t.Fatalf("Error making request: %v", err)
 	}
+	defer response.Body.Close()
 	body, err := io.ReadAll(response.Body)
 	assert.NoError(t, err)
 	assert.Equal(t, response.StatusCode, 200, string(body))
@@ -184,7 +239,12 @@ func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client,
 	}
 
 	// Verify the response contains the expected data
+	atLeastOne := false
 	for _, resp := range anyResp {
-		assert.Contains(t, strings.ToLower(payload.Response), resp)
+		if strings.Contains(strings.ToLower(payload.Response), resp) {
+			atLeastOne = true
+			break
+		}
 	}
+	assert.True(t, atLeastOne, "none of %v found in %s", anyResp, payload.Response)
 }