소스 검색

progress cmd

Michael Yang 1 년 전
부모
커밋
1c0e092ead
1개의 변경된 파일120개의 추가작업 그리고 12개의 파일을 삭제
  1. 120 12
      cmd/cmd.go

+ 120 - 12
cmd/cmd.go

@@ -30,6 +30,7 @@ import (
 	"github.com/jmorganca/ollama/api"
 	"github.com/jmorganca/ollama/format"
 	"github.com/jmorganca/ollama/parser"
+	"github.com/jmorganca/ollama/progress"
 	"github.com/jmorganca/ollama/readline"
 	"github.com/jmorganca/ollama/server"
 	"github.com/jmorganca/ollama/version"
@@ -47,6 +48,15 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 		return err
 	}
 
+	p := progress.NewProgress(os.Stderr)
+	defer p.Stop()
+
+	bars := make(map[string]*progress.Bar)
+
+	status := "transferring context"
+	spinner := progress.NewSpinner(status)
+	p.Add(status, spinner)
+
 	modelfile, err := os.ReadFile(filename)
 	if err != nil {
 		return err
@@ -95,16 +105,38 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 		}
 	}
 
-	request := api.CreateRequest{Name: args[0], Path: filename, Modelfile: string(modelfile)}
 	fn := func(resp api.ProgressResponse) error {
-		log.Printf("progress(%s): %s", resp.Digest, resp.Status)
+		if resp.Digest != "" {
+			spinner.Stop()
+
+			bar, ok := bars[resp.Digest]
+			if !ok {
+				bar = progress.NewBar(resp.Status, resp.Total, resp.Completed)
+				bars[resp.Digest] = bar
+				p.Add(resp.Digest, bar)
+			}
+
+			bar.Set(resp.Completed)
+		} else if status != resp.Status {
+			spinner.Stop()
+
+			status = resp.Status
+			spinner = progress.NewSpinner(status)
+			p.Add(status, spinner)
+		}
+
 		return nil
 	}
 
+	request := api.CreateRequest{Name: args[0], Path: filename, Modelfile: string(modelfile)}
 	if err := client.Create(context.Background(), &request, fn); err != nil {
 		return err
 	}
 
+	if spinner != nil {
+		spinner.Stop()
+	}
+
 	return nil
 }
 
@@ -141,13 +173,53 @@ func PushHandler(cmd *cobra.Command, args []string) error {
 		return err
 	}
 
-	request := api.PushRequest{Name: args[0], Insecure: insecure}
+	p := progress.NewProgress(os.Stderr)
+	defer p.Stop()
+
+	bars := make(map[string]*progress.Bar)
+
+	var status string
+	var spinner *progress.Spinner
+
 	fn := func(resp api.ProgressResponse) error {
-		log.Printf("progress(%s): %s", resp.Digest, resp.Status)
+		if resp.Digest != "" {
+			if spinner != nil {
+				spinner.Stop()
+				spinner = nil
+			}
+
+			bar, ok := bars[resp.Digest]
+			if !ok {
+				bar = progress.NewBar(resp.Status, resp.Total, resp.Completed)
+				bars[resp.Digest] = bar
+				p.Add(resp.Digest, bar)
+			}
+
+			bar.Set(resp.Completed)
+		} else if status != resp.Status {
+			if spinner != nil {
+				spinner.Stop()
+				spinner = nil
+			}
+
+			status = resp.Status
+			spinner = progress.NewSpinner(status)
+			p.Add(status, spinner)
+		}
+
 		return nil
 	}
 
-	return client.Push(context.Background(), &request, fn)
+	request := api.PushRequest{Name: args[0], Insecure: insecure}
+	if err := client.Push(context.Background(), &request, fn); err != nil {
+		return err
+	}
+
+	if spinner != nil {
+		spinner.Stop()
+	}
+
+	return nil
 }
 
 func ListHandler(cmd *cobra.Command, args []string) error {
@@ -297,22 +369,58 @@ func PullHandler(cmd *cobra.Command, args []string) error {
 		return err
 	}
 
-	return pull(args[0], insecure)
-}
-
-func pull(model string, insecure bool) error {
 	client, err := api.ClientFromEnvironment()
 	if err != nil {
 		return err
 	}
 
-	request := api.PullRequest{Name: model, Insecure: insecure}
+	p := progress.NewProgress(os.Stderr)
+	defer p.Stop()
+
+	bars := make(map[string]*progress.Bar)
+
+	var status string
+	var spinner *progress.Spinner
+
 	fn := func(resp api.ProgressResponse) error {
-		log.Printf("progress(%s): %s", resp.Digest, resp.Status)
+		if resp.Digest != "" {
+			if spinner != nil {
+				spinner.Stop()
+				spinner = nil
+			}
+
+			bar, ok := bars[resp.Digest]
+			if !ok {
+				bar = progress.NewBar(resp.Status, resp.Total, resp.Completed)
+				bars[resp.Digest] = bar
+				p.Add(resp.Digest, bar)
+			}
+
+			bar.Set(resp.Completed)
+		} else if status != resp.Status {
+			if spinner != nil {
+				spinner.Stop()
+				spinner = nil
+			}
+
+			status = resp.Status
+			spinner = progress.NewSpinner(status)
+			p.Add(status, spinner)
+		}
+
 		return nil
 	}
 
-	return client.Pull(context.Background(), &request, fn)
+	request := api.PullRequest{Name: args[0], Insecure: insecure}
+	if err := client.Pull(context.Background(), &request, fn); err != nil {
+		return err
+	}
+
+	if spinner != nil {
+		spinner.Stop()
+	}
+
+	return nil
 }
 
 func RunGenerate(cmd *cobra.Command, args []string) error {