浏览代码

Merge pull request #1916 from ollama/mxyng/inactivity-monitor

download: add inactivity monitor
Michael Yang 1 年之前
父节点
当前提交
6e0ea5ecc8
共有 1 个文件被更改,包括 74 次插入43 次删除
  1. 74 43
      server/download.go

+ 74 - 43
server/download.go

@@ -25,6 +25,11 @@ import (
 	"github.com/jmorganca/ollama/format"
 	"github.com/jmorganca/ollama/format"
 )
 )
 
 
+const maxRetries = 6
+
+var errMaxRetriesExceeded = errors.New("max retries exceeded")
+var errPartStalled = errors.New("part stalled")
+
 var blobDownloadManager sync.Map
 var blobDownloadManager sync.Map
 
 
 type blobDownload struct {
 type blobDownload struct {
@@ -44,10 +49,11 @@ type blobDownload struct {
 }
 }
 
 
 type blobDownloadPart struct {
 type blobDownloadPart struct {
-	N         int
-	Offset    int64
-	Size      int64
-	Completed int64
+	N           int
+	Offset      int64
+	Size        int64
+	Completed   int64
+	lastUpdated time.Time
 
 
 	*blobDownload `json:"-"`
 	*blobDownload `json:"-"`
 }
 }
@@ -72,6 +78,13 @@ func (p *blobDownloadPart) StopsAt() int64 {
 	return p.Offset + p.Size
 	return p.Offset + p.Size
 }
 }
 
 
+func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
+	n = len(b)
+	p.blobDownload.Completed.Add(int64(n))
+	p.lastUpdated = time.Now()
+	return n, nil
+}
+
 func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
 func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
 	partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
 	partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
 	if err != nil {
 	if err != nil {
@@ -157,6 +170,9 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
 				case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
 				case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
 					// return immediately if the context is canceled or the device is out of space
 					// return immediately if the context is canceled or the device is out of space
 					return err
 					return err
+				case errors.Is(err, errPartStalled):
+					try--
+					continue
 				case err != nil:
 				case err != nil:
 					sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
 					sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
 					slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep))
 					slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep))
@@ -195,28 +211,54 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
 }
 }
 
 
 func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error {
 func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error {
-	headers := make(http.Header)
-	headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
-	resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
-	if err != nil {
-		return err
-	}
-	defer resp.Body.Close()
+	g, ctx := errgroup.WithContext(ctx)
+	g.Go(func() error {
+		headers := make(http.Header)
+		headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
+		resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
+		if err != nil {
+			return err
+		}
+		defer resp.Body.Close()
 
 
-	n, err := io.Copy(w, io.TeeReader(resp.Body, b))
-	if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
-		// rollback progress
-		b.Completed.Add(-n)
-		return err
-	}
+		n, err := io.Copy(w, io.TeeReader(resp.Body, part))
+		if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
+			// rollback progress
+			b.Completed.Add(-n)
+			return err
+		}
 
 
-	part.Completed += n
-	if err := b.writePart(part.Name(), part); err != nil {
+		part.Completed += n
+		if err := b.writePart(part.Name(), part); err != nil {
+			return err
+		}
+
+		// return nil or context.Canceled or UnexpectedEOF (resumable)
 		return err
 		return err
-	}
+	})
+
+	g.Go(func() error {
+		ticker := time.NewTicker(time.Second)
+		for {
+			select {
+			case <-ticker.C:
+				if part.Completed >= part.Size {
+					return nil
+				}
+
+				if !part.lastUpdated.IsZero() && time.Since(part.lastUpdated) > 5*time.Second {
+					log.Printf("%s part %d stalled; retrying", b.Digest[7:19], part.N)
+					// reset last updated
+					part.lastUpdated = time.Time{}
+					return errPartStalled
+				}
+			case <-ctx.Done():
+				return ctx.Err()
+			}
+		}
+	})
 
 
-	// return nil or context.Canceled or UnexpectedEOF (resumable)
-	return err
+	return g.Wait()
 }
 }
 
 
 func (b *blobDownload) newPart(offset, size int64) error {
 func (b *blobDownload) newPart(offset, size int64) error {
@@ -255,12 +297,6 @@ func (b *blobDownload) writePart(partName string, part *blobDownloadPart) error
 	return json.NewEncoder(partFile).Encode(part)
 	return json.NewEncoder(partFile).Encode(part)
 }
 }
 
 
-func (b *blobDownload) Write(p []byte) (n int, err error) {
-	n = len(p)
-	b.Completed.Add(int64(n))
-	return n, nil
-}
-
 func (b *blobDownload) acquire() {
 func (b *blobDownload) acquire() {
 	b.references.Add(1)
 	b.references.Add(1)
 }
 }
@@ -279,20 +315,19 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
 	for {
 	for {
 		select {
 		select {
 		case <-ticker.C:
 		case <-ticker.C:
+			fn(api.ProgressResponse{
+				Status:    fmt.Sprintf("pulling %s", b.Digest[7:19]),
+				Digest:    b.Digest,
+				Total:     b.Total,
+				Completed: b.Completed.Load(),
+			})
+
+			if b.done || b.err != nil {
+				return b.err
+			}
 		case <-ctx.Done():
 		case <-ctx.Done():
 			return ctx.Err()
 			return ctx.Err()
 		}
 		}
-
-		fn(api.ProgressResponse{
-			Status:    fmt.Sprintf("pulling %s", b.Digest[7:19]),
-			Digest:    b.Digest,
-			Total:     b.Total,
-			Completed: b.Completed.Load(),
-		})
-
-		if b.done || b.err != nil {
-			return b.err
-		}
 	}
 	}
 }
 }
 
 
@@ -303,10 +338,6 @@ type downloadOpts struct {
 	fn      func(api.ProgressResponse)
 	fn      func(api.ProgressResponse)
 }
 }
 
 
-const maxRetries = 6
-
-var errMaxRetriesExceeded = errors.New("max retries exceeded")
-
 // downloadBlob downloads a blob from the registry and stores it in the blobs directory
 // downloadBlob downloads a blob from the registry and stores it in the blobs directory
 func downloadBlob(ctx context.Context, opts downloadOpts) error {
 func downloadBlob(ctx context.Context, opts downloadOpts) error {
 	fp, err := GetBlobsPath(opts.digest)
 	fp, err := GetBlobsPath(opts.digest)