Browse Source

dynamically size download parts based on file size

Michael Yang 1 year ago
parent
commit
630bb75d2a
1 changed files with 17 additions and 5 deletions
  1. 17 5
      server/download.go

+ 17 - 5
server/download.go

@@ -20,6 +20,7 @@ import (
 	"golang.org/x/sync/errgroup"
 
 	"github.com/jmorganca/ollama/api"
+	"github.com/jmorganca/ollama/format"
 )
 
 var blobDownloadManager sync.Map
@@ -47,6 +48,12 @@ type blobDownloadPart struct {
 	*blobDownload `json:"-"`
 }
 
+const (
+	numDownloadParts          = 64
+	minDownloadPartSize int64 = 32 * 1000 * 1000
+	maxDownloadPartSize int64 = 256 * 1000 * 1000
+)
+
 func (p *blobDownloadPart) Name() string {
 	return strings.Join([]string{
 		p.blobDownload.Name, "partial", strconv.Itoa(p.N),
@@ -92,9 +99,15 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *R
 
 		b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
 
-		var offset int64
-		var size int64 = 64 * 1024 * 1024
+		var size = b.Total / numDownloadParts
+		switch {
+		case size < minDownloadPartSize:
+			size = minDownloadPartSize
+		case size > maxDownloadPartSize:
+			size = maxDownloadPartSize
+		}
 
+		var offset int64
 		for offset < b.Total {
 			if offset+size > b.Total {
 				size = b.Total - offset
@@ -108,7 +121,7 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *R
 		}
 	}
 
-	log.Printf("downloading %s in %d part(s)", b.Digest[7:19], len(b.Parts))
+	log.Printf("downloading %s in %d %s part(s)", b.Digest[7:19], len(b.Parts), format.HumanBytes(b.Parts[0].Size))
 	return nil
 }
 
@@ -126,8 +139,7 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *Regis
 	file.Truncate(b.Total)
 
 	g, _ := errgroup.WithContext(ctx)
-	// TODO(mxyng): download concurrency should be configurable
-	g.SetLimit(64)
+	g.SetLimit(numDownloadParts)
 	for i := range b.Parts {
 		part := b.Parts[i]
 		if part.Completed == part.Size {