Browse Source

generate progress

Michael Yang 1 year ago
parent
commit
4dcf7a59b1
2 changed files with 37 additions and 1 deletions
  1. 10 0
      cmd/cmd.go
  2. 27 1
      progress/progress.go

+ 10 - 0
cmd/cmd.go

@@ -472,6 +472,13 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
 		return err
 	}
 
+	p := progress.NewProgress(os.Stderr)
+	defer p.Stop()
+
+	spinner := progress.NewSpinner("")
+	defer spinner.Stop()
+	p.Add("", spinner)
+
 	var latest api.GenerateResponse
 
 	generateContext, ok := cmd.Context().Value(generateContextKey("context")).([]int)
@@ -502,6 +509,9 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
 
 	request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext, Format: format}
 	fn := func(response api.GenerateResponse) error {
+		spinner.Stop()
+		p.StopAndClear()
+
 		latest = response
 
 		if wordWrap {

+ 27 - 1
progress/progress.go

@@ -3,8 +3,12 @@ package progress
 import (
 	"fmt"
 	"io"
+	"os"
+	"strings"
 	"sync"
 	"time"
+
+	"golang.org/x/term"
 )
 
 type State interface {
@@ -26,12 +30,34 @@ func NewProgress(w io.Writer) *Progress {
 	return p
 }
 
-func (p *Progress) Stop() {
+func (p *Progress) Stop() bool {
 	if p.ticker != nil {
 		p.ticker.Stop()
 		p.ticker = nil
 		p.render()
+		return true
 	}
+
+	return false
+}
+
+func (p *Progress) StopAndClear() bool {
+	stopped := p.Stop()
+	if stopped {
+		termWidth, _, err := term.GetSize(int(os.Stderr.Fd()))
+		if err != nil {
+			panic(err)
+		}
+
+		// clear the progress bar by:
+		// 1. reset to beginning of line
+		// 2. move up to the first line of the progress bar
+		// 3. fill the terminal width with spaces
+		// 4. reset to beginning of line
+		fmt.Fprintf(p.w, "\r\033[%dA%s\r", p.pos, strings.Repeat(" ", termWidth))
+	}
+
+	return stopped
 }
 
 func (p *Progress) Add(key string, state State) {