Browse Source

Merge pull request #4413 from ollama/mxyng/name-check

check if name exists before create/pull/copy
Michael Yang 11 tháng trước cách đây
mục cha
commit
96bc232b43
2 tập tin đã thay đổi với 135 bổ sung30 xóa
  1. 34 8
      server/routes.go
  2. 101 22
      server/routes_test.go

+ 34 - 8
server/routes.go

@@ -421,13 +421,14 @@ func (s *Server) PullModelHandler(c *gin.Context) {
 		return
 	}
 
-	var model string
-	if req.Model != "" {
-		model = req.Model
-	} else if req.Name != "" {
-		model = req.Name
-	} else {
-		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
+	name := model.ParseName(cmp.Or(req.Model, req.Name))
+	if !name.IsValid() {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid model name"})
+		return
+	}
+
+	if err := checkNameExists(name); err != nil {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		return
 	}
 
@@ -445,7 +446,7 @@ func (s *Server) PullModelHandler(c *gin.Context) {
 		ctx, cancel := context.WithCancel(c.Request.Context())
 		defer cancel()
 
-		if err := PullModel(ctx, model, regOpts, fn); err != nil {
+		if err := PullModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil {
 			ch <- gin.H{"error": err.Error()}
 		}
 	}()
@@ -507,6 +508,21 @@ func (s *Server) PushModelHandler(c *gin.Context) {
 	streamResponse(c, ch)
 }
 
+func checkNameExists(name model.Name) error {
+	names, err := Manifests()
+	if err != nil {
+		return err
+	}
+
+	for n := range names {
+		if strings.EqualFold(n.Filepath(), name.Filepath()) && n != name {
+			return fmt.Errorf("a model with that name already exists")
+		}
+	}
+
+	return nil
+}
+
 func (s *Server) CreateModelHandler(c *gin.Context) {
 	var req api.CreateRequest
 	if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
@@ -523,6 +539,11 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
 		return
 	}
 
+	if err := checkNameExists(name); err != nil {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		return
+	}
+
 	if req.Path == "" && req.Modelfile == "" {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
 		return
@@ -771,6 +792,11 @@ func (s *Server) CopyModelHandler(c *gin.Context) {
 		return
 	}
 
+	if err := checkNameExists(dst); err != nil {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		return
+	}
+
 	if err := CopyModel(src, dst); errors.Is(err, os.ErrNotExist) {
 		c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found", r.Source)})
 	} else if err != nil {

+ 101 - 22
server/routes_test.go

@@ -21,35 +21,35 @@ import (
 	"github.com/ollama/ollama/version"
 )
 
-func Test_Routes(t *testing.T) {
-	type testCase struct {
-		Name     string
-		Method   string
-		Path     string
-		Setup    func(t *testing.T, req *http.Request)
-		Expected func(t *testing.T, resp *http.Response)
-	}
+func createTestFile(t *testing.T, name string) string {
+	t.Helper()
 
-	createTestFile := func(t *testing.T, name string) string {
-		t.Helper()
+	f, err := os.CreateTemp(t.TempDir(), name)
+	assert.Nil(t, err)
+	defer f.Close()
 
-		f, err := os.CreateTemp(t.TempDir(), name)
-		assert.Nil(t, err)
-		defer f.Close()
+	err = binary.Write(f, binary.LittleEndian, []byte("GGUF"))
+	assert.Nil(t, err)
 
-		err = binary.Write(f, binary.LittleEndian, []byte("GGUF"))
-		assert.Nil(t, err)
+	err = binary.Write(f, binary.LittleEndian, uint32(3))
+	assert.Nil(t, err)
 
-		err = binary.Write(f, binary.LittleEndian, uint32(3))
-		assert.Nil(t, err)
+	err = binary.Write(f, binary.LittleEndian, uint64(0))
+	assert.Nil(t, err)
 
-		err = binary.Write(f, binary.LittleEndian, uint64(0))
-		assert.Nil(t, err)
+	err = binary.Write(f, binary.LittleEndian, uint64(0))
+	assert.Nil(t, err)
 
-		err = binary.Write(f, binary.LittleEndian, uint64(0))
-		assert.Nil(t, err)
+	return f.Name()
+}
 
-		return f.Name()
+func Test_Routes(t *testing.T) {
+	type testCase struct {
+		Name     string
+		Method   string
+		Path     string
+		Setup    func(t *testing.T, req *http.Request)
+		Expected func(t *testing.T, resp *http.Response)
 	}
 
 	createTestModel := func(t *testing.T, name string) {
@@ -237,3 +237,82 @@ func Test_Routes(t *testing.T) {
 		})
 	}
 }
+
+func TestCase(t *testing.T) {
+	t.Setenv("OLLAMA_MODELS", t.TempDir())
+
+	cases := []string{
+		"mistral",
+		"llama3:latest",
+		"library/phi3:q4_0",
+		"registry.ollama.ai/library/gemma:q5_K_M",
+		// TODO: host:port currently fails on windows (#4107)
+		// "localhost:5000/alice/bob:latest",
+	}
+
+	var s Server
+	for _, tt := range cases {
+		t.Run(tt, func(t *testing.T) {
+			w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
+				Name:      tt,
+				Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
+				Stream:    &stream,
+			})
+
+			if w.Code != http.StatusOK {
+				t.Fatalf("expected status 200 got %d", w.Code)
+			}
+
+			expect, err := json.Marshal(map[string]string{"error": "a model with that name already exists"})
+			if err != nil {
+				t.Fatal(err)
+			}
+
+			t.Run("create", func(t *testing.T) {
+				w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
+					Name:      strings.ToUpper(tt),
+					Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
+					Stream:    &stream,
+				})
+
+				if w.Code != http.StatusBadRequest {
+					t.Fatalf("expected status 500 got %d", w.Code)
+				}
+
+				if !bytes.Equal(w.Body.Bytes(), expect) {
+					t.Fatalf("expected error %s got %s", expect, w.Body.String())
+				}
+			})
+
+			t.Run("pull", func(t *testing.T) {
+				w := createRequest(t, s.PullModelHandler, api.PullRequest{
+					Name:   strings.ToUpper(tt),
+					Stream: &stream,
+				})
+
+				if w.Code != http.StatusBadRequest {
+					t.Fatalf("expected status 500 got %d", w.Code)
+				}
+
+				if !bytes.Equal(w.Body.Bytes(), expect) {
+					t.Fatalf("expected error %s got %s", expect, w.Body.String())
+				}
+			})
+
+			t.Run("copy", func(t *testing.T) {
+				w := createRequest(t, s.CopyModelHandler, api.CopyRequest{
+					Source:      tt,
+					Destination: strings.ToUpper(tt),
+				})
+
+				if w.Code != http.StatusBadRequest {
+					t.Fatalf("expected status 500 got %d", w.Code)
+				}
+
+				if !bytes.Equal(w.Body.Bytes(), expect) {
+					t.Fatalf("expected error %s got %s", expect, w.Body.String())
+				}
+			})
+		})
+	}
+}