Pārlūkot izejas kodu

add create, pull, and push

Patrick Devine 1 gadu atpakaļ
vecāks
revīzija
6e2be5a8a0
3 mainītis faili ar 114 papildinājumiem un 12 dzēšanām
  1. 33 10
      api/client.go
  2. 80 1
      cmd/cmd.go
  3. 1 1
      server/routes.go

+ 33 - 10
api/client.go

@@ -107,15 +107,38 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn Generate
 type PullProgressFunc func(PullProgress) error
 
 func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error {
-	/*
-		return c.stream(ctx, http.MethodPost, "/api/pull", req, func(bts []byte) error {
-			var resp PullProgress
-			if err := json.Unmarshal(bts, &resp); err != nil {
-				return err
-			}
+	return c.stream(ctx, http.MethodPost, "/api/pull", req, func(bts []byte) error {
+		var resp PullProgress
+		if err := json.Unmarshal(bts, &resp); err != nil {
+			return err
+		}
 
-			return fn(resp)
-		})
-	*/
-	return nil
+		return fn(resp)
+	})
+}
+
+type PushProgressFunc func(PushProgress) error
+
+func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc) error {
+	return c.stream(ctx, http.MethodPost, "/api/push", req, func(bts []byte) error {
+		var resp PushProgress
+		if err := json.Unmarshal(bts, &resp); err != nil {
+			return err
+		}
+
+		return fn(resp)
+	})
+}
+
+type CreateProgressFunc func(CreateProgress) error
+
+func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
+	return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
+		var resp CreateProgress
+		if err := json.Unmarshal(bts, &resp); err != nil {
+			return err
+		}
+
+		return fn(resp)
+	})
 }

+ 80 - 1
cmd/cmd.go

@@ -30,6 +30,23 @@ func cacheDir() string {
 	return filepath.Join(home, ".ollama")
 }
 
+func create(cmd *cobra.Command, args []string) error {
+	filename, _ := cmd.Flags().GetString("file")
+	client := api.NewClient()
+
+	request := api.CreateRequest{Name: args[0], Path: filename}
+	fn := func(resp api.CreateProgress) error {
+		fmt.Println(resp.Status)
+		return nil
+	}
+
+	if err := client.Create(context.Background(), &request, fn); err != nil {
+		return err
+	}
+
+	return nil
+}
+
 func RunRun(cmd *cobra.Command, args []string) error {
 	_, err := os.Stat(args[0])
 	switch {
@@ -51,8 +68,37 @@ func RunRun(cmd *cobra.Command, args []string) error {
 	return RunGenerate(cmd, args)
 }
 
+func push(cmd *cobra.Command, args []string) error {
+	client := api.NewClient()
+
+	request := api.PushRequest{Name: args[0]}
+	fn := func(resp api.PushProgress) error {
+		fmt.Println(resp.Status)
+		return nil
+	}
+
+	if err := client.Push(context.Background(), &request, fn); err != nil {
+		return err
+	}
+	return nil
+}
+
+func RunPull(cmd *cobra.Command, args []string) error {
+	return pull(args[0])
+}
+
 func pull(model string) error {
-	// TODO add this back
+	client := api.NewClient()
+
+	request := api.PullRequest{Name: model}
+	fn := func(resp api.PullProgress) error {
+		fmt.Println(resp.Status)
+		return nil
+	}
+
+	if err := client.Pull(context.Background(), &request, fn); err != nil {
+		return err
+	}
 	return nil
 }
 
@@ -199,6 +245,15 @@ func NewCLI() *cobra.Command {
 
 	cobra.EnableCommandSorting = false
 
+	createCmd := &cobra.Command{
+		Use:   "create MODEL",
+		Short: "Create a model from a Modelfile",
+		Args:  cobra.MinimumNArgs(1),
+		RunE:  create,
+	}
+
+	createCmd.Flags().StringP("file", "f", "Modelfile", "Name of the Modelfile (default \"Modelfile\")")
+
 	runCmd := &cobra.Command{
 		Use:   "run MODEL [PROMPT]",
 		Short: "Run a model",
@@ -215,9 +270,33 @@ func NewCLI() *cobra.Command {
 		RunE:    RunServer,
 	}
 
+	pullCmd := &cobra.Command{
+		Use:   "pull MODEL",
+		Short: "Pull a model from a registry",
+		Args:  cobra.MinimumNArgs(1),
+		RunE:  RunPull,
+	}
+
+	pushCmd := &cobra.Command{
+		Use:   "push MODEL",
+		Short: "Push a model to a registry",
+		Args:  cobra.MinimumNArgs(1),
+		RunE:  push,
+	}
+
+	rootCmd.AddCommand(
+		serveCmd,
+		createCmd,
+		runCmd,
+		pullCmd,
+		pushCmd,
+	)
+
 	rootCmd.AddCommand(
 		serveCmd,
+		createCmd,
 		runCmd,
+		pullCmd,
 	)
 
 	return rootCmd

+ 1 - 1
server/routes.go

@@ -116,7 +116,7 @@ func pull(c *gin.Context) {
 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 			return
 		}
-	}
+	}()
 
 	streamResponse(c, ch)
 }