|
@@ -5,13 +5,16 @@ import (
|
|
|
"crypto/sha256"
|
|
|
"encoding/hex"
|
|
|
"encoding/json"
|
|
|
+ "errors"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
+ "io/ioutil"
|
|
|
"log"
|
|
|
"net/http"
|
|
|
"os"
|
|
|
"path"
|
|
|
"path/filepath"
|
|
|
+ "strconv"
|
|
|
"strings"
|
|
|
|
|
|
"github.com/jmorganca/ollama/api"
|
|
@@ -536,7 +539,8 @@ func PullModel(name, username, password string, fn func(status, digest string, T
|
|
|
|
|
|
for _, layer := range layers {
|
|
|
fn("starting download", layer.Digest, total, completed, float64(completed)/float64(total))
|
|
|
- if err := downloadBlob(DefaultRegistry, repoName, layer.Digest, username, password); err != nil {
|
|
|
+ if err := downloadBlob(DefaultRegistry, repoName, layer.Digest, username, password, fn); err != nil {
|
|
|
+ fn(fmt.Sprintf("error downloading: %v", err), layer.Digest, 0, 0, 0)
|
|
|
return err
|
|
|
}
|
|
|
completed += layer.Size
|
|
@@ -717,7 +721,7 @@ func uploadBlob(location string, layer *Layer, username string, password string)
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-func downloadBlob(registryURL, repoName, digest, username, password string) error {
|
|
|
+func downloadBlob(registryURL, repoName, digest string, username, password string, fn func(status, digest string, Total, Completed int, Percent float64)) error {
|
|
|
home, err := os.UserHomeDir()
|
|
|
if err != nil {
|
|
|
return err
|
|
@@ -732,8 +736,22 @@ func downloadBlob(registryURL, repoName, digest, username, password string) erro
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
+ var size int64
|
|
|
+
|
|
|
+ fi, err := os.Stat(fp + "-partial")
|
|
|
+ switch {
|
|
|
+ case errors.Is(err, os.ErrNotExist):
|
|
|
+ // noop, file doesn't exist so create it
|
|
|
+ case err != nil:
|
|
|
+ return fmt.Errorf("stat: %w", err)
|
|
|
+ default:
|
|
|
+ size = fi.Size()
|
|
|
+ }
|
|
|
+
|
|
|
url := fmt.Sprintf("%s/v2/%s/blobs/%s", registryURL, repoName, digest)
|
|
|
- headers := map[string]string{}
|
|
|
+ headers := map[string]string{
|
|
|
+ "Range": fmt.Sprintf("bytes=%d-", size),
|
|
|
+ }
|
|
|
|
|
|
resp, err := makeRequest("GET", url, headers, nil, username, password)
|
|
|
if err != nil {
|
|
@@ -742,10 +760,8 @@ func downloadBlob(registryURL, repoName, digest, username, password string) erro
|
|
|
}
|
|
|
defer resp.Body.Close()
|
|
|
|
|
|
- // TODO: handle range requests to make this resumable
|
|
|
-
|
|
|
- if resp.StatusCode != http.StatusOK {
|
|
|
- body, _ := io.ReadAll(resp.Body)
|
|
|
+ if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
|
|
|
+ body, _ := ioutil.ReadAll(resp.Body)
|
|
|
return fmt.Errorf("registry responded with code %d: %v", resp.StatusCode, string(body))
|
|
|
}
|
|
|
|
|
@@ -754,16 +770,34 @@ func downloadBlob(registryURL, repoName, digest, username, password string) erro
|
|
|
return fmt.Errorf("make blobs directory: %w", err)
|
|
|
}
|
|
|
|
|
|
- out, err := os.Create(fp)
|
|
|
+ out, err := os.OpenFile(fp+"-partial", os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
|
|
|
if err != nil {
|
|
|
- log.Printf("couldn't create %s", fp)
|
|
|
- return err
|
|
|
+ panic(err)
|
|
|
}
|
|
|
defer out.Close()
|
|
|
|
|
|
- _, err = io.Copy(out, resp.Body)
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
+ remaining, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
|
|
|
+ completed := size
|
|
|
+ total := remaining + completed
|
|
|
+
|
|
|
+ for {
|
|
|
+ fn(fmt.Sprintf("Downloading %s", digest), digest, int(total), int(completed), float64(completed)/float64(total))
|
|
|
+ if completed >= total {
|
|
|
+ fmt.Printf("finished downloading\n")
|
|
|
+ err = os.Rename(fp+"-partial", fp)
|
|
|
+ if err != nil {
|
|
|
+ fmt.Printf("error: %v\n", err)
|
|
|
+ fn(fmt.Sprintf("error renaming file: %v", err), digest, int(total), int(completed), 1)
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ break
|
|
|
+ }
|
|
|
+
|
|
|
+ n, err := io.CopyN(out, resp.Body, 8192)
|
|
|
+ if err != nil && !errors.Is(err, io.EOF) {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ completed += n
|
|
|
}
|
|
|
|
|
|
log.Printf("success getting %s\n", digest)
|
|
@@ -790,7 +824,7 @@ func makeRequest(method, url string, headers map[string]string, body io.Reader,
|
|
|
if len(via) >= 10 {
|
|
|
return fmt.Errorf("too many redirects")
|
|
|
}
|
|
|
- log.Printf("redirected to: %s", req.URL)
|
|
|
+ log.Printf("redirected to: %s\n", req.URL)
|
|
|
return nil
|
|
|
},
|
|
|
}
|