瀏覽代碼

add progressbar for model pulls

Patrick Devine 1 年之前
父節點
當前提交
2e1394e405
共有 4 個文件被更改,包括 67 次插入84 次删除
  1. 19 1
      cmd/cmd.go
  2. 48 14
      server/images.go
  3. 0 65
      server/models.go
  4. 0 4
      server/routes.go

+ 19 - 1
cmd/cmd.go

@@ -90,9 +90,27 @@ func RunPull(cmd *cobra.Command, args []string) error {
 func pull(model string) error {
 	client := api.NewClient()
 
+	var bar *progressbar.ProgressBar
+
+	currentLayer := ""
 	request := api.PullRequest{Name: model}
 	fn := func(resp api.PullProgress) error {
-		fmt.Println(resp.Status)
+		if resp.Digest != currentLayer && resp.Digest != "" {
+			if currentLayer != "" {
+				fmt.Println()
+			}
+			currentLayer = resp.Digest
+			layerStr := resp.Digest[7:23] + "..."
+			bar = progressbar.DefaultBytes(
+				int64(resp.Total),
+				"pulling "+layerStr,
+			)
+		} else if resp.Digest == currentLayer && resp.Digest != "" {
+			bar.Set(resp.Completed)
+		} else {
+			currentLayer = ""
+			fmt.Println(resp.Status)
+		}
 		return nil
 	}
 

+ 48 - 14
server/images.go

@@ -5,13 +5,16 @@ import (
 	"crypto/sha256"
 	"encoding/hex"
 	"encoding/json"
+	"errors"
 	"fmt"
 	"io"
+	"io/ioutil"
 	"log"
 	"net/http"
 	"os"
 	"path"
 	"path/filepath"
+	"strconv"
 	"strings"
 
 	"github.com/jmorganca/ollama/api"
@@ -536,7 +539,8 @@ func PullModel(name, username, password string, fn func(status, digest string, T
 
 	for _, layer := range layers {
 		fn("starting download", layer.Digest, total, completed, float64(completed)/float64(total))
-		if err := downloadBlob(DefaultRegistry, repoName, layer.Digest, username, password); err != nil {
+		if err := downloadBlob(DefaultRegistry, repoName, layer.Digest, username, password, fn); err != nil {
+			fn(fmt.Sprintf("error downloading: %v", err), layer.Digest, 0, 0, 0)
 			return err
 		}
 		completed += layer.Size
@@ -717,7 +721,7 @@ func uploadBlob(location string, layer *Layer, username string, password string)
 	return nil
 }
 
-func downloadBlob(registryURL, repoName, digest, username, password string) error {
+func downloadBlob(registryURL, repoName, digest string, username, password string, fn func(status, digest string, Total, Completed int, Percent float64)) error {
 	home, err := os.UserHomeDir()
 	if err != nil {
 		return err
@@ -732,8 +736,22 @@ func downloadBlob(registryURL, repoName, digest, username, password string) erro
 		return nil
 	}
 
+	var size int64
+
+	fi, err := os.Stat(fp + "-partial")
+	switch {
+	case errors.Is(err, os.ErrNotExist):
+		// noop, file doesn't exist so create it
+	case err != nil:
+		return fmt.Errorf("stat: %w", err)
+	default:
+		size = fi.Size()
+	}
+
 	url := fmt.Sprintf("%s/v2/%s/blobs/%s", registryURL, repoName, digest)
-	headers := map[string]string{}
+	headers := map[string]string{
+		"Range": fmt.Sprintf("bytes=%d-", size),
+	}
 
 	resp, err := makeRequest("GET", url, headers, nil, username, password)
 	if err != nil {
@@ -742,10 +760,8 @@ func downloadBlob(registryURL, repoName, digest, username, password string) erro
 	}
 	defer resp.Body.Close()
 
-	// TODO: handle range requests to make this resumable
-
-	if resp.StatusCode != http.StatusOK {
-		body, _ := io.ReadAll(resp.Body)
+	if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
+		body, _ := ioutil.ReadAll(resp.Body)
 		return fmt.Errorf("registry responded with code %d: %v", resp.StatusCode, string(body))
 	}
 
@@ -754,16 +770,34 @@ func downloadBlob(registryURL, repoName, digest, username, password string) erro
 		return fmt.Errorf("make blobs directory: %w", err)
 	}
 
-	out, err := os.Create(fp)
+	out, err := os.OpenFile(fp+"-partial", os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
 	if err != nil {
-		log.Printf("couldn't create %s", fp)
-		return err
+		panic(err)
 	}
 	defer out.Close()
 
-	_, err = io.Copy(out, resp.Body)
-	if err != nil {
-		return err
+	remaining, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
+	completed := size
+	total := remaining + completed
+
+	for {
+		fn(fmt.Sprintf("Downloading %s", digest), digest, int(total), int(completed), float64(completed)/float64(total))
+		if completed >= total {
+			fmt.Printf("finished downloading\n")
+			err = os.Rename(fp+"-partial", fp)
+			if err != nil {
+				fmt.Printf("error: %v\n", err)
+				fn(fmt.Sprintf("error renaming file: %v", err), digest, int(total), int(completed), 1)
+				return err
+			}
+			break
+		}
+
+		n, err := io.CopyN(out, resp.Body, 8192)
+		if err != nil && !errors.Is(err, io.EOF) {
+			return err
+		}
+		completed += n
 	}
 
 	log.Printf("success getting %s\n", digest)
@@ -790,7 +824,7 @@ func makeRequest(method, url string, headers map[string]string, body io.Reader,
 			if len(via) >= 10 {
 				return fmt.Errorf("too many redirects")
 			}
-			log.Printf("redirected to: %s", req.URL)
+			log.Printf("redirected to: %s\n", req.URL)
 			return nil
 		},
 	}

+ 0 - 65
server/models.go

@@ -1,12 +1,6 @@
 package server
 
 import (
-	"fmt"
-	"os"
-	"path"
-	"path/filepath"
-	"strconv"
-
 	"github.com/jmorganca/ollama/api"
 )
 
@@ -26,62 +20,3 @@ type Model struct {
 	License          string `json:"license"`
 }
 
-func saveModel(model *Model, fn func(total, completed int64)) error {
-	// this models cache directory is created by the server on startup
-
-	client := &http.Client{}
-	req, err := http.NewRequest("GET", model.URL, nil)
-	if err != nil {
-		return fmt.Errorf("failed to download model: %w", err)
-	}
-
-	var size int64
-
-	// completed file doesn't exist, check partial file
-	fi, err := os.Stat(model.TempFile())
-	switch {
-	case errors.Is(err, os.ErrNotExist):
-		// noop, file doesn't exist so create it
-	case err != nil:
-		return fmt.Errorf("stat: %w", err)
-	default:
-		size = fi.Size()
-	}
-
-	req.Header.Add("Range", fmt.Sprintf("bytes=%d-", size))
-
-	resp, err := client.Do(req)
-	if err != nil {
-		return fmt.Errorf("failed to download model: %w", err)
-	}
-	defer resp.Body.Close()
-
-	if resp.StatusCode >= 400 {
-		return fmt.Errorf("failed to download model: %s", resp.Status)
-	}
-
-	out, err := os.OpenFile(model.TempFile(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
-	if err != nil {
-		panic(err)
-	}
-	defer out.Close()
-
-	remaining, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
-	completed := size
-
-	total := remaining + completed
-
-	for {
-		fn(total, completed)
-		if completed >= total {
-			return os.Rename(model.TempFile(), model.FullName())
-		}
-
-		n, err := io.CopyN(out, resp.Body, 8192)
-		if err != nil && !errors.Is(err, io.EOF) {
-			return err
-		}
-
-		completed += n
-	}
-}

+ 0 - 4
server/routes.go

@@ -19,10 +19,6 @@ import (
 	"github.com/jmorganca/ollama/llama"
 )
 
-//go:embed templates/*
-var templatesFS embed.FS
-var templates = template.Must(template.ParseFS(templatesFS, "templates/*.prompt"))
-
 func cacheDir() string {
 	home, err := os.UserHomeDir()
 	if err != nil {