Explorar o código

server/internal/client/ollama: fix file descriptor management in Pull (#9931)

Close chunked writers as soon as downloads complete, rather than
deferring closure until Pull exits. This prevents exhausting file
descriptors when pulling many layers.

Instead of unbounded defers, use a WaitGroup and background goroutine
to close each chunked writer as soon as its downloads finish.

Also rename 'total' to 'received' for clarity.
Blake Mizerany hai 1 mes
pai
achega
ce929984a3
Modificáronse 1 ficheiros con 31 adicións e 25 borrados
  1. 31 25
      server/internal/client/ollama/registry.go

+ 31 - 25
server/internal/client/ollama/registry.go

@@ -486,44 +486,43 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
 		expected += l.Size
 	}
 
-	var total atomic.Int64
+	var received atomic.Int64
 	var g errgroup.Group
 	g.SetLimit(r.maxStreams())
 	for _, l := range layers {
 		info, err := c.Get(l.Digest)
 		if err == nil && info.Size == l.Size {
-			total.Add(l.Size)
+			received.Add(l.Size)
 			t.update(l, l.Size, ErrCached)
 			continue
 		}
 
+		var wg sync.WaitGroup
 		chunked, err := c.Chunked(l.Digest, l.Size)
 		if err != nil {
 			t.update(l, 0, err)
 			continue
 		}
-		// TODO(bmizerany): fix this unbounded use of defer
-		defer chunked.Close()
 
 		for cs, err := range r.chunksums(ctx, name, l) {
 			if err != nil {
-				// Chunksum stream was interrupted, so tell
-				// trace about it, and let in-flight chunk
-				// downloads finish. Once they finish, return
-				// ErrIncomplete, which is triggered by the
-				// fact that the total bytes received is less
-				// than the expected bytes.
+				// Chunksum stream interrupted. Note in trace
+				// log and let in-flight downloads complete.
+				// This will naturally trigger ErrIncomplete
+				// since received < expected bytes.
 				t.update(l, 0, err)
 				break
 			}
 
+			wg.Add(1)
 			g.Go(func() (err error) {
 				defer func() {
-					if err == nil || errors.Is(err, ErrCached) {
-						total.Add(cs.Chunk.Size())
+					if err == nil {
+						received.Add(cs.Chunk.Size())
 					} else {
 						err = fmt.Errorf("error downloading %s: %w", cs.Digest.Short(), err)
 					}
+					wg.Done()
 				}()
 
 				req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
@@ -537,27 +536,34 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
 				}
 				defer res.Body.Close()
 
-				// Count bytes towards progress, as they
-				// arrive, so that our bytes piggyback other
-				// chunk updates on completion.
-				//
-				// This tactic is enough to show "smooth"
-				// progress given the current CLI client. In
-				// the near future, the server should report
-				// download rate since it knows better than a
-				// client that is measuring rate based on
-				// wall-clock time-since-last-update.
 				body := &trackingReader{l: l, r: res.Body, update: t.update}
-
 				return chunked.Put(cs.Chunk, cs.Digest, body)
 			})
 		}
+
+		// Close writer immediately after downloads finish, not at Pull
+		// exit. Using defer would keep file descriptors open until all
+		// layers complete, potentially exhausting system limits with
+		// many layers.
+		//
+		// The WaitGroup tracks when all chunks finish downloading,
+		// allowing precise writer closure in a background goroutine.
+		// Each layer briefly uses one extra goroutine while at most
+		// maxStreams()-1 chunks download in parallel.
+		//
+		// This caps file descriptors at maxStreams() instead of
+		// growing with layer count.
+		g.Go(func() error {
+			wg.Wait()
+			chunked.Close()
+			return nil
+		})
 	}
 	if err := g.Wait(); err != nil {
 		return err
 	}
-	if total.Load() != expected {
-		return fmt.Errorf("%w: received %d/%d", ErrIncomplete, total.Load(), expected)
+	if received.Load() != expected {
+		return fmt.Errorf("%w: received %d/%d", ErrIncomplete, received.Load(), expected)
 	}
 
 	md := blob.DigestFromBytes(m.Data)