Josh Yan 10 months ago
parent
commit
ae65cc8dea
1 changed files with 42 additions and 3 deletions
  1. 42 3
      cmd/cmd.go

+ 42 - 3
cmd/cmd.go

@@ -77,7 +77,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 		return err
 	}
 
-	status := "transferring model data"
+	status := ""
 	spinner := progress.NewSpinner(status)
 	p.Add(status, spinner)
 
@@ -113,7 +113,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 
 				path = tempfile
 			}
-
+			spinner.Stop()
 			digest, err := createBlob(cmd, client, path)
 			if err != nil {
 				return err
@@ -274,6 +274,13 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
 	}
 	defer bin.Close()
 
+	// Get file info to retrieve the size
+	fileInfo, err := bin.Stat()
+	if err != nil {
+		return "", err
+	}
+	fileSize := fileInfo.Size()
+
 	hash := sha256.New()
 	if _, err := io.Copy(hash, bin); err != nil {
 		return "", err
@@ -283,6 +290,29 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
 		return "", err
 	}
 
+	var pw progressWriter
+	// Create a progress bar and start a goroutine to update it
+	p := progress.NewProgress(os.Stderr)
+	bar := progress.NewBar("transferring model data...", fileSize, 0)
+	p.Add("", bar)
+
+	ticker := time.NewTicker(60 * time.Millisecond)
+	done := make(chan struct{})
+	defer p.Stop()
+
+	go func() {
+		defer ticker.Stop()
+		for {
+			select {
+			case <-ticker.C:
+				bar.Set(pw.n)
+			case <-done:
+				bar.Set(fileSize)
+				return
+			}
+		}
+	}()
+
 	digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
 
 	// We check if we can find the models directory locally
@@ -312,12 +342,21 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
 	}
 
 	// If at any point copying the blob over locally fails, we default to the copy through the server
-	if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
+	if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
 		return "", err
 	}
 	return digest, nil
 }
 
+type progressWriter struct {
+	n int64
+}
+
+func (w *progressWriter) Write(p []byte) (n int, err error) {
+	w.n += int64(len(p))
+	return len(p), nil
+}
+
 func getLocalPath(ctx context.Context, digest string) (string, error) {
 	ollamaHost := envconfig.Host