Pārlūkot izejas kodu

Merge pull request #2221 from ollama/mxyng/up-down-ccy

adjust download and upload concurrency based on available bandwidth
Michael Yang 1 gadu atpakaļ
vecāks
revīzija
2e20110e50
2 mainītis faili ar 106 papildinājumiem un 17 dzēšanām
  1. 100 13
      server/download.go
  2. 6 4
      server/upload.go

+ 100 - 13
server/download.go

@@ -20,6 +20,7 @@ import (
 	"time"
 	"time"
 
 
 	"golang.org/x/sync/errgroup"
 	"golang.org/x/sync/errgroup"
+	"golang.org/x/sync/semaphore"
 
 
 	"github.com/jmorganca/ollama/api"
 	"github.com/jmorganca/ollama/api"
 	"github.com/jmorganca/ollama/format"
 	"github.com/jmorganca/ollama/format"
@@ -138,30 +139,29 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r
 }
 }
 
 
 func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *registryOptions) {
 func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *registryOptions) {
-	b.err = b.run(ctx, requestURL, opts)
-}
-
-func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
 	defer blobDownloadManager.Delete(b.Digest)
 	defer blobDownloadManager.Delete(b.Digest)
 	ctx, b.CancelFunc = context.WithCancel(ctx)
 	ctx, b.CancelFunc = context.WithCancel(ctx)
 
 
 	file, err := os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0o644)
 	file, err := os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0o644)
 	if err != nil {
 	if err != nil {
-		return err
+		b.err = err
+		return
 	}
 	}
 	defer file.Close()
 	defer file.Close()
 
 
 	_ = file.Truncate(b.Total)
 	_ = file.Truncate(b.Total)
 
 
-	g, inner := errgroup.WithContext(ctx)
-	g.SetLimit(numDownloadParts)
+	var limit int64 = 2
+	g, inner := NewLimitGroup(ctx, numDownloadParts, limit)
+	go watchDelta(inner, g, &b.Completed, limit)
+
 	for i := range b.Parts {
 	for i := range b.Parts {
 		part := b.Parts[i]
 		part := b.Parts[i]
 		if part.Completed == part.Size {
 		if part.Completed == part.Size {
 			continue
 			continue
 		}
 		}
 
 
-		g.Go(func() error {
+		g.Go(inner, func() error {
 			var err error
 			var err error
 			for try := 0; try < maxRetries; try++ {
 			for try := 0; try < maxRetries; try++ {
 				w := io.NewOffsetWriter(file, part.StartsAt())
 				w := io.NewOffsetWriter(file, part.StartsAt())
@@ -188,26 +188,29 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
 	}
 	}
 
 
 	if err := g.Wait(); err != nil {
 	if err := g.Wait(); err != nil {
-		return err
+		b.err = err
+		return
 	}
 	}
 
 
 	// explicitly close the file so we can rename it
 	// explicitly close the file so we can rename it
 	if err := file.Close(); err != nil {
 	if err := file.Close(); err != nil {
-		return err
+		b.err = err
+		return
 	}
 	}
 
 
 	for i := range b.Parts {
 	for i := range b.Parts {
 		if err := os.Remove(file.Name() + "-" + strconv.Itoa(i)); err != nil {
 		if err := os.Remove(file.Name() + "-" + strconv.Itoa(i)); err != nil {
-			return err
+			b.err = err
+			return
 		}
 		}
 	}
 	}
 
 
 	if err := os.Rename(file.Name(), b.Name); err != nil {
 	if err := os.Rename(file.Name(), b.Name); err != nil {
-		return err
+		b.err = err
+		return
 	}
 	}
 
 
 	b.done = true
 	b.done = true
-	return nil
 }
 }
 
 
 func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *registryOptions) error {
 func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *registryOptions) error {
@@ -377,3 +380,87 @@ func downloadBlob(ctx context.Context, opts downloadOpts) error {
 
 
 	return download.Wait(ctx, opts.fn)
 	return download.Wait(ctx, opts.fn)
 }
 }
+
+type LimitGroup struct {
+	*errgroup.Group
+	*semaphore.Weighted
+	size, limit int64
+}
+
+func NewLimitGroup(ctx context.Context, size, limit int64) (*LimitGroup, context.Context) {
+	g, ctx := errgroup.WithContext(ctx)
+	return &LimitGroup{
+		Group:    g,
+		Weighted: semaphore.NewWeighted(size),
+		size:     size,
+		limit:    limit,
+	}, ctx
+}
+
+func (g *LimitGroup) Go(ctx context.Context, fn func() error) {
+	var weight int64 = 1
+	if g.limit > 0 {
+		weight = g.size / g.limit
+	}
+
+	_ = g.Acquire(ctx, weight)
+	if ctx.Err() != nil {
+		return
+	}
+
+	g.Group.Go(func() error {
+		defer g.Release(weight)
+		return fn()
+	})
+}
+
+func (g *LimitGroup) SetLimit(limit int64) {
+	if limit > g.limit {
+		g.limit = limit
+	}
+}
+
+func watchDelta(ctx context.Context, g *LimitGroup, c *atomic.Int64, limit int64) {
+	var maxDelta float64
+	var buckets []int64
+
+	// 5s ramp up period
+	nextUpdate := time.Now().Add(5 * time.Second)
+
+	ticker := time.NewTicker(time.Second)
+	for {
+		select {
+		case <-ticker.C:
+			buckets = append(buckets, c.Load())
+			if len(buckets) < 2 {
+				continue
+			} else if len(buckets) > 10 {
+				buckets = buckets[1:]
+			}
+
+			delta := float64((buckets[len(buckets)-1] - buckets[0])) / float64(len(buckets))
+			slog.Debug("", "limit", limit, "delta", format.HumanBytes(int64(delta)), "max_delta", format.HumanBytes(int64(maxDelta)))
+
+			if time.Now().Before(nextUpdate) {
+				// quiet period; do not update ccy if recently updated
+				continue
+			} else if maxDelta > 0 {
+				x := delta / maxDelta
+				if x < 1.2 {
+					continue
+				}
+
+				limit += int64(x)
+				slog.Debug("setting", "limit", limit)
+				g.SetLimit(limit)
+			}
+
+			// 3s cooldown period
+			nextUpdate = time.Now().Add(3 * time.Second)
+			maxDelta = delta
+
+		case <-ctx.Done():
+			return
+		}
+	}
+}

+ 6 - 4
server/upload.go

@@ -18,7 +18,6 @@ import (
 
 
 	"github.com/jmorganca/ollama/api"
 	"github.com/jmorganca/ollama/api"
 	"github.com/jmorganca/ollama/format"
 	"github.com/jmorganca/ollama/format"
-	"golang.org/x/sync/errgroup"
 )
 )
 
 
 var blobUploadManager sync.Map
 var blobUploadManager sync.Map
@@ -137,14 +136,17 @@ func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) {
 	}
 	}
 	defer b.file.Close()
 	defer b.file.Close()
 
 
-	g, inner := errgroup.WithContext(ctx)
-	g.SetLimit(numUploadParts)
+	var limit int64 = 2
+	g, inner := NewLimitGroup(ctx, numUploadParts, limit)
+	go watchDelta(inner, g, &b.Completed, limit)
+
 	for i := range b.Parts {
 	for i := range b.Parts {
 		part := &b.Parts[i]
 		part := &b.Parts[i]
 		select {
 		select {
 		case <-inner.Done():
 		case <-inner.Done():
+			break
 		case requestURL := <-b.nextURL:
 		case requestURL := <-b.nextURL:
-			g.Go(func() error {
+			g.Go(inner, func() error {
 				var err error
 				var err error
 				for try := 0; try < maxRetries; try++ {
 				for try := 0; try < maxRetries; try++ {
 					err = b.uploadPart(inner, http.MethodPatch, requestURL, part, opts)
 					err = b.uploadPart(inner, http.MethodPatch, requestURL, part, opts)