|
@@ -20,6 +20,7 @@ import (
|
|
"golang.org/x/sync/errgroup"
|
|
"golang.org/x/sync/errgroup"
|
|
|
|
|
|
"github.com/jmorganca/ollama/api"
|
|
"github.com/jmorganca/ollama/api"
|
|
|
|
+ "github.com/jmorganca/ollama/format"
|
|
)
|
|
)
|
|
|
|
|
|
var blobDownloadManager sync.Map
|
|
var blobDownloadManager sync.Map
|
|
@@ -34,6 +35,9 @@ type blobDownload struct {
|
|
Parts []*blobDownloadPart
|
|
Parts []*blobDownloadPart
|
|
|
|
|
|
context.CancelFunc
|
|
context.CancelFunc
|
|
|
|
+
|
|
|
|
+ done bool
|
|
|
|
+ err error
|
|
references atomic.Int32
|
|
references atomic.Int32
|
|
}
|
|
}
|
|
|
|
|
|
@@ -46,6 +50,12 @@ type blobDownloadPart struct {
|
|
*blobDownload `json:"-"`
|
|
*blobDownload `json:"-"`
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+const (
|
|
|
|
+ numDownloadParts = 64
|
|
|
|
+ minDownloadPartSize int64 = 32 * 1000 * 1000
|
|
|
|
+ maxDownloadPartSize int64 = 256 * 1000 * 1000
|
|
|
|
+)
|
|
|
|
+
|
|
func (p *blobDownloadPart) Name() string {
|
|
func (p *blobDownloadPart) Name() string {
|
|
return strings.Join([]string{
|
|
return strings.Join([]string{
|
|
p.blobDownload.Name, "partial", strconv.Itoa(p.N),
|
|
p.blobDownload.Name, "partial", strconv.Itoa(p.N),
|
|
@@ -91,9 +101,15 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *R
|
|
|
|
|
|
b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
|
|
b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
|
|
|
|
|
|
- var offset int64
|
|
|
|
- var size int64 = 64 * 1024 * 1024
|
|
|
|
|
|
+ var size = b.Total / numDownloadParts
|
|
|
|
+ switch {
|
|
|
|
+ case size < minDownloadPartSize:
|
|
|
|
+ size = minDownloadPartSize
|
|
|
|
+ case size > maxDownloadPartSize:
|
|
|
|
+ size = maxDownloadPartSize
|
|
|
|
+ }
|
|
|
|
|
|
|
|
+ var offset int64
|
|
for offset < b.Total {
|
|
for offset < b.Total {
|
|
if offset+size > b.Total {
|
|
if offset+size > b.Total {
|
|
size = b.Total - offset
|
|
size = b.Total - offset
|
|
@@ -107,11 +123,15 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *R
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
- log.Printf("downloading %s in %d part(s)", b.Digest[7:19], len(b.Parts))
|
|
|
|
|
|
+ log.Printf("downloading %s in %d %s part(s)", b.Digest[7:19], len(b.Parts), format.HumanBytes(b.Parts[0].Size))
|
|
return nil
|
|
return nil
|
|
}
|
|
}
|
|
|
|
|
|
-func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) (err error) {
|
|
|
|
|
|
+func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) {
|
|
|
|
+ b.err = b.run(ctx, requestURL, opts)
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
|
|
defer blobDownloadManager.Delete(b.Digest)
|
|
defer blobDownloadManager.Delete(b.Digest)
|
|
|
|
|
|
ctx, b.CancelFunc = context.WithCancel(ctx)
|
|
ctx, b.CancelFunc = context.WithCancel(ctx)
|
|
@@ -124,9 +144,8 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *Regis
|
|
|
|
|
|
file.Truncate(b.Total)
|
|
file.Truncate(b.Total)
|
|
|
|
|
|
- g, ctx := errgroup.WithContext(ctx)
|
|
|
|
- // TODO(mxyng): download concurrency should be configurable
|
|
|
|
- g.SetLimit(64)
|
|
|
|
|
|
+ g, _ := errgroup.WithContext(ctx)
|
|
|
|
+ g.SetLimit(numDownloadParts)
|
|
for i := range b.Parts {
|
|
for i := range b.Parts {
|
|
part := b.Parts[i]
|
|
part := b.Parts[i]
|
|
if part.Completed == part.Size {
|
|
if part.Completed == part.Size {
|
|
@@ -168,7 +187,12 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *Regis
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
- return os.Rename(file.Name(), b.Name)
|
|
|
|
|
|
+ if err := os.Rename(file.Name(), b.Name); err != nil {
|
|
|
|
+ return err
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ b.done = true
|
|
|
|
+ return nil
|
|
}
|
|
}
|
|
|
|
|
|
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error {
|
|
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error {
|
|
@@ -267,11 +291,8 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
|
|
Completed: b.Completed.Load(),
|
|
Completed: b.Completed.Load(),
|
|
})
|
|
})
|
|
|
|
|
|
- if b.Completed.Load() >= b.Total {
|
|
|
|
- // wait for the file to get renamed
|
|
|
|
- if _, err := os.Stat(b.Name); err == nil {
|
|
|
|
- return nil
|
|
|
|
- }
|
|
|
|
|
|
+ if b.done || b.err != nil {
|
|
|
|
+ return b.err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|