瀏覽代碼

add API tests for list handler (#1535)

Patrick Devine 1 年之前
父節點
當前提交
0174665d0e
共有 1 個文件被更改,包括 61 次插入0 次删除
  1. 61 0
      server/routes_test.go

+ 61 - 0
server/routes_test.go

@@ -2,12 +2,19 @@ package server
 
 
 import (
 import (
 	"context"
 	"context"
+	"encoding/json"
+	"fmt"
 	"io"
 	"io"
 	"net/http"
 	"net/http"
 	"net/http/httptest"
 	"net/http/httptest"
+	"os"
+	"strings"
 	"testing"
 	"testing"
 
 
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
+
+	"github.com/jmorganca/ollama/api"
+	"github.com/jmorganca/ollama/parser"
 )
 )
 
 
 func setupServer(t *testing.T) (*Server, error) {
 func setupServer(t *testing.T) (*Server, error) {
@@ -40,6 +47,55 @@ func Test_Routes(t *testing.T) {
 				assert.Equal(t, `{"version":"0.0.0"}`, string(body))
 				assert.Equal(t, `{"version":"0.0.0"}`, string(body))
 			},
 			},
 		},
 		},
+		{
+			Name:   "Tags Handler (no tags)",
+			Method: http.MethodGet,
+			Path:   "/api/tags",
+			Expected: func(t *testing.T, resp *http.Response) {
+				contentType := resp.Header.Get("Content-Type")
+				assert.Equal(t, contentType, "application/json; charset=utf-8")
+				body, err := io.ReadAll(resp.Body)
+				assert.Nil(t, err)
+
+				var modelList api.ListResponse
+
+				err = json.Unmarshal(body, &modelList)
+				assert.Nil(t, err)
+
+				assert.Equal(t, 0, len(modelList.Models))
+			},
+		},
+		{
+			Name:   "Tags Handler (yes tags)",
+			Method: http.MethodGet,
+			Path:   "/api/tags",
+			Setup: func(t *testing.T, req *http.Request) {
+				f, err := os.CreateTemp("", "ollama-modelfile")
+				assert.Nil(t, err)
+				defer os.RemoveAll(f.Name())
+
+				modelfile := strings.NewReader(fmt.Sprintf("FROM %s", f.Name()))
+				commands, err := parser.Parse(modelfile)
+				assert.Nil(t, err)
+				fn := func(resp api.ProgressResponse) {}
+				err = CreateModel(context.TODO(), "test", "", commands, fn)
+				assert.Nil(t, err)
+			},
+			Expected: func(t *testing.T, resp *http.Response) {
+				contentType := resp.Header.Get("Content-Type")
+				assert.Equal(t, contentType, "application/json; charset=utf-8")
+				body, err := io.ReadAll(resp.Body)
+				assert.Nil(t, err)
+
+				var modelList api.ListResponse
+
+				err = json.Unmarshal(body, &modelList)
+				assert.Nil(t, err)
+
+				assert.Equal(t, 1, len(modelList.Models))
+				assert.Equal(t, modelList.Models[0].Name, "test:latest")
+			},
+		},
 	}
 	}
 
 
 	s, err := setupServer(t)
 	s, err := setupServer(t)
@@ -50,6 +106,11 @@ func Test_Routes(t *testing.T) {
 	httpSrv := httptest.NewServer(router)
 	httpSrv := httptest.NewServer(router)
 	t.Cleanup(httpSrv.Close)
 	t.Cleanup(httpSrv.Close)
 
 
+	workDir, err := os.MkdirTemp("", "ollama-test")
+	assert.Nil(t, err)
+	defer os.RemoveAll(workDir)
+	os.Setenv("OLLAMA_MODELS", workDir)
+
 	for _, tc := range testCases {
 	for _, tc := range testCases {
 		u := httpSrv.URL + tc.Path
 		u := httpSrv.URL + tc.Path
 		req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
 		req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)