瀏覽代碼

add rm command for models (#151)

Patrick Devine 1 年之前
父節點
當前提交
e7a393de54
共有 5 個文件被更改,包括 166 次插入25 次删除
  1. 13 0
      api/client.go
  2. 8 4
      api/types.go
  3. 34 11
      cmd/cmd.go
  4. 77 0
      server/images.go
  5. 34 10
      server/routes.go

+ 13 - 0
api/client.go

@@ -210,3 +210,16 @@ func (c *Client) List(ctx context.Context) (*ListResponse, error) {
 	}
 	return &lr, nil
 }
+
+type DeleteProgressFunc func(ProgressResponse) error
+
+func (c *Client) Delete(ctx context.Context, req *DeleteRequest, fn DeleteProgressFunc) error {
+	return c.stream(ctx, http.MethodDelete, "/api/delete", req, func(bts []byte) error {
+		var resp ProgressResponse
+		if err := json.Unmarshal(bts, &resp); err != nil {
+			return err
+		}
+
+		return fn(resp)
+	})
+}

+ 8 - 4
api/types.go

@@ -37,6 +37,10 @@ type CreateProgress struct {
 	Status string `json:"status"`
 }
 
+type DeleteRequest struct {
+	Name string `json:"name"`
+}
+
 type PullRequest struct {
 	Name     string `json:"name"`
 	Username string `json:"username"`
@@ -44,10 +48,10 @@ type PullRequest struct {
 }
 
 type ProgressResponse struct {
-	Status    string  `json:"status"`
-	Digest    string  `json:"digest,omitempty"`
-	Total     int     `json:"total,omitempty"`
-	Completed int     `json:"completed,omitempty"`
+	Status    string `json:"status"`
+	Digest    string `json:"digest,omitempty"`
+	Total     int    `json:"total,omitempty"`
+	Completed int    `json:"completed,omitempty"`
 }
 
 type PushRequest struct {

+ 34 - 11
cmd/cmd.go

@@ -25,7 +25,7 @@ import (
 	"github.com/jmorganca/ollama/server"
 )
 
-func create(cmd *cobra.Command, args []string) error {
+func CreateHandler(cmd *cobra.Command, args []string) error {
 	filename, _ := cmd.Flags().GetString("file")
 	filename, err := filepath.Abs(filename)
 	if err != nil {
@@ -59,7 +59,7 @@ func create(cmd *cobra.Command, args []string) error {
 	return nil
 }
 
-func RunRun(cmd *cobra.Command, args []string) error {
+func RunHandler(cmd *cobra.Command, args []string) error {
 	mp := server.ParseModelPath(args[0])
 	fp, err := mp.GetManifestPath(false)
 	if err != nil {
@@ -86,7 +86,7 @@ func RunRun(cmd *cobra.Command, args []string) error {
 	return RunGenerate(cmd, args)
 }
 
-func push(cmd *cobra.Command, args []string) error {
+func PushHandler(cmd *cobra.Command, args []string) error {
 	client := api.NewClient()
 
 	request := api.PushRequest{Name: args[0]}
@@ -101,7 +101,7 @@ func push(cmd *cobra.Command, args []string) error {
 	return nil
 }
 
-func list(cmd *cobra.Command, args []string) error {
+func ListHandler(cmd *cobra.Command, args []string) error {
 	client := api.NewClient()
 
 	models, err := client.List(context.Background())
@@ -131,7 +131,22 @@ func list(cmd *cobra.Command, args []string) error {
 	return nil
 }
 
-func RunPull(cmd *cobra.Command, args []string) error {
+func DeleteHandler(cmd *cobra.Command, args []string) error {
+	client := api.NewClient()
+
+	request := api.DeleteRequest{Name: args[0]}
+	fn := func(resp api.ProgressResponse) error {
+		fmt.Println(resp.Status)
+		return nil
+	}
+
+	if err := client.Delete(context.Background(), &request, fn); err != nil {
+		return err
+	}
+	return nil
+}
+
+func PullHandler(cmd *cobra.Command, args []string) error {
 	return pull(args[0])
 }
 
@@ -290,7 +305,7 @@ func generateInteractive(cmd *cobra.Command, model string) error {
 		switch {
 		case strings.HasPrefix(line, "/list"):
 			args := strings.Fields(line)
-			if err := list(cmd, args[1:]); err != nil {
+			if err := ListHandler(cmd, args[1:]); err != nil {
 				return err
 			}
 
@@ -387,7 +402,7 @@ func NewCLI() *cobra.Command {
 		Use:   "create MODEL",
 		Short: "Create a model from a Modelfile",
 		Args:  cobra.MinimumNArgs(1),
-		RunE:  create,
+		RunE:  CreateHandler,
 	}
 
 	createCmd.Flags().StringP("file", "f", "Modelfile", "Name of the Modelfile (default \"Modelfile\")")
@@ -396,7 +411,7 @@ func NewCLI() *cobra.Command {
 		Use:   "run MODEL [PROMPT]",
 		Short: "Run a model",
 		Args:  cobra.MinimumNArgs(1),
-		RunE:  RunRun,
+		RunE:  RunHandler,
 	}
 
 	runCmd.Flags().Bool("verbose", false, "Show timings for response")
@@ -412,21 +427,28 @@ func NewCLI() *cobra.Command {
 		Use:   "pull MODEL",
 		Short: "Pull a model from a registry",
 		Args:  cobra.MinimumNArgs(1),
-		RunE:  RunPull,
+		RunE:  PullHandler,
 	}
 
 	pushCmd := &cobra.Command{
 		Use:   "push MODEL",
 		Short: "Push a model to a registry",
 		Args:  cobra.MinimumNArgs(1),
-		RunE:  push,
+		RunE:  PushHandler,
 	}
 
 	listCmd := &cobra.Command{
 		Use:   "list",
 		Aliases: []string{"ls"},
 		Short: "List models",
-		RunE:  list,
+		RunE:  ListHandler,
+	}
+
+	deleteCmd := &cobra.Command{
+		Use:   "rm",
+		Short: "Remove a model",
+		Args:  cobra.MinimumNArgs(1),
+		RunE:  DeleteHandler,
 	}
 
 	rootCmd.AddCommand(
@@ -436,6 +458,7 @@ func NewCLI() *cobra.Command {
 		pullCmd,
 		pushCmd,
 		listCmd,
+		deleteCmd,
 	)
 
 	return rootCmd

+ 77 - 0
server/images.go

@@ -487,6 +487,83 @@ func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
 	return layer, nil
 }
 
+func DeleteModel(name string, fn func(api.ProgressResponse)) error {
+	mp := ParseModelPath(name)
+
+	manifest, err := GetManifest(mp)
+	if err != nil {
+		fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
+		return err
+	}
+	deleteMap := make(map[string]bool)
+	for _, layer := range manifest.Layers {
+		deleteMap[layer.Digest] = true
+	}
+	deleteMap[manifest.Config.Digest] = true
+
+	fp, err := GetManifestPath()
+	if err != nil {
+		fn(api.ProgressResponse{Status: "problem getting manifest path"})
+		return err
+	}
+	err = filepath.Walk(fp, func(path string, info os.FileInfo, err error) error {
+		if err != nil {
+			fn(api.ProgressResponse{Status: "problem walking manifest dir"})
+			return err
+		}
+		if !info.IsDir() {
+			path := path[len(fp)+1:]
+			slashIndex := strings.LastIndex(path, "/")
+			if slashIndex == -1 {
+				return nil
+			}
+			tag := path[:slashIndex] + ":" + path[slashIndex+1:]
+			fmp := ParseModelPath(tag)
+
+			// skip the manifest we're trying to delete
+			if mp.GetFullTagname() == fmp.GetFullTagname() {
+				return nil
+			}
+
+			// save (i.e. delete from the deleteMap) any files used in other manifests
+			manifest, err := GetManifest(fmp)
+			if err != nil {
+				log.Printf("skipping file: %s", fp)
+				return nil
+			}
+			for _, layer := range manifest.Layers {
+				delete(deleteMap, layer.Digest)
+			}
+			delete(deleteMap, manifest.Config.Digest)
+		}
+		return nil
+	})
+
+	// only delete the files which are still in the deleteMap
+	for k, v := range deleteMap {
+		if v {
+			err := os.Remove(k)
+			if err != nil {
+				log.Printf("couldn't remove file '%s': %v", k, err)
+				continue
+			}
+		}
+	}
+
+	fp, err = mp.GetManifestPath(false)
+	if err != nil {
+		return err
+	}
+	err = os.Remove(fp)
+	if err != nil {
+		log.Printf("couldn't remove manifest file '%s': %v", fp, err)
+		return err
+	}
+	fn(api.ProgressResponse{Status: fmt.Sprintf("deleted '%s'", name)})
+
+	return nil
+}
+
 func PushModel(name, username, password string, fn func(api.ProgressResponse)) error {
 	mp := ParseModelPath(name)
 

+ 34 - 10
server/routes.go

@@ -18,7 +18,7 @@ import (
 	"github.com/jmorganca/ollama/llama"
 )
 
-func generate(c *gin.Context) {
+func GenerateHandler(c *gin.Context) {
 	start := time.Now()
 
 	var req api.GenerateRequest
@@ -78,7 +78,7 @@ func generate(c *gin.Context) {
 	streamResponse(c, ch)
 }
 
-func pull(c *gin.Context) {
+func PullModelHandler(c *gin.Context) {
 	var req api.PullRequest
 	if err := c.ShouldBindJSON(&req); err != nil {
 		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@@ -100,7 +100,7 @@ func pull(c *gin.Context) {
 	streamResponse(c, ch)
 }
 
-func push(c *gin.Context) {
+func PushModelHandler(c *gin.Context) {
 	var req api.PushRequest
 	if err := c.ShouldBindJSON(&req); err != nil {
 		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@@ -122,7 +122,7 @@ func push(c *gin.Context) {
 	streamResponse(c, ch)
 }
 
-func create(c *gin.Context) {
+func CreateModelHandler(c *gin.Context) {
 	var req api.CreateRequest
 	if err := c.ShouldBindJSON(&req); err != nil {
 		c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
@@ -146,7 +146,30 @@ func create(c *gin.Context) {
 	streamResponse(c, ch)
 }
 
-func list(c *gin.Context) {
+func DeleteModelHandler(c *gin.Context) {
+	var req api.DeleteRequest
+	if err := c.ShouldBindJSON(&req); err != nil {
+		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		return
+	}
+
+	ch := make(chan any)
+	go func() {
+		defer close(ch)
+		fn := func(r api.ProgressResponse) {
+			ch <- r
+		}
+
+		if err := DeleteModel(req.Name, fn); err != nil {
+			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+			return
+		}
+	}()
+
+	streamResponse(c, ch)
+}
+
+func ListModelsHandler(c *gin.Context) {
 	var models []api.ListResponseModel
 	fp, err := GetManifestPath()
 	if err != nil {
@@ -199,11 +222,12 @@ func Serve(ln net.Listener) error {
 		c.String(http.StatusOK, "Ollama is running")
 	})
 
-	r.POST("/api/pull", pull)
-	r.POST("/api/generate", generate)
-	r.POST("/api/create", create)
-	r.POST("/api/push", push)
-	r.GET("/api/tags", list)
+	r.POST("/api/pull", PullModelHandler)
+	r.POST("/api/generate", GenerateHandler)
+	r.POST("/api/create", CreateModelHandler)
+	r.POST("/api/push", PushModelHandler)
+	r.GET("/api/tags", ListModelsHandler)
+	r.DELETE("/api/delete", DeleteModelHandler)
 
 	log.Printf("Listening on %s", ln.Addr())
 	s := &http.Server{