Explorar o código

cmd: spinner progress for transfer model data (#6100)

Josh hai 8 meses
pai
achega
f7e3b9190f
Modificáronse 2 ficheiros con 52 adicións e 7 borrados
  1. 42 3
      cmd/cmd.go
  2. 10 4
      progress/spinner.go

+ 42 - 3
cmd/cmd.go

@@ -22,6 +22,7 @@ import (
 	"runtime"
 	"slices"
 	"strings"
+	"sync/atomic"
 	"syscall"
 	"time"
 
@@ -78,6 +79,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 	status := "transferring model data"
 	spinner := progress.NewSpinner(status)
 	p.Add(status, spinner)
+	defer p.Stop()
 
 	for i := range modelfile.Commands {
 		switch modelfile.Commands[i].Name {
@@ -112,7 +114,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 				path = tempfile
 			}
 
-			digest, err := createBlob(cmd, client, path)
+			digest, err := createBlob(cmd, client, path, spinner)
 			if err != nil {
 				return err
 			}
@@ -263,13 +265,20 @@ func tempZipFiles(path string) (string, error) {
 	return tempfile.Name(), nil
 }
 
-func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) {
+func createBlob(cmd *cobra.Command, client *api.Client, path string, spinner *progress.Spinner) (string, error) {
 	bin, err := os.Open(path)
 	if err != nil {
 		return "", err
 	}
 	defer bin.Close()
 
+	// Get file info to retrieve the size
+	fileInfo, err := bin.Stat()
+	if err != nil {
+		return "", err
+	}
+	fileSize := fileInfo.Size()
+
 	hash := sha256.New()
 	if _, err := io.Copy(hash, bin); err != nil {
 		return "", err
@@ -279,13 +288,43 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
 		return "", err
 	}
 
+	var pw progressWriter
+	status := "transferring model data 0%"
+	spinner.SetMessage(status)
+
+	done := make(chan struct{})
+	defer close(done)
+
+	go func() {
+		ticker := time.NewTicker(60 * time.Millisecond)
+		defer ticker.Stop()
+		for {
+			select {
+			case <-ticker.C:
+				spinner.SetMessage(fmt.Sprintf("transferring model data %d%%", int(100*pw.n.Load()/fileSize)))
+			case <-done:
+				spinner.SetMessage("transferring model data 100%")
+				return
+			}
+		}
+	}()
+
 	digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
-	if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
+	if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
 		return "", err
 	}
 	return digest, nil
 }
 
+type progressWriter struct {
+	n atomic.Int64
+}
+
+func (w *progressWriter) Write(p []byte) (n int, err error) {
+	w.n.Add(int64(len(p)))
+	return len(p), nil
+}
+
 func RunHandler(cmd *cobra.Command, args []string) error {
 	interactive := true
 

+ 10 - 4
progress/spinner.go

@@ -3,11 +3,12 @@ package progress
 import (
 	"fmt"
 	"strings"
+	"sync/atomic"
 	"time"
 )
 
 type Spinner struct {
-	message      string
+	message      atomic.Value
 	messageWidth int
 
 	parts []string
@@ -21,20 +22,25 @@ type Spinner struct {
 
 func NewSpinner(message string) *Spinner {
 	s := &Spinner{
-		message: message,
 		parts: []string{
 			"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏",
 		},
 		started: time.Now(),
 	}
+	s.SetMessage(message)
 	go s.start()
 	return s
 }
 
+func (s *Spinner) SetMessage(message string) {
+	s.message.Store(message)
+}
+
 func (s *Spinner) String() string {
 	var sb strings.Builder
-	if len(s.message) > 0 {
-		message := strings.TrimSpace(s.message)
+
+	if message, ok := s.message.Load().(string); ok && len(message) > 0 {
+		message := strings.TrimSpace(message)
 		if s.messageWidth > 0 && len(message) > s.messageWidth {
 			message = message[:s.messageWidth]
 		}