瀏覽代碼

display pull progress

Bruce MacDonald 1 年之前
父節點
當前提交
7cf5905063
共有 7 個文件被更改,包括 81 次插入19 次删除
  1. 21 4
      api/client.go
  2. 6 0
      api/types.go
  3. 29 2
      cmd/cmd.go
  4. 3 1
      go.mod
  5. 6 2
      go.sum
  6. 15 9
      server/models.go
  7. 1 1
      server/routes.go

+ 21 - 4
api/client.go

@@ -8,6 +8,7 @@ import (
 	"fmt"
 	"io"
 	"net/http"
+	"strings"
 )
 
 type Client struct {
@@ -65,7 +66,6 @@ func (c *Client) stream(ctx context.Context, method string, path string, reqData
 		if err != nil {
 			break
 		}
-
 		callback(bytes.TrimSuffix(line, []byte("\n")))
 	}
 
@@ -128,10 +128,27 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, callback fu
 	return &res, nil
 }
 
-func (c *Client) Pull(ctx context.Context, req *PullRequest, callback func(token string)) (*PullResponse, error) {
+func (c *Client) Pull(ctx context.Context, req *PullRequest, callback func(progress PullProgress)) (*PullResponse, error) {
 	var res PullResponse
-	if err := c.stream(ctx, http.MethodPost, "/api/pull", req, func(token []byte) {
-		callback(string(token))
+	if err := c.stream(ctx, http.MethodPost, "/api/pull", req, func(progressBytes []byte) {
+		/*
+			Events have the following format for progress:
+				event:progress
+				data:{"total":123,"completed":123,"percent":0.1}
+			Need to parse out the data part and unmarshal it.
+		*/
+		eventParts := strings.Split(string(progressBytes), "data:")
+		if len(eventParts) < 2 {
+			// no data part, ignore
+			return
+		}
+		eventData := eventParts[1]
+		var progress PullProgress
+		if err := json.Unmarshal([]byte(eventData), &progress); err != nil {
+			fmt.Println(err)
+			return
+		}
+		callback(progress)
 	}); err != nil {
 		return nil, err
 	}

+ 6 - 0
api/types.go

@@ -22,6 +22,12 @@ type PullRequest struct {
 	Model string `json:"model"`
 }
 
+type PullProgress struct {
+	Total     int     `json:"total"`
+	Completed int     `json:"completed"`
+	Percent   float64 `json:"percent"`
+}
+
 type PullResponse struct {
 	Response string `json:"response"`
 }

+ 29 - 2
cmd/cmd.go

@@ -7,7 +7,9 @@ import (
 	"net"
 	"os"
 	"path"
+	"sync"
 
+	"github.com/gosuri/uiprogress"
 	"github.com/jmorganca/ollama/api"
 	"github.com/jmorganca/ollama/server"
 	"github.com/spf13/cobra"
@@ -22,6 +24,10 @@ func cacheDir() string {
 	return path.Join(home, ".ollama")
 }
 
+func bytesToGB(bytes int) float64 {
+	return float64(bytes) / float64(1<<30)
+}
+
 func run(model string) error {
 	client, err := NewAPIClient()
 	if err != nil {
@@ -30,8 +36,29 @@ func run(model string) error {
 	pr := api.PullRequest{
 		Model: model,
 	}
-	callback := func(progress string) {
-		fmt.Println(progress)
+	var bar *uiprogress.Bar
+	mutex := &sync.Mutex{}
+	var progressData api.PullProgress
+
+	callback := func(progress api.PullProgress) {
+		mutex.Lock()
+		progressData = progress
+		if bar == nil {
+			uiprogress.Start()                           // start rendering
+			bar = uiprogress.AddBar(int(progress.Total)) // Add a new bar
+
+			// display the total file size and how much has downloaded so far
+			bar.PrependFunc(func(b *uiprogress.Bar) string {
+				return fmt.Sprintf("Downloading: %.2f GB / %.2f GB", bytesToGB(progressData.Completed), bytesToGB(progressData.Total))
+			})
+
+			// display completion percentage
+			bar.AppendFunc(func(b *uiprogress.Bar) string {
+				return fmt.Sprintf(" %d%%", int((float64(progressData.Completed)/float64(progressData.Total))*100))
+			})
+		}
+		bar.Set(int(progress.Completed))
+		mutex.Unlock()
 	}
 	_, err = client.Pull(context.Background(), &pr, callback)
 	return err

+ 3 - 1
go.mod

@@ -4,6 +4,7 @@ go 1.20
 
 require (
 	github.com/gin-gonic/gin v1.9.1
+	github.com/gosuri/uiprogress v0.0.1
 	github.com/spf13/cobra v1.7.0
 )
 
@@ -17,6 +18,7 @@ require (
 	github.com/go-playground/validator/v10 v10.14.0 // indirect
 	github.com/goccy/go-json v0.10.2 // indirect
 	github.com/google/go-cmp v0.5.9 // indirect
+	github.com/gosuri/uilive v0.0.4 // indirect
 	github.com/inconshreveable/mousetrap v1.1.0 // indirect
 	github.com/json-iterator/go v1.1.12 // indirect
 	github.com/klauspost/cpuid/v2 v2.2.4 // indirect
@@ -32,7 +34,7 @@ require (
 	golang.org/x/arch v0.3.0 // indirect
 	golang.org/x/crypto v0.10.0 // indirect
 	golang.org/x/net v0.10.0 // indirect
-	golang.org/x/sys v0.9.0 // indirect
+	golang.org/x/sys v0.10.0 // indirect
 	golang.org/x/text v0.10.0 // indirect
 	google.golang.org/protobuf v1.30.0 // indirect
 	gopkg.in/yaml.v3 v3.0.1 // indirect

+ 6 - 2
go.sum

@@ -28,6 +28,10 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
 github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
 github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
 github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
+github.com/gosuri/uilive v0.0.4 h1:hUEBpQDj8D8jXgtCdBu7sWsy5sbW/5GhuO8KBwJ2jyY=
+github.com/gosuri/uilive v0.0.4/go.mod h1:V/epo5LjjlDE5RJUcqx8dbw+zc93y5Ya3yg8tfZ74VI=
+github.com/gosuri/uiprogress v0.0.1 h1:0kpv/XY/qTmFWl/SkaJykZXrBBzwwadmW8fRb7RJSxw=
+github.com/gosuri/uiprogress v0.0.1/go.mod h1:C1RTYn4Sc7iEyf6j8ft5dyoZ4212h8G1ol9QQluh5+0=
 github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
 github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
 github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
@@ -97,8 +101,8 @@ golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBc
 golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s=
-golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
+golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
 golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
 golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=

+ 15 - 9
server/models.go

@@ -9,15 +9,14 @@ import (
 	"os"
 	"path"
 	"strconv"
+
+	"github.com/jmorganca/ollama/api"
 )
 
 // const directoryURL = "https://ollama.ai/api/models"
+// TODO
 const directoryURL = "https://raw.githubusercontent.com/jmorganca/ollama/go/models.json"
 
-type directoryCtxKey string
-
-var dirCtx directoryCtxKey = "directory"
-
 type Model struct {
 	Name             string `json:"name"`
 	DisplayName      string `json:"display_name"`
@@ -31,7 +30,7 @@ type Model struct {
 	License          string `json:"license"`
 }
 
-func pull(model string, progressCh chan<- string) error {
+func pull(model string, progressCh chan<- api.PullProgress) error {
 	remote, err := getRemote(model)
 	if err != nil {
 		return fmt.Errorf("failed to pull model: %w", err)
@@ -64,7 +63,7 @@ func getRemote(model string) (*Model, error) {
 	return nil, fmt.Errorf("model not found in directory: %s", model)
 }
 
-func saveModel(model *Model, progressCh chan<- string) error {
+func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
 	// this models cache directory is created by the server on startup
 	home, err := os.UserHomeDir()
 	if err != nil {
@@ -130,11 +129,18 @@ func saveModel(model *Model, progressCh chan<- string) error {
 		totalBytes += n
 
 		// send progress updates
-		progressCh <- fmt.Sprintf("Downloaded %d out of %d bytes (%.2f%%)", totalBytes, totalSize, float64(totalBytes)/float64(totalSize)*100)
+		progressCh <- api.PullProgress{
+			Total:     totalSize,
+			Completed: totalBytes,
+			Percent:   float64(totalBytes) / float64(totalSize) * 100,
+		}
 	}
 
-	// send completion message
-	progressCh <- "Download complete!"
+	progressCh <- api.PullProgress{
+		Total:     totalSize,
+		Completed: totalSize,
+		Percent:   100,
+	}
 
 	return nil
 }

+ 1 - 1
server/routes.go

@@ -107,7 +107,7 @@ func Serve(ln net.Listener) error {
 			return
 		}
 
-		progressCh := make(chan string)
+		progressCh := make(chan api.PullProgress)
 		go func() {
 			defer close(progressCh)
 			if err := pull(req.Model, progressCh); err != nil {