소스 검색

add API create/copy handlers (#1541)

Patrick Devine 1 년 전
부모
커밋
86b0dd4b16
1개의 변경된 파일75개의 추가작업 그리고 12개의 파일을 삭제
  1. 75 12
      server/routes_test.go

+ 75 - 12
server/routes_test.go

@@ -1,10 +1,12 @@
 package server
 
 import (
+	"bytes"
 	"context"
 	"encoding/json"
 	"fmt"
 	"io"
+
 	"net/http"
 	"net/http/httptest"
 	"os"
@@ -31,6 +33,20 @@ func Test_Routes(t *testing.T) {
 		Setup    func(t *testing.T, req *http.Request)
 		Expected func(t *testing.T, resp *http.Response)
 	}
+	var tempModelFile string
+
+	createTestModel := func(t *testing.T, name string) {
+		f, err := os.CreateTemp("", "ollama-model")
+		assert.Nil(t, err)
+		defer os.RemoveAll(f.Name())
+
+		modelfile := strings.NewReader(fmt.Sprintf("FROM %s", f.Name()))
+		commands, err := parser.Parse(modelfile)
+		assert.Nil(t, err)
+		fn := func(resp api.ProgressResponse) {}
+		err = CreateModel(context.TODO(), name, "", commands, fn)
+		assert.Nil(t, err)
+	}
 
 	testCases := []testCase{
 		{
@@ -70,16 +86,7 @@ func Test_Routes(t *testing.T) {
 			Method: http.MethodGet,
 			Path:   "/api/tags",
 			Setup: func(t *testing.T, req *http.Request) {
-				f, err := os.CreateTemp("", "ollama-modelfile")
-				assert.Nil(t, err)
-				defer os.RemoveAll(f.Name())
-
-				modelfile := strings.NewReader(fmt.Sprintf("FROM %s", f.Name()))
-				commands, err := parser.Parse(modelfile)
-				assert.Nil(t, err)
-				fn := func(resp api.ProgressResponse) {}
-				err = CreateModel(context.TODO(), "test", "", commands, fn)
-				assert.Nil(t, err)
+				createTestModel(t, "test-model")
 			},
 			Expected: func(t *testing.T, resp *http.Response) {
 				contentType := resp.Header.Get("Content-Type")
@@ -88,12 +95,66 @@ func Test_Routes(t *testing.T) {
 				assert.Nil(t, err)
 
 				var modelList api.ListResponse
-
 				err = json.Unmarshal(body, &modelList)
 				assert.Nil(t, err)
 
 				assert.Equal(t, 1, len(modelList.Models))
-				assert.Equal(t, modelList.Models[0].Name, "test:latest")
+				assert.Equal(t, modelList.Models[0].Name, "test-model:latest")
+			},
+		},
+		{
+			Name:   "Create Model Handler",
+			Method: http.MethodPost,
+			Path:   "/api/create",
+			Setup: func(t *testing.T, req *http.Request) {
+				f, err := os.CreateTemp("", "ollama-model")
+				assert.Nil(t, err)
+				tempModelFile = f.Name()
+
+				stream := false
+				createReq := api.CreateRequest{
+					Name:      "t-bone",
+					Modelfile: fmt.Sprintf("FROM %s", f.Name()),
+					Stream:    &stream,
+				}
+				jsonData, err := json.Marshal(createReq)
+				assert.Nil(t, err)
+
+				req.Body = io.NopCloser(bytes.NewReader(jsonData))
+			},
+			Expected: func(t *testing.T, resp *http.Response) {
+				os.RemoveAll(tempModelFile)
+
+				contentType := resp.Header.Get("Content-Type")
+				assert.Equal(t, "application/json", contentType)
+				_, err := io.ReadAll(resp.Body)
+				assert.Nil(t, err)
+				assert.Equal(t, resp.StatusCode, 200)
+
+				model, err := GetModel("t-bone")
+				assert.Nil(t, err)
+				assert.Equal(t, "t-bone:latest", model.ShortName)
+			},
+		},
+		{
+			Name:   "Copy Model Handler",
+			Method: http.MethodPost,
+			Path:   "/api/copy",
+			Setup: func(t *testing.T, req *http.Request) {
+				createTestModel(t, "hamshank")
+				copyReq := api.CopyRequest{
+					Source:      "hamshank",
+					Destination: "beefsteak",
+				}
+				jsonData, err := json.Marshal(copyReq)
+				assert.Nil(t, err)
+
+				req.Body = io.NopCloser(bytes.NewReader(jsonData))
+			},
+			Expected: func(t *testing.T, resp *http.Response) {
+				model, err := GetModel("beefsteak")
+				assert.Nil(t, err)
+				assert.Equal(t, "beefsteak:latest", model.ShortName)
 			},
 		},
 	}
@@ -121,11 +182,13 @@ func Test_Routes(t *testing.T) {
 		}
 
 		resp, err := httpSrv.Client().Do(req)
+		defer resp.Body.Close()
 		assert.Nil(t, err)
 
 		if tc.Expected != nil {
 			tc.Expected(t, resp)
 		}
+
 	}
 
 }