Просмотр исходного кода

Merge pull request #3713 from ollama/mxyng/modelname

update copy handler to use model.Name
Michael Yang 1 год назад
Родитель
Сommit
3450a57d4a
2 измененных файлов с 30 добавлено и 32 удалено
  1. 12 15
      server/images.go
  2. 18 17
      server/routes.go

+ 12 - 15
server/images.go

@@ -29,6 +29,7 @@ import (
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/llm"
 	"github.com/ollama/ollama/llm"
 	"github.com/ollama/ollama/parser"
 	"github.com/ollama/ollama/parser"
+	"github.com/ollama/ollama/types/model"
 	"github.com/ollama/ollama/version"
 	"github.com/ollama/ollama/version"
 )
 )
 
 
@@ -701,36 +702,32 @@ func convertModel(name, path string, fn func(resp api.ProgressResponse)) (string
 	return path, nil
 	return path, nil
 }
 }
 
 
-func CopyModel(src, dest string) error {
-	srcModelPath := ParseModelPath(src)
-	srcPath, err := srcModelPath.GetManifestPath()
+func CopyModel(src, dst model.Name) error {
+	manifests, err := GetManifestPath()
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 
 
-	destModelPath := ParseModelPath(dest)
-	destPath, err := destModelPath.GetManifestPath()
-	if err != nil {
-		return err
-	}
-	if err := os.MkdirAll(filepath.Dir(destPath), 0o755); err != nil {
+	dstpath := filepath.Join(manifests, dst.FilepathNoBuild())
+	if err := os.MkdirAll(filepath.Dir(dstpath), 0o755); err != nil {
 		return err
 		return err
 	}
 	}
 
 
-	// copy the file
-	input, err := os.ReadFile(srcPath)
+	srcpath := filepath.Join(manifests, src.FilepathNoBuild())
+	srcfile, err := os.Open(srcpath)
 	if err != nil {
 	if err != nil {
-		fmt.Println("Error reading file:", err)
 		return err
 		return err
 	}
 	}
+	defer srcfile.Close()
 
 
-	err = os.WriteFile(destPath, input, 0o644)
+	dstfile, err := os.Create(dstpath)
 	if err != nil {
 	if err != nil {
-		fmt.Println("Error reading file:", err)
 		return err
 		return err
 	}
 	}
+	defer dstfile.Close()
 
 
-	return nil
+	_, err = io.Copy(dstfile, srcfile)
+	return err
 }
 }
 
 
 func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{}, dryRun bool) error {
 func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{}, dryRun bool) error {

+ 18 - 17
server/routes.go

@@ -29,6 +29,7 @@ import (
 	"github.com/ollama/ollama/llm"
 	"github.com/ollama/ollama/llm"
 	"github.com/ollama/ollama/openai"
 	"github.com/ollama/ollama/openai"
 	"github.com/ollama/ollama/parser"
 	"github.com/ollama/ollama/parser"
+	"github.com/ollama/ollama/types/model"
 	"github.com/ollama/ollama/version"
 	"github.com/ollama/ollama/version"
 )
 )
 
 
@@ -788,35 +789,35 @@ func (s *Server) ListModelsHandler(c *gin.Context) {
 }
 }
 
 
 func (s *Server) CopyModelHandler(c *gin.Context) {
 func (s *Server) CopyModelHandler(c *gin.Context) {
-	var req api.CopyRequest
-	err := c.ShouldBindJSON(&req)
-	switch {
-	case errors.Is(err, io.EOF):
+	var r api.CopyRequest
+	if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
 		return
 		return
-	case err != nil:
+	} else if err != nil {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		return
 		return
 	}
 	}
 
 
-	if req.Source == "" || req.Destination == "" {
-		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "source add destination are required"})
-		return
+	src := model.ParseName(r.Source)
+	if !src.IsValid() {
+		_ = c.Error(fmt.Errorf("source %q is invalid", r.Source))
 	}
 	}
 
 
-	if err := ParseModelPath(req.Destination).Validate(); err != nil {
-		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
-		return
+	dst := model.ParseName(r.Destination)
+	if !dst.IsValid() {
+		_ = c.Error(fmt.Errorf("destination %q is invalid", r.Destination))
 	}
 	}
 
 
-	if err := CopyModel(req.Source, req.Destination); err != nil {
-		if os.IsNotExist(err) {
-			c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Source)})
-		} else {
-			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-		}
+	if len(c.Errors) > 0 {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": c.Errors.Errors()})
 		return
 		return
 	}
 	}
+
+	if err := CopyModel(src, dst); errors.Is(err, os.ErrNotExist) {
+		c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found", r.Source)})
+	} else if err != nil {
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+	}
 }
 }
 
 
 func (s *Server) HeadBlobHandler(c *gin.Context) {
 func (s *Server) HeadBlobHandler(c *gin.Context) {