Michael Yang 1 year ago
parent
commit
c4a3ccd7ac
3 changed files with 285 additions and 0 deletions
  1. 118 0
      progress/bar.go
  2. 65 0
      progress/progress.go
  3. 102 0
      progress/spinner.go

+ 118 - 0
progress/bar.go

@@ -0,0 +1,118 @@
+package progress
+
+import (
+	"fmt"
+	"os"
+	"strings"
+	"time"
+
+	"github.com/jmorganca/ollama/format"
+	"golang.org/x/term"
+)
+
+type Bar struct {
+	message      string
+	messageWidth int
+
+	maxValue     int64
+	initialValue int64
+	currentValue int64
+
+	started time.Time
+	stopped time.Time
+}
+
+func NewBar(message string, maxValue, initialValue int64) *Bar {
+	return &Bar{
+		message:      message,
+		messageWidth: -1,
+		maxValue:     maxValue,
+		initialValue: initialValue,
+		currentValue: initialValue,
+		started:      time.Now(),
+	}
+}
+
+func (b *Bar) String() string {
+	termWidth, _, err := term.GetSize(int(os.Stderr.Fd()))
+	if err != nil {
+		panic(err)
+	}
+
+	var pre, mid, suf strings.Builder
+
+	if b.message != "" {
+		message := strings.TrimSpace(b.message)
+		if b.messageWidth > 0 && len(message) > b.messageWidth {
+			message = message[:b.messageWidth]
+		}
+
+		fmt.Fprintf(&pre, "%s", message)
+		if b.messageWidth-pre.Len() >= 0 {
+			pre.WriteString(strings.Repeat(" ", b.messageWidth-pre.Len()))
+		}
+
+		pre.WriteString(" ")
+	}
+
+	fmt.Fprintf(&pre, "%.1f%% ", b.percent())
+
+	fmt.Fprintf(&suf, "(%s/%s, %s/s, %s)",
+		format.HumanBytes(b.currentValue),
+		format.HumanBytes(b.maxValue),
+		format.HumanBytes(int64(b.rate())),
+		b.elapsed())
+
+	mid.WriteString("[")
+
+	// pad 3 for last = or > and "] "
+	f := termWidth - pre.Len() - mid.Len() - suf.Len() - 3
+	n := int(float64(f) * b.percent() / 100)
+	if n > 0 {
+		mid.WriteString(strings.Repeat("=", n))
+	}
+
+	if b.currentValue >= b.maxValue {
+		mid.WriteString("=")
+	} else {
+		mid.WriteString(">")
+	}
+
+	if f-n > 0 {
+		mid.WriteString(strings.Repeat(" ", f-n))
+	}
+
+	mid.WriteString("] ")
+
+	return pre.String() + mid.String() + suf.String()
+}
+
+func (b *Bar) Set(value int64) {
+	if value >= b.maxValue {
+		value = b.maxValue
+		b.stopped = time.Now()
+	}
+
+	b.currentValue = value
+}
+
+func (b *Bar) percent() float64 {
+	if b.maxValue > 0 {
+		return float64(b.currentValue) / float64(b.maxValue) * 100
+	}
+
+	return 0
+}
+
+func (b *Bar) rate() float64 {
+	return (float64(b.currentValue) - float64(b.initialValue)) / b.elapsed().Seconds()
+}
+
+func (b *Bar) elapsed() time.Duration {
+	stopped := b.stopped
+	if stopped.IsZero() {
+		stopped = time.Now()
+	}
+
+	return stopped.Sub(b.started).Round(time.Second)
+}

+ 65 - 0
progress/progress.go

@@ -0,0 +1,65 @@
+package progress
+
+import (
+	"fmt"
+	"io"
+	"sync"
+	"time"
+)
+
+type State interface {
+	String() string
+}
+
+type Progress struct {
+	mu  sync.Mutex
+	pos int
+	w   io.Writer
+
+	ticker *time.Ticker
+	states []State
+}
+
+func NewProgress(w io.Writer) *Progress {
+	p := &Progress{pos: -1, w: w}
+	go p.start()
+	return p
+}
+
+func (p *Progress) Stop() {
+	if p.ticker != nil {
+		p.ticker.Stop()
+		p.ticker = nil
+		p.render()
+	}
+}
+
+func (p *Progress) Add(key string, state State) {
+	p.mu.Lock()
+	defer p.mu.Unlock()
+
+	p.states = append(p.states, state)
+}
+
+func (p *Progress) render() error {
+	p.mu.Lock()
+	defer p.mu.Unlock()
+
+	fmt.Fprintf(p.w, "\033[%dA", p.pos)
+	for _, state := range p.states {
+		fmt.Fprintln(p.w, state.String())
+	}
+
+	if len(p.states) > 0 {
+		p.pos = len(p.states)
+	}
+
+	return nil
+}
+
+func (p *Progress) start() {
+	p.ticker = time.NewTicker(100 * time.Millisecond)
+	for range p.ticker.C {
+		p.render()
+	}
+}

+ 102 - 0
progress/spinner.go

@@ -0,0 +1,102 @@
+package progress
+
+import (
+	"fmt"
+	"os"
+	"strings"
+	"time"
+
+	"golang.org/x/term"
+)
+
+type Spinner struct {
+	message      string
+	messageWidth int
+
+	parts []string
+
+	value int
+
+	ticker  *time.Ticker
+	started time.Time
+	stopped time.Time
+}
+
+func NewSpinner(message string) *Spinner {
+	s := &Spinner{
+		message: message,
+		parts: []string{
+			"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏",
+		},
+		started: time.Now(),
+	}
+	go s.start()
+	return s
+}
+
+func (s *Spinner) String() string {
+	termWidth, _, err := term.GetSize(int(os.Stderr.Fd()))
+	if err != nil {
+		panic(err)
+	}
+
+	var pre strings.Builder
+	if len(s.message) > 0 {
+		message := strings.TrimSpace(s.message)
+		if s.messageWidth > 0 && len(message) > s.messageWidth {
+			message = message[:s.messageWidth]
+		}
+
+		fmt.Fprintf(&pre, "%s", message)
+		if s.messageWidth-pre.Len() >= 0 {
+			pre.WriteString(strings.Repeat(" ", s.messageWidth-pre.Len()))
+		}
+
+		pre.WriteString(" ")
+	}
+
+	var pad int
+	if s.stopped.IsZero() {
+		// spinner has a string length of 3 but a rune length of 1
+		// in order to align correctly, we need to pad with (3 - 1) = 2 spaces
+		spinner := s.parts[s.value]
+		pre.WriteString(spinner)
+		pad = len(spinner) - len([]rune(spinner))
+	}
+
+	var suf strings.Builder
+	fmt.Fprintf(&suf, "(%s)", s.elapsed())
+
+	var mid strings.Builder
+	f := termWidth - pre.Len() - mid.Len() - suf.Len() + pad
+	if f > 0 {
+		mid.WriteString(strings.Repeat(" ", f))
+	}
+
+	return pre.String() + mid.String() + suf.String()
+}
+
+func (s *Spinner) start() {
+	s.ticker = time.NewTicker(100 * time.Millisecond)
+	for range s.ticker.C {
+		s.value = (s.value + 1) % len(s.parts)
+		if !s.stopped.IsZero() {
+			return
+		}
+	}
+}
+
+func (s *Spinner) Stop() {
+	if s.stopped.IsZero() {
+		s.stopped = time.Now()
+	}
+}
+
+func (s *Spinner) elapsed() time.Duration {
+	stopped := s.stopped
+	if stopped.IsZero() {
+		stopped = time.Now()
+	}
+
+	return stopped.Sub(s.started).Round(time.Second)
+}