Pārlūkot izejas kodu

server/internal/client/ollama: persist through chunk download errors (#9923)

Blake Mizerany 1 mēnesi atpakaļ
vecāks
revīzija
c794fef2f2

+ 40 - 21
server/internal/client/ollama/registry.go

@@ -59,6 +59,11 @@ var (
 	// ErrCached is passed to [Trace.PushUpdate] when a layer already
 	// exists. It is a non-fatal error and is never returned by [Registry.Push].
 	ErrCached = errors.New("cached")
+
+	// ErrIncomplete is returned by [Registry.Pull] when a model pull was
+	// incomplete due to one or more layer download failures. Users that
+	// want specific errors should use [WithTrace].
+	ErrIncomplete = errors.New("incomplete")
 )
 
 // Defaults
@@ -271,8 +276,19 @@ func DefaultRegistry() (*Registry, error) {
 
 func UserAgent() string {
 	buildinfo, _ := debug.ReadBuildInfo()
+
+	version := buildinfo.Main.Version
+	if version == "(devel)" {
+		// When using `go run .` the version is "(devel)". This is seen
+		// as an invalid version by ollama.com and so it defaults to
+		// "needs upgrade" for some requests, such as pulls. These
+		// checks can be skipped by using the special version "v0.0.0",
+		// so we set it to that here.
+		version = "v0.0.0"
+	}
+
 	return fmt.Sprintf("ollama/%s (%s %s) Go/%s",
-		buildinfo.Main.Version,
+		version,
 		runtime.GOARCH,
 		runtime.GOOS,
 		runtime.Version(),
@@ -418,13 +434,14 @@ func canRetry(err error) bool {
 //
 // It always calls update with a nil error.
 type trackingReader struct {
-	r io.Reader
-	n *atomic.Int64
+	l      *Layer
+	r      io.Reader
+	update func(l *Layer, n int64, err error)
 }
 
 func (r *trackingReader) Read(p []byte) (n int, err error) {
 	n, err = r.r.Read(p)
-	r.n.Add(int64(n))
+	r.update(r.l, int64(n), nil)
 	return
 }
 
@@ -462,16 +479,20 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
 
 	// Send initial layer trace events to allow clients to have an
 	// understanding of work to be done before work starts.
+	var expected int64
 	t := traceFromContext(ctx)
 	for _, l := range layers {
 		t.update(l, 0, nil)
+		expected += l.Size
 	}
 
+	var total atomic.Int64
 	var g errgroup.Group
 	g.SetLimit(r.maxStreams())
 	for _, l := range layers {
 		info, err := c.Get(l.Digest)
 		if err == nil && info.Size == l.Size {
+			total.Add(l.Size)
 			t.update(l, l.Size, ErrCached)
 			continue
 		}
@@ -484,21 +505,25 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
 		// TODO(bmizerany): fix this unbounded use of defer
 		defer chunked.Close()
 
-		var progress atomic.Int64
 		for cs, err := range r.chunksums(ctx, name, l) {
 			if err != nil {
-				// Bad chunksums response, update tracing
-				// clients and then bail.
-				t.update(l, progress.Load(), err)
-				return err
+				// Chunksum stream was interrupted, so tell
+				// trace about it, and let in-flight chunk
+				// downloads finish. Once they finish, return
+				// ErrIncomplete, which is triggered by the
+				// fact that the total bytes received is less
+				// than the expected bytes.
+				t.update(l, 0, err)
+				break
 			}
 
 			g.Go(func() (err error) {
 				defer func() {
-					if err != nil {
+					if err == nil || errors.Is(err, ErrCached) {
+						total.Add(cs.Chunk.Size())
+					} else {
 						err = fmt.Errorf("error downloading %s: %w", cs.Digest.Short(), err)
 					}
-					t.update(l, progress.Load(), err)
 				}()
 
 				req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
@@ -522,7 +547,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
 				// download rate since it knows better than a
 				// client that is measuring rate based on
 				// wall-clock time-since-last-update.
-				body := &trackingReader{r: res.Body, n: &progress}
+				body := &trackingReader{l: l, r: res.Body, update: t.update}
 
 				return chunked.Put(cs.Chunk, cs.Digest, body)
 			})
@@ -531,6 +556,9 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
 	if err := g.Wait(); err != nil {
 		return err
 	}
+	if total.Load() != expected {
+		return fmt.Errorf("%w: received %d/%d", ErrIncomplete, total.Load(), expected)
+	}
 
 	md := blob.DigestFromBytes(m.Data)
 	if err := blob.PutBytes(c, md, m.Data); err != nil {
@@ -757,15 +785,12 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se
 		}
 		blobURL := res.Header.Get("Content-Location")
 
-		var size int64
 		s := bufio.NewScanner(res.Body)
 		s.Split(bufio.ScanWords)
 		for {
 			if !s.Scan() {
 				if s.Err() != nil {
 					yield(chunksum{}, s.Err())
-				} else if size != l.Size {
-					yield(chunksum{}, fmt.Errorf("size mismatch: layer size %d != sum of chunks %d", size, l.Size))
 				}
 				return
 			}
@@ -789,12 +814,6 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se
 				return
 			}
 
-			size += chunk.Size()
-			if size > l.Size {
-				yield(chunksum{}, fmt.Errorf("chunk size %d exceeds layer size %d", size, l.Size))
-				return
-			}
-
 			cs := chunksum{
 				URL:    blobURL,
 				Chunk:  chunk,

+ 29 - 2
server/internal/client/ollama/registry_test.go

@@ -25,6 +25,28 @@ import (
 	"github.com/ollama/ollama/server/internal/testutil"
 )
 
+func ExampleRegistry_cancelOnFirstError() {
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
+	ctx = WithTrace(ctx, &Trace{
+		Update: func(l *Layer, n int64, err error) {
+			if err != nil {
+				// Discontinue pulling layers if there is an
+				// error instead of continuing to pull more
+				// data.
+				cancel()
+			}
+		},
+	})
+
+	var r Registry
+	if err := r.Pull(ctx, "model"); err != nil {
+		// panic for demo purposes
+		panic(err)
+	}
+}
+
 func TestManifestMarshalJSON(t *testing.T) {
 	// All manifests should contain an "empty" config object.
 	var m Manifest
@@ -813,8 +835,13 @@ func TestPullChunksums(t *testing.T) {
 	)
 	err := rc.Pull(ctx, "test")
 	check(err)
-	if !slices.Equal(reads, []int64{0, 3, 5}) {
-		t.Errorf("reads = %v; want %v", reads, []int64{0, 3, 5})
+	wantReads := []int64{
+		0, // initial signaling of layer pull starting
+		3, // first chunk read
+		2, // second chunk read
+	}
+	if !slices.Equal(reads, wantReads) {
+		t.Errorf("reads = %v; want %v", reads, wantReads)
 	}
 
 	mw, err := rc.Resolve(t.Context(), "test")

+ 20 - 13
server/internal/registry/server.go

@@ -200,7 +200,7 @@ type params struct {
 	//
 	// Unfortunately, this API was designed to be a bit awkward. Stream is
 	// defined to default to true if not present, so we need a way to check
-	// if the client decisively it to false. So, we use a pointer to a
+	// if the client decisively set it to false. So, we use a pointer to a
 	// bool. Gross.
 	//
 	// Use [stream()] to get the correct value for this field.
@@ -280,17 +280,17 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
 	progress := make(map[*ollama.Layer]int64)
 
 	progressCopy := make(map[*ollama.Layer]int64, len(progress))
-	pushUpdate := func() {
+	flushProgress := func() {
 		defer maybeFlush()
 
-		// TODO(bmizerany): This scales poorly with more layers due to
-		// needing to flush out them all in one big update. We _could_
-		// just flush on the changed ones, or just track the whole
-		// download. Needs more thought. This is fine for now.
+		// TODO(bmizerany): Flushing every layer in one update doesn't
+		// scale well. We could flush only the modified layers or track
+		// the full download. Needs further consideration, though it's
+		// fine for now.
 		mu.Lock()
 		maps.Copy(progressCopy, progress)
 		mu.Unlock()
-		for l, n := range progress {
+		for l, n := range progressCopy {
 			enc.Encode(progressUpdateJSON{
 				Digest:    l.Digest,
 				Total:     l.Size,
@@ -298,19 +298,26 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
 			})
 		}
 	}
+	defer flushProgress()
 
-	t := time.NewTicker(time.Hour) // "unstarted" timer
+	t := time.NewTicker(1000 * time.Hour) // "unstarted" timer
 	start := sync.OnceFunc(func() {
-		pushUpdate()
+		flushProgress() // flush initial state
 		t.Reset(100 * time.Millisecond)
 	})
 	ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
 		Update: func(l *ollama.Layer, n int64, err error) {
 			if n > 0 {
-				start() // flush initial state
+				// Block flushing progress updates until every
+				// layer is accounted for. Clients depend on a
+				// complete model size to calculate progress
+				// correctly; if they use an incomplete total,
+				// progress indicators would erratically jump
+				// as new layers are registered.
+				start()
 			}
 			mu.Lock()
-			progress[l] = n
+			progress[l] += n
 			mu.Unlock()
 		},
 	})
@@ -323,9 +330,9 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
 	for {
 		select {
 		case <-t.C:
-			pushUpdate()
+			flushProgress()
 		case err := <-done:
-			pushUpdate()
+			flushProgress()
 			if err != nil {
 				var status string
 				if errors.Is(err, ollama.ErrModelNotFound) {