Browse Source

server/internal/client/ollama: cache completed chunks

This change adds tracking of download chunks during the pull process so
that subsequent pulls can skip downloading already completed chunks.
This works across restarts of ollama.

Currently, download state will be lost if a prune is triggered during a
pull (e.g. restart or remove). This issue should be addressed in a
follow-up PR.
Blake Mizerany 1 month ago
parent
commit
cf440b49d7

+ 53 - 2
server/internal/client/ollama/registry.go

@@ -514,13 +514,40 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
 				break
 			}
 
+			cacheKey := fmt.Sprintf(
+				"v1 pull chunksum %s %s %d-%d",
+				l.Digest,
+				cs.Digest,
+				cs.Chunk.Start,
+				cs.Chunk.End,
+			)
+			cacheKeyDigest := blob.DigestFromBytes(cacheKey)
+			_, err := c.Get(cacheKeyDigest)
+			if err == nil {
+				received.Add(cs.Chunk.Size())
+				t.update(l, cs.Chunk.Size(), ErrCached)
+				continue
+			}
+
 			wg.Add(1)
 			g.Go(func() (err error) {
 				defer func() {
 					if err == nil {
+						// Ignore cache key write errors for now. We've already
+						// reported to trace that the chunk is complete.
+						//
+						// Ideally, we should only report completion to trace
+						// after successful cache commit. This current approach
+						// works but could trigger unnecessary redownloads if
+						// the checkpoint key is missing on next pull.
+						//
+						// Not incorrect, just suboptimal - fix this in a
+						// future update.
+						_ = blob.PutBytes(c, cacheKeyDigest, cacheKey)
+
 						received.Add(cs.Chunk.Size())
 					} else {
-						err = fmt.Errorf("error downloading %s: %w", cs.Digest.Short(), err)
+						t.update(l, 0, err)
 					}
 					wg.Done()
 				}()
@@ -563,7 +590,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
 		return err
 	}
 	if received.Load() != expected {
-		return fmt.Errorf("%w: received %d/%d", ErrIncomplete, received.Load(), expected)
+		return fmt.Errorf("%w: received %d/%d bytes", ErrIncomplete, received.Load(), expected)
 	}
 
 	md := blob.DigestFromBytes(m.Data)
@@ -608,6 +635,30 @@ func (m *Manifest) Layer(d blob.Digest) *Layer {
 	return nil
 }
 
+func (m *Manifest) All() iter.Seq[*Layer] {
+	return func(yield func(*Layer) bool) {
+		if !yield(m.Config) {
+			return
+		}
+		for _, l := range m.Layers {
+			if !yield(l) {
+				return
+			}
+		}
+	}
+}
+
+func (m *Manifest) Size() int64 {
+	var size int64
+	if m.Config != nil {
+		size += m.Config.Size
+	}
+	for _, l := range m.Layers {
+		size += l.Size
+	}
+	return size
+}
+
 // MarshalJSON implements json.Marshaler.
 //
 // NOTE: It adds an empty config object to the manifest, which is required by

+ 475 - 0
server/internal/client/ollama/registry_synctest_test.go

@@ -0,0 +1,475 @@
+//go:build goexperiment.synctest
+
+package ollama
+
+import (
+	"bufio"
+	"cmp"
+	"context"
+	"errors"
+	"fmt"
+	"io"
+	"io/fs"
+	"net"
+	"net/http"
+	"os"
+	"strings"
+	"sync/atomic"
+	"testing"
+	"testing/synctest"
+
+	"github.com/ollama/ollama/server/internal/cache/blob"
+)
+
+func newHTTPClient(cn net.Conn) *http.Client {
+	return &http.Client{
+		Transport: &http.Transport{
+			DialContext: func(context.Context, string, string) (net.Conn, error) {
+				return cn, nil
+			},
+		},
+	}
+}
+
+type clientTester struct {
+	t          *testing.T
+	rc         *Registry
+	sc         net.Conn
+	br         *bufio.Reader
+	inProgress atomic.Int64
+}
+
+// newClientTester creates a clientTester with a new pipe connection. If the
+// provided cache is nil, a new cache is created with t.TempDir().
+func newClientTester(t *testing.T, c *blob.DiskCache) *clientTester {
+	t.Helper()
+
+	cc, sc := net.Pipe()
+
+	if c == nil {
+		var err error
+		c, err = blob.Open(t.TempDir())
+		if err != nil {
+			t.Fatal(err)
+		}
+	}
+
+	return &clientTester{
+		t: t,
+		rc: &Registry{
+			Cache:             c,
+			ChunkingThreshold: 2, // set low for ease of testing
+			HTTPClient:        newHTTPClient(cc),
+		},
+		sc: sc,
+		br: bufio.NewReader(sc),
+	}
+}
+
+func (ct *clientTester) setMaxStreams(n int) {
+	ct.rc.MaxStreams = n
+}
+
+func (ct *clientTester) close() {
+	if err := ct.sc.Close(); err != nil {
+		ct.t.Fatal("error closing conn:", err)
+	}
+}
+
+func (ct *clientTester) running() bool {
+	return ct.inProgress.Load() > 0
+}
+
+// pull starts a pull.
+// It tracks the number of in-progress pulls for use with [running].
+// It prefixes the name with "http://example.com/".
+// It does not wait for the pull to complete.
+func (ct *clientTester) pull(ctx context.Context, name string) error {
+	ct.inProgress.Add(1)
+	defer ct.inProgress.Add(-1)
+	return ct.rc.Pull(ctx, fmt.Sprintf("http://example.com/%s", name))
+}
+
+// await reads the next request from the clientTester's bufio.Reader and returns
+// it.
+// If wantPath is not empty, it checks that the request's URL.Path matches
+// wantPath.
+func (ct *clientTester) await(wantPath string) *http.Request {
+	ct.t.Helper()
+	req, err := http.ReadRequest(ct.br)
+	if err != nil {
+		ct.t.Fatal("error reading request:", err)
+	}
+	if wantPath != "" && req.URL.Path != wantPath {
+		ct.t.Fatalf("request = %v; want %v", req.URL.Path, wantPath)
+	}
+	return req
+}
+
+func (ct *clientTester) respond(code int, body string) {
+	ct.t.Helper()
+	err := (&http.Response{
+		ProtoMajor: 1,
+		ProtoMinor: 1,
+
+		StatusCode: code,
+
+		ContentLength: int64(len(body)),
+		Body:          io.NopCloser(strings.NewReader(body)),
+	}).Write(ct.sc)
+	if err != nil {
+		ct.t.Fatal("error writing response:", err)
+	}
+}
+
+type stringReadCloser struct {
+	length int
+	io.Reader
+}
+
+func (r *stringReadCloser) Close() error { return nil }
+
+func stringBody(format string, args ...any) io.ReadCloser {
+	s := fmt.Sprintf(format, args...)
+	return &stringReadCloser{len(s), strings.NewReader(s)}
+}
+
+func (ct *clientTester) respondWith(res *http.Response) {
+	ct.t.Helper()
+	if b, ok := res.Body.(*stringReadCloser); ok {
+		res.ContentLength = int64(b.length)
+	}
+	if res.Body != nil && res.Body != http.NoBody && res.ContentLength == 0 {
+		panic("response with Body must have ContentLength")
+	}
+	res.ProtoMajor = cmp.Or(res.ProtoMajor, 1)
+	res.ProtoMinor = cmp.Or(res.ProtoMinor, 1)
+	res.StatusCode = cmp.Or(res.StatusCode, 200)
+	err := res.Write(ct.sc)
+	if err != nil {
+		ct.t.Fatal("error writing response:", err)
+	}
+}
+
+func checkBlob(t *testing.T, c *blob.DiskCache, d blob.Digest, content string) {
+	t.Helper()
+	info, err := c.Get(d)
+	if err != nil {
+		t.Fatalf("Get(%v) = %v", d, err)
+	}
+	if int(info.Size) != len(content) {
+		t.Errorf("info.Size = %v; want 3", info.Size)
+	}
+	data, err := os.ReadFile(c.GetFile(d))
+	if err != nil {
+		t.Fatalf("ReadFile = %v", err)
+	}
+	if string(data) != content {
+		t.Errorf("data = %q; want abc", data)
+	}
+}
+
+func TestPull(t *testing.T) {
+	t.Run("single", func(t *testing.T) {
+		synctest.Run(func() {
+			ctx := context.Background()
+			ctx = WithTrace(ctx, &Trace{
+				Update: func(l *Layer, n int64, err error) {
+					if errors.Is(err, ErrCached) {
+						t.Errorf("unexpected ErrCached for %v", l.Digest)
+					}
+				},
+			})
+
+			c, err := blob.Open(t.TempDir())
+			if err != nil {
+				t.Fatal(err)
+			}
+
+			ct := newClientTester(t, c)
+			defer ct.close()
+
+			go func() {
+				err := ct.pull(ctx, "library/abc")
+				if err != nil {
+					t.Errorf("pull = %v", err)
+				}
+			}()
+
+			const content = "a" // below chunking threshold
+			sendManifest := func() {
+				ct.respond(200, fmt.Sprintf(`{"layers":[{"digest":%q,"size":%d}]}`, blob.DigestFromBytes(content), len(content)))
+			}
+
+			d := blob.DigestFromBytes(content)
+			ct.await("/v2/library/abc/manifests/latest")
+			sendManifest()
+			synctest.Wait()
+			if !ct.running() {
+				t.Error("pull is not running")
+			}
+
+			// cache should be empty
+			_, err = c.Get(d)
+			if !errors.Is(err, fs.ErrNotExist) {
+				t.Fatalf("Get(%v) = %v; want fs.ErrNotExist", d, err)
+			}
+
+			// blob request/response
+			ct.await("/v2/library/abc/blobs/" + d.String())
+			ct.respond(200, content)
+			synctest.Wait()
+			if ct.running() {
+				t.Error("pull is still running")
+			}
+			checkBlob(t, c, d, content)
+			_, err = c.Resolve("example.com/library/abc:latest")
+			if err != nil {
+				t.Errorf("expected manifest to be linked: %v", err)
+			}
+
+			// repull should be cached
+			ctx = WithTrace(ctx, &Trace{
+				Update: func(l *Layer, n int64, err error) {
+					if n > 0 && !errors.Is(err, ErrCached) {
+						t.Errorf("unexpected error: %v", err)
+					}
+				},
+			})
+
+			go func() {
+				err := ct.pull(ctx, "library/abc")
+				if err != nil {
+					t.Errorf("pull = %v", err)
+				}
+			}()
+
+			ct.await("/v2/library/abc/manifests/latest")
+			sendManifest()
+			synctest.Wait()
+			if ct.running() {
+				t.Error("pull is still running")
+			}
+		})
+
+		t.Run("chunked", func(t *testing.T) {
+			synctest.Run(func() {
+				c, err := blob.Open(t.TempDir())
+				if err != nil {
+					t.Fatal(err)
+				}
+
+				ct := newClientTester(t, c)
+				defer ct.close()
+
+				ctx := WithTrace(t.Context(), &Trace{
+					Update: func(l *Layer, n int64, err error) {
+						if err != nil {
+							t.Errorf("unexpected error: %v", err)
+						}
+					},
+				})
+
+				go func() {
+					ct.setMaxStreams(1)
+					err := ct.pull(ctx, "library/abc")
+					if err != nil {
+						t.Errorf("pull = %v", err)
+					}
+				}()
+
+				const content = "abc" // above chunking threshold
+
+				d := blob.DigestFromBytes(content)
+				sendManifest := func() {
+					ct.respond(200, fmt.Sprintf(`{"layers":[{"digest":%q,"size":%d}]}`, d, len(content)))
+				}
+				ct.await("/v2/library/abc/manifests/latest")
+				sendManifest()
+				ct.await("/v2/library/abc/chunksums/" + d.String())
+
+				s0 := blob.DigestFromBytes("ab")
+				s1 := blob.DigestFromBytes("c")
+				ct.respondWith(&http.Response{
+					Header: http.Header{
+						"Content-Location": []string{"http://example.com/v2/library/abc/blobs/" + d.String()},
+					},
+					Body: stringBody(`
+						%s 0-1
+						%s 2-2
+					`, s0, s1),
+				})
+
+				for i := range 2 {
+					t.Logf("checking range request %d", i)
+
+					req := ct.await("/v2/library/abc/blobs/" + d.String())
+					switch rng := req.Header.Get("Range"); rng {
+					case "bytes=0-1":
+						ct.respond(200, "ab")
+					case "bytes=2-2":
+						ct.respond(200, "c")
+					default:
+						t.Errorf("unexpected range: %q", rng)
+					}
+				}
+
+				synctest.Wait()
+				if ct.running() {
+					t.Error("pull is still running")
+				}
+				checkBlob(t, ct.rc.Cache, d, content)
+				_, err = c.Resolve("example.com/library/abc:latest")
+				if err != nil {
+					t.Errorf("expected manifest to be linked: %v", err)
+				}
+			})
+		})
+	})
+
+	t.Run("errors", func(t *testing.T) {
+		synctest.Run(func() {
+			c, err := blob.Open(t.TempDir())
+			if err != nil {
+				t.Fatal(err)
+			}
+
+			ct := newClientTester(t, c)
+			defer ct.close()
+
+			type update struct {
+				l   *Layer
+				n   int64
+				err error
+			}
+
+			var got strings.Builder
+			ctx := WithTrace(t.Context(), &Trace{
+				Update: func(l *Layer, n int64, err error) {
+					fmt.Fprintf(&got, "%v %d %v\n", l.Digest.Short(), n, err)
+				},
+			})
+
+			go func() {
+				ct.setMaxStreams(1)
+				err := ct.pull(ctx, "library/abc")
+				if err != nil {
+					t.Errorf("pull = %v", err)
+				}
+			}()
+
+			// makeManifest makes a single layer manifest using
+			// content and returns the digest of the content, and
+			// the content of the manifest.
+			makeManifest := func(content string) (blob.Digest, string) {
+				d := blob.DigestFromBytes(content)
+				return d, fmt.Sprintf(`{"layers":[{"digest":%q,"size":%d}]}`, d, len(content))
+			}
+
+			ct.await("/v2/library/abc/manifests/latest")
+			d, man := makeManifest("a")
+			ct.respond(200, man)
+			ct.await("/v2/library/abc/blobs/" + d.String())
+			ct.respond(200, "a")
+			synctest.Wait()
+			if ct.running() {
+				t.Error("pull is still running")
+			}
+			var want strings.Builder
+			want.WriteString(d.Short() + " 0 <nil>\n") // initial announcement
+			want.WriteString(d.Short() + " 1 <nil>\n") // final
+			if got.String() != want.String() {
+				t.Errorf("\ngot:\n%s\nwant:\n%s", got.String(), want.String())
+			}
+
+			// error on manifest fetch
+			got.Reset()
+			done := make(chan error)
+			go func() { done <- ct.pull(ctx, "library/abc") }()
+			ct.await("/v2/library/abc/manifests/latest")
+			ct.respond(400, `some error`)
+			synctest.Wait()
+			if ct.running() {
+				t.Error("pull is still running")
+			}
+			err = <-done
+			if err == nil || !strings.Contains(err.Error(), "some error") {
+				t.Errorf("err = %v; want some error", err)
+			}
+			if got.String() != "" {
+				t.Errorf("\nunexpected traces:\n%s", got.String())
+			}
+
+			// error on blob fetch
+			got.Reset()
+			go func() { done <- ct.pull(ctx, "library/abc") }()
+			ct.await("/v2/library/abc/manifests/latest")
+			d, man = makeManifest("b")
+			ct.respond(200, man)
+			ct.await("/v2/library/abc/blobs/" + d.String())
+			ct.respond(501, `blob store error`)
+			synctest.Wait()
+			if ct.running() {
+				t.Error("pull is still running")
+			}
+			err = <-done
+			if err == nil || !strings.Contains(err.Error(), "blob store error") {
+				t.Errorf("err = %v; want some error", err)
+			}
+
+			// check we get a trace error on blob fetch after some
+			// progress and that one chunksum error does not
+			// prevent the next chunksum request.
+			got.Reset()
+			go func() { done <- ct.pull(ctx, "library/abc") }()
+			ct.await("/v2/library/abc/manifests/latest")
+			d, man = makeManifest("ccc")
+			ct.respond(200, man)
+			ct.await("/v2/library/abc/chunksums/" + d.String())
+			ct.respondWith(&http.Response{
+				Header: http.Header{
+					"Content-Location": []string{"http://example.com/v2/library/abc/blobs/" + d.String()},
+				},
+				Body: stringBody(`
+					%[1]s 0-0
+					%[1]s 1-1
+					%[1]s 2-2
+				`, blob.DigestFromBytes("c")),
+			})
+			req := ct.await("/v2/library/abc/blobs/" + d.String())
+			if rng := req.Header.Get("Range"); rng != "bytes=0-0" {
+				t.Errorf("unexpected range: %q", rng)
+			}
+			ct.respond(200, "c")
+			req = ct.await("/v2/library/abc/blobs/" + d.String())
+			if rng := req.Header.Get("Range"); rng != "bytes=1-1" {
+				t.Errorf("unexpected range: %q", rng)
+			}
+			ct.respond(501, `blob store error`)
+			req = ct.await("/v2/library/abc/blobs/" + d.String())
+			if rng := req.Header.Get("Range"); rng != "bytes=2-2" {
+				t.Errorf("unexpected range: %q", rng)
+			}
+			ct.respond(501, `blob store error`)
+			synctest.Wait()
+			if ct.running() {
+				t.Error("pull is still running")
+			}
+			err = <-done
+			if err == nil || !strings.Contains(err.Error(), "blob store error") {
+				t.Errorf("err = %v; want some error", err)
+			}
+
+			var errorsSeen int
+			for line := range strings.Lines(got.String()) {
+				if strings.Contains(line, "blob store error") {
+					errorsSeen++
+				}
+			}
+			if errorsSeen != 2 {
+				t.Errorf("errorsSeen = %d; want 2", errorsSeen)
+			}
+		})
+	})
+
+}

+ 1 - 77
server/internal/client/ollama/registry_test.go

@@ -17,8 +17,8 @@ import (
 	"reflect"
 	"slices"
 	"strings"
-	"sync"
 	"testing"
+
 	"time"
 
 	"github.com/ollama/ollama/server/internal/cache/blob"
@@ -790,79 +790,3 @@ func TestUnlink(t *testing.T) {
 		}
 	})
 }
-
-func TestPullChunksums(t *testing.T) {
-	check := testutil.Checker(t)
-
-	content := "hello"
-	var chunksums string
-	contentDigest := func() blob.Digest {
-		return blob.DigestFromBytes(content)
-	}
-	rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
-		switch {
-		case strings.Contains(r.URL.Path, "/manifests/latest"):
-			fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":%d}]}`, contentDigest(), len(content))
-		case strings.HasSuffix(r.URL.Path, "/chunksums/"+contentDigest().String()):
-			loc := fmt.Sprintf("http://blob.store/v2/library/test/blobs/%s", contentDigest())
-			w.Header().Set("Content-Location", loc)
-			io.WriteString(w, chunksums)
-		case strings.Contains(r.URL.Path, "/blobs/"+contentDigest().String()):
-			http.ServeContent(w, r, contentDigest().String(), time.Time{}, strings.NewReader(content))
-		default:
-			t.Errorf("unexpected request: %v", r)
-			http.NotFound(w, r)
-		}
-	})
-
-	rc.MaxStreams = 1        // prevent concurrent chunk downloads
-	rc.ChunkingThreshold = 1 // for all blobs to be chunked
-
-	var mu sync.Mutex
-	var reads []int64
-	ctx := WithTrace(t.Context(), &Trace{
-		Update: func(l *Layer, n int64, err error) {
-			t.Logf("Update: %v %d %v", l, n, err)
-			mu.Lock()
-			reads = append(reads, n)
-			mu.Unlock()
-		},
-	})
-
-	chunksums = fmt.Sprintf("%s 0-2\n%s 3-4\n",
-		blob.DigestFromBytes("hel"),
-		blob.DigestFromBytes("lo"),
-	)
-	err := rc.Pull(ctx, "test")
-	check(err)
-	wantReads := []int64{
-		0, // initial signaling of layer pull starting
-		3, // first chunk read
-		2, // second chunk read
-	}
-	if !slices.Equal(reads, wantReads) {
-		t.Errorf("reads = %v; want %v", reads, wantReads)
-	}
-
-	mw, err := rc.Resolve(t.Context(), "test")
-	check(err)
-	mg, err := rc.ResolveLocal("test")
-	check(err)
-	if !reflect.DeepEqual(mw, mg) {
-		t.Errorf("mw = %v; mg = %v", mw, mg)
-	}
-	for i := range mg.Layers {
-		_, err = c.Get(mg.Layers[i].Digest)
-		if err != nil {
-			t.Errorf("Get(%v): %v", mg.Layers[i].Digest, err)
-		}
-	}
-
-	// missing chunks
-	content = "llama"
-	chunksums = fmt.Sprintf("%s 0-1\n", blob.DigestFromBytes("ll"))
-	err = rc.Pull(ctx, "missingchunks")
-	if err == nil {
-		t.Error("expected error because of missing chunks")
-	}
-}

BIN
server/internal/client/ollama/testdata/ollama.com/v2/library/smol/blobs/sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99


+ 1 - 0
server/internal/client/ollama/testdata/ollama.com/v2/library/smol/blobs/sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116

@@ -0,0 +1 @@
+{"model_format":"gguf","model_family":"unknown","model_families":["unknown"],"model_type":"0","file_type":"unknown","architecture":"amd64","os":"linux","rootfs":{"type":"layers","diff_ids":["sha256:a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"]}}

+ 1 - 0
server/internal/client/ollama/testdata/ollama.com/v2/library/smol/manifests/latest

@@ -0,0 +1 @@
+{"schemaVersion":2,"mediaType":"application/vnd.docker.distribution.manifest.v2+json","config":{"mediaType":"application/vnd.docker.container.image.v1+json","digest":"sha256:ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116","size":267},"layers":[{"mediaType":"application/vnd.ollama.image.model","digest":"sha256:a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99","size":24}]}