Browse Source

server/internal/client/ollama: hold DiskCache on Registry (#9463)

Previously, using a Registry required a DiskCache to be passed in for
use in various methods. This was a bit cumbersome, as the DiskCache is
required for most operations, and the DefaultCache is used in most of
those cases. This change makes the DiskCache an optional field on the
Registry struct.

This also changes DefaultCache to initialize on first use. This is to
not burden clients with the cost of creating a new cache per use, or
having to hold onto a cache for the lifetime of the Registry.

Also, slip in some minor docs updates for Trace.
Blake Mizerany 2 months ago
parent
commit
3519dd1c6e

+ 49 - 16
server/internal/client/ollama/registry.go

@@ -27,6 +27,7 @@ import (
 	"slices"
 	"slices"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
+	"sync"
 	"sync/atomic"
 	"sync/atomic"
 	"time"
 	"time"
 
 
@@ -73,19 +74,22 @@ const (
 	DefaultMaxChunkSize = 8 << 20
 	DefaultMaxChunkSize = 8 << 20
 )
 )
 
 
-// DefaultCache returns a new disk cache for storing models. If the
-// OLLAMA_MODELS environment variable is set, it uses that directory;
-// otherwise, it uses $HOME/.ollama/models.
-func DefaultCache() (*blob.DiskCache, error) {
+var defaultCache = sync.OnceValues(func() (*blob.DiskCache, error) {
 	dir := os.Getenv("OLLAMA_MODELS")
 	dir := os.Getenv("OLLAMA_MODELS")
 	if dir == "" {
 	if dir == "" {
-		home, err := os.UserHomeDir()
-		if err != nil {
-			return nil, err
-		}
+		home, _ := os.UserHomeDir()
+		home = cmp.Or(home, ".")
 		dir = filepath.Join(home, ".ollama", "models")
 		dir = filepath.Join(home, ".ollama", "models")
 	}
 	}
 	return blob.Open(dir)
 	return blob.Open(dir)
+})
+
+// DefaultCache returns the default cache used by the registry. It is
+// configured from the OLLAMA_MODELS environment variable, or defaults to
+// $HOME/.ollama/models, or, if an error occurs obtaining the home directory,
+// it uses the current working directory.
+func DefaultCache() (*blob.DiskCache, error) {
+	return defaultCache()
 }
 }
 
 
 // Error is the standard error returned by Ollama APIs. It can represent a
 // Error is the standard error returned by Ollama APIs. It can represent a
@@ -168,6 +172,10 @@ func CompleteName(name string) string {
 // Registry is a client for performing push and pull operations against an
 // Registry is a client for performing push and pull operations against an
 // Ollama registry.
 // Ollama registry.
 type Registry struct {
 type Registry struct {
+	// Cache is the cache used to store models. If nil, [DefaultCache] is
+	// used.
+	Cache *blob.DiskCache
+
 	// UserAgent is the User-Agent header to send with requests to the
 	// UserAgent is the User-Agent header to send with requests to the
 	// registry. If empty, the User-Agent is determined by HTTPClient.
 	// registry. If empty, the User-Agent is determined by HTTPClient.
 	UserAgent string
 	UserAgent string
@@ -206,12 +214,18 @@ type Registry struct {
 	// It is only used when a layer is larger than [MaxChunkingThreshold].
 	// It is only used when a layer is larger than [MaxChunkingThreshold].
 	MaxChunkSize int64
 	MaxChunkSize int64
 
 
-	// Mask, if set, is the name used to convert non-fully qualified
-	// names to fully qualified names. If empty, the default mask
-	// ("registry.ollama.ai/library/_:latest") is used.
+	// Mask, if set, is the name used to convert non-fully qualified names
+	// to fully qualified names. If empty, [DefaultMask] is used.
 	Mask string
 	Mask string
 }
 }
 
 
+func (r *Registry) cache() (*blob.DiskCache, error) {
+	if r.Cache != nil {
+		return r.Cache, nil
+	}
+	return defaultCache()
+}
+
 func (r *Registry) parseName(name string) (names.Name, error) {
 func (r *Registry) parseName(name string) (names.Name, error) {
 	mask := defaultMask
 	mask := defaultMask
 	if r.Mask != "" {
 	if r.Mask != "" {
@@ -282,12 +296,17 @@ type PushParams struct {
 }
 }
 
 
 // Push pushes the model with the name in the cache to the remote registry.
 // Push pushes the model with the name in the cache to the remote registry.
-func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *PushParams) error {
+func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
 	if p == nil {
 	if p == nil {
 		p = &PushParams{}
 		p = &PushParams{}
 	}
 	}
 
 
-	m, err := r.ResolveLocal(c, cmp.Or(p.From, name))
+	c, err := r.cache()
+	if err != nil {
+		return err
+	}
+
+	m, err := r.ResolveLocal(cmp.Or(p.From, name))
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -403,7 +422,7 @@ func canRetry(err error) bool {
 // chunks of the specified size, and then reassembled and verified. This is
 // chunks of the specified size, and then reassembled and verified. This is
 // typically slower than splitting the model up across layers, and is mostly
 // typically slower than splitting the model up across layers, and is mostly
 // utilized for layers of type equal to "application/vnd.ollama.image".
 // utilized for layers of type equal to "application/vnd.ollama.image".
-func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error {
+func (r *Registry) Pull(ctx context.Context, name string) error {
 	scheme, n, _, err := r.parseNameExtended(name)
 	scheme, n, _, err := r.parseNameExtended(name)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
@@ -417,6 +436,11 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err
 		return fmt.Errorf("%w: no layers", ErrManifestInvalid)
 		return fmt.Errorf("%w: no layers", ErrManifestInvalid)
 	}
 	}
 
 
+	c, err := r.cache()
+	if err != nil {
+		return err
+	}
+
 	exists := func(l *Layer) bool {
 	exists := func(l *Layer) bool {
 		info, err := c.Get(l.Digest)
 		info, err := c.Get(l.Digest)
 		return err == nil && info.Size == l.Size
 		return err == nil && info.Size == l.Size
@@ -554,11 +578,15 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err
 
 
 // Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified
 // Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified
 // before attempting to unlink the model.
 // before attempting to unlink the model.
-func (r *Registry) Unlink(c *blob.DiskCache, name string) (ok bool, _ error) {
+func (r *Registry) Unlink(name string) (ok bool, _ error) {
 	n, err := r.parseName(name)
 	n, err := r.parseName(name)
 	if err != nil {
 	if err != nil {
 		return false, err
 		return false, err
 	}
 	}
+	c, err := r.cache()
+	if err != nil {
+		return false, err
+	}
 	return c.Unlink(n.String())
 	return c.Unlink(n.String())
 }
 }
 
 
@@ -631,12 +659,17 @@ type Layer struct {
 }
 }
 
 
 // ResolveLocal resolves a name to a Manifest in the local cache.
 // ResolveLocal resolves a name to a Manifest in the local cache.
-func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
+func (r *Registry) ResolveLocal(name string) (*Manifest, error) {
 	_, n, d, err := r.parseNameExtended(name)
 	_, n, d, err := r.parseNameExtended(name)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
+	c, err := r.cache()
+	if err != nil {
+		return nil, err
+	}
 	if !d.IsValid() {
 	if !d.IsValid() {
+		// No digest, so resolve the manifest by name.
 		d, err = c.Resolve(n.String())
 		d, err = c.Resolve(n.String())
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err

+ 52 - 50
server/internal/client/ollama/registry_test.go

@@ -73,6 +73,7 @@ func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error
 // To simulate a network error, pass a handler that returns a 499 status code.
 // To simulate a network error, pass a handler that returns a 499 status code.
 func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
 func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
 	t.Helper()
 	t.Helper()
+
 	c, err := blob.Open(t.TempDir())
 	c, err := blob.Open(t.TempDir())
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
@@ -86,6 +87,7 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
 	}
 	}
 
 
 	r := &Registry{
 	r := &Registry{
+		Cache: c,
 		HTTPClient: &http.Client{
 		HTTPClient: &http.Client{
 			Transport: recordRoundTripper(h),
 			Transport: recordRoundTripper(h),
 		},
 		},
@@ -152,55 +154,55 @@ func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) {
 }
 }
 
 
 func TestPushZero(t *testing.T) {
 func TestPushZero(t *testing.T) {
-	rc, c := newClient(t, okHandler)
-	err := rc.Push(t.Context(), c, "empty", nil)
+	rc, _ := newClient(t, okHandler)
+	err := rc.Push(t.Context(), "empty", nil)
 	if !errors.Is(err, ErrManifestInvalid) {
 	if !errors.Is(err, ErrManifestInvalid) {
 		t.Errorf("err = %v; want %v", err, ErrManifestInvalid)
 		t.Errorf("err = %v; want %v", err, ErrManifestInvalid)
 	}
 	}
 }
 }
 
 
 func TestPushSingle(t *testing.T) {
 func TestPushSingle(t *testing.T) {
-	rc, c := newClient(t, okHandler)
-	err := rc.Push(t.Context(), c, "single", nil)
+	rc, _ := newClient(t, okHandler)
+	err := rc.Push(t.Context(), "single", nil)
 	testutil.Check(t, err)
 	testutil.Check(t, err)
 }
 }
 
 
 func TestPushMultiple(t *testing.T) {
 func TestPushMultiple(t *testing.T) {
-	rc, c := newClient(t, okHandler)
-	err := rc.Push(t.Context(), c, "multiple", nil)
+	rc, _ := newClient(t, okHandler)
+	err := rc.Push(t.Context(), "multiple", nil)
 	testutil.Check(t, err)
 	testutil.Check(t, err)
 }
 }
 
 
 func TestPushNotFound(t *testing.T) {
 func TestPushNotFound(t *testing.T) {
-	rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+	rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
 		t.Errorf("unexpected request: %v", r)
 		t.Errorf("unexpected request: %v", r)
 	})
 	})
-	err := rc.Push(t.Context(), c, "notfound", nil)
+	err := rc.Push(t.Context(), "notfound", nil)
 	if !errors.Is(err, fs.ErrNotExist) {
 	if !errors.Is(err, fs.ErrNotExist) {
 		t.Errorf("err = %v; want %v", err, fs.ErrNotExist)
 		t.Errorf("err = %v; want %v", err, fs.ErrNotExist)
 	}
 	}
 }
 }
 
 
 func TestPushNullLayer(t *testing.T) {
 func TestPushNullLayer(t *testing.T) {
-	rc, c := newClient(t, nil)
-	err := rc.Push(t.Context(), c, "null", nil)
+	rc, _ := newClient(t, nil)
+	err := rc.Push(t.Context(), "null", nil)
 	if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
 	if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
 		t.Errorf("err = %v; want invalid manifest", err)
 		t.Errorf("err = %v; want invalid manifest", err)
 	}
 	}
 }
 }
 
 
 func TestPushSizeMismatch(t *testing.T) {
 func TestPushSizeMismatch(t *testing.T) {
-	rc, c := newClient(t, nil)
+	rc, _ := newClient(t, nil)
 	ctx, _ := withTraceUnexpected(t.Context())
 	ctx, _ := withTraceUnexpected(t.Context())
-	got := rc.Push(ctx, c, "sizemismatch", nil)
+	got := rc.Push(ctx, "sizemismatch", nil)
 	if got == nil || !strings.Contains(got.Error(), "size mismatch") {
 	if got == nil || !strings.Contains(got.Error(), "size mismatch") {
 		t.Errorf("err = %v; want size mismatch", got)
 		t.Errorf("err = %v; want size mismatch", got)
 	}
 	}
 }
 }
 
 
 func TestPushInvalid(t *testing.T) {
 func TestPushInvalid(t *testing.T) {
-	rc, c := newClient(t, nil)
-	err := rc.Push(t.Context(), c, "invalid", nil)
+	rc, _ := newClient(t, nil)
+	err := rc.Push(t.Context(), "invalid", nil)
 	if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
 	if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
 		t.Errorf("err = %v; want invalid manifest", err)
 		t.Errorf("err = %v; want invalid manifest", err)
 	}
 	}
@@ -208,7 +210,7 @@ func TestPushInvalid(t *testing.T) {
 
 
 func TestPushExistsAtRemote(t *testing.T) {
 func TestPushExistsAtRemote(t *testing.T) {
 	var pushed bool
 	var pushed bool
-	rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+	rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
 		if strings.Contains(r.URL.Path, "/uploads/") {
 		if strings.Contains(r.URL.Path, "/uploads/") {
 			if !pushed {
 			if !pushed {
 				// First push. Return an uploadURL.
 				// First push. Return an uploadURL.
@@ -236,35 +238,35 @@ func TestPushExistsAtRemote(t *testing.T) {
 
 
 	check := testutil.Checker(t)
 	check := testutil.Checker(t)
 
 
-	err := rc.Push(ctx, c, "single", nil)
+	err := rc.Push(ctx, "single", nil)
 	check(err)
 	check(err)
 
 
 	if !errors.Is(errors.Join(errs...), nil) {
 	if !errors.Is(errors.Join(errs...), nil) {
 		t.Errorf("errs = %v; want %v", errs, []error{ErrCached})
 		t.Errorf("errs = %v; want %v", errs, []error{ErrCached})
 	}
 	}
 
 
-	err = rc.Push(ctx, c, "single", nil)
+	err = rc.Push(ctx, "single", nil)
 	check(err)
 	check(err)
 }
 }
 
 
 func TestPushRemoteError(t *testing.T) {
 func TestPushRemoteError(t *testing.T) {
-	rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+	rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
 		if strings.Contains(r.URL.Path, "/blobs/") {
 		if strings.Contains(r.URL.Path, "/blobs/") {
 			w.WriteHeader(500)
 			w.WriteHeader(500)
 			io.WriteString(w, `{"errors":[{"code":"blob_error"}]}`)
 			io.WriteString(w, `{"errors":[{"code":"blob_error"}]}`)
 			return
 			return
 		}
 		}
 	})
 	})
-	got := rc.Push(t.Context(), c, "single", nil)
+	got := rc.Push(t.Context(), "single", nil)
 	checkErrCode(t, got, 500, "blob_error")
 	checkErrCode(t, got, 500, "blob_error")
 }
 }
 
 
 func TestPushLocationError(t *testing.T) {
 func TestPushLocationError(t *testing.T) {
-	rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+	rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
 		w.Header().Set("Location", ":///x")
 		w.Header().Set("Location", ":///x")
 		w.WriteHeader(http.StatusAccepted)
 		w.WriteHeader(http.StatusAccepted)
 	})
 	})
-	got := rc.Push(t.Context(), c, "single", nil)
+	got := rc.Push(t.Context(), "single", nil)
 	wantContains := "invalid upload URL"
 	wantContains := "invalid upload URL"
 	if got == nil || !strings.Contains(got.Error(), wantContains) {
 	if got == nil || !strings.Contains(got.Error(), wantContains) {
 		t.Errorf("err = %v; want to contain %v", got, wantContains)
 		t.Errorf("err = %v; want to contain %v", got, wantContains)
@@ -272,14 +274,14 @@ func TestPushLocationError(t *testing.T) {
 }
 }
 
 
 func TestPushUploadRoundtripError(t *testing.T) {
 func TestPushUploadRoundtripError(t *testing.T) {
-	rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+	rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
 		if r.Host == "blob.store" {
 		if r.Host == "blob.store" {
 			w.WriteHeader(499) // force RoundTrip error on upload
 			w.WriteHeader(499) // force RoundTrip error on upload
 			return
 			return
 		}
 		}
 		w.Header().Set("Location", "http://blob.store/blobs/123")
 		w.Header().Set("Location", "http://blob.store/blobs/123")
 	})
 	})
-	got := rc.Push(t.Context(), c, "single", nil)
+	got := rc.Push(t.Context(), "single", nil)
 	if !errors.Is(got, errRoundTrip) {
 	if !errors.Is(got, errRoundTrip) {
 		t.Errorf("got = %v; want %v", got, errRoundTrip)
 		t.Errorf("got = %v; want %v", got, errRoundTrip)
 	}
 	}
@@ -295,20 +297,20 @@ func TestPushUploadFileOpenError(t *testing.T) {
 			os.Remove(c.GetFile(l.Digest))
 			os.Remove(c.GetFile(l.Digest))
 		},
 		},
 	})
 	})
-	got := rc.Push(ctx, c, "single", nil)
+	got := rc.Push(ctx, "single", nil)
 	if !errors.Is(got, fs.ErrNotExist) {
 	if !errors.Is(got, fs.ErrNotExist) {
 		t.Errorf("got = %v; want fs.ErrNotExist", got)
 		t.Errorf("got = %v; want fs.ErrNotExist", got)
 	}
 	}
 }
 }
 
 
 func TestPushCommitRoundtripError(t *testing.T) {
 func TestPushCommitRoundtripError(t *testing.T) {
-	rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+	rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
 		if strings.Contains(r.URL.Path, "/blobs/") {
 		if strings.Contains(r.URL.Path, "/blobs/") {
 			panic("unexpected")
 			panic("unexpected")
 		}
 		}
 		w.WriteHeader(499) // force RoundTrip error
 		w.WriteHeader(499) // force RoundTrip error
 	})
 	})
-	err := rc.Push(t.Context(), c, "zero", nil)
+	err := rc.Push(t.Context(), "zero", nil)
 	if !errors.Is(err, errRoundTrip) {
 	if !errors.Is(err, errRoundTrip) {
 		t.Errorf("err = %v; want %v", err, errRoundTrip)
 		t.Errorf("err = %v; want %v", err, errRoundTrip)
 	}
 	}
@@ -322,8 +324,8 @@ func checkNotExist(t *testing.T, err error) {
 }
 }
 
 
 func TestRegistryPullInvalidName(t *testing.T) {
 func TestRegistryPullInvalidName(t *testing.T) {
-	rc, c := newClient(t, nil)
-	err := rc.Pull(t.Context(), c, "://")
+	rc, _ := newClient(t, nil)
+	err := rc.Pull(t.Context(), "://")
 	if !errors.Is(err, ErrNameInvalid) {
 	if !errors.Is(err, ErrNameInvalid) {
 		t.Errorf("err = %v; want %v", err, ErrNameInvalid)
 		t.Errorf("err = %v; want %v", err, ErrNameInvalid)
 	}
 	}
@@ -338,10 +340,10 @@ func TestRegistryPullInvalidManifest(t *testing.T) {
 	}
 	}
 
 
 	for _, resp := range cases {
 	for _, resp := range cases {
-		rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+		rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
 			io.WriteString(w, resp)
 			io.WriteString(w, resp)
 		})
 		})
-		err := rc.Pull(t.Context(), c, "x")
+		err := rc.Pull(t.Context(), "x")
 		if !errors.Is(err, ErrManifestInvalid) {
 		if !errors.Is(err, ErrManifestInvalid) {
 			t.Errorf("err = %v; want invalid manifest", err)
 			t.Errorf("err = %v; want invalid manifest", err)
 		}
 		}
@@ -364,18 +366,18 @@ func TestRegistryPullNotCached(t *testing.T) {
 	})
 	})
 
 
 	// Confirm that the layer does not exist locally
 	// Confirm that the layer does not exist locally
-	_, err := rc.ResolveLocal(c, "model")
+	_, err := rc.ResolveLocal("model")
 	checkNotExist(t, err)
 	checkNotExist(t, err)
 
 
 	_, err = c.Get(d)
 	_, err = c.Get(d)
 	checkNotExist(t, err)
 	checkNotExist(t, err)
 
 
-	err = rc.Pull(t.Context(), c, "model")
+	err = rc.Pull(t.Context(), "model")
 	check(err)
 	check(err)
 
 
 	mw, err := rc.Resolve(t.Context(), "model")
 	mw, err := rc.Resolve(t.Context(), "model")
 	check(err)
 	check(err)
-	mg, err := rc.ResolveLocal(c, "model")
+	mg, err := rc.ResolveLocal("model")
 	check(err)
 	check(err)
 	if !reflect.DeepEqual(mw, mg) {
 	if !reflect.DeepEqual(mw, mg) {
 		t.Errorf("mw = %v; mg = %v", mw, mg)
 		t.Errorf("mw = %v; mg = %v", mw, mg)
@@ -400,7 +402,7 @@ func TestRegistryPullNotCached(t *testing.T) {
 
 
 func TestRegistryPullCached(t *testing.T) {
 func TestRegistryPullCached(t *testing.T) {
 	cached := blob.DigestFromBytes("exists")
 	cached := blob.DigestFromBytes("exists")
-	rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+	rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
 		if strings.Contains(r.URL.Path, "/blobs/") {
 		if strings.Contains(r.URL.Path, "/blobs/") {
 			w.WriteHeader(499) // should not be called
 			w.WriteHeader(499) // should not be called
 			return
 			return
@@ -423,7 +425,7 @@ func TestRegistryPullCached(t *testing.T) {
 	ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
 	ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
 	defer cancel()
 	defer cancel()
 
 
-	err := rc.Pull(ctx, c, "single")
+	err := rc.Pull(ctx, "single")
 	testutil.Check(t, err)
 	testutil.Check(t, err)
 
 
 	want := []int64{6}
 	want := []int64{6}
@@ -436,30 +438,30 @@ func TestRegistryPullCached(t *testing.T) {
 }
 }
 
 
 func TestRegistryPullManifestNotFound(t *testing.T) {
 func TestRegistryPullManifestNotFound(t *testing.T) {
-	rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+	rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
 		w.WriteHeader(http.StatusNotFound)
 		w.WriteHeader(http.StatusNotFound)
 	})
 	})
-	err := rc.Pull(t.Context(), c, "notfound")
+	err := rc.Pull(t.Context(), "notfound")
 	checkErrCode(t, err, 404, "")
 	checkErrCode(t, err, 404, "")
 }
 }
 
 
 func TestRegistryPullResolveRemoteError(t *testing.T) {
 func TestRegistryPullResolveRemoteError(t *testing.T) {
-	rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+	rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
 		w.WriteHeader(http.StatusInternalServerError)
 		w.WriteHeader(http.StatusInternalServerError)
 		io.WriteString(w, `{"errors":[{"code":"an_error"}]}`)
 		io.WriteString(w, `{"errors":[{"code":"an_error"}]}`)
 	})
 	})
-	err := rc.Pull(t.Context(), c, "single")
+	err := rc.Pull(t.Context(), "single")
 	checkErrCode(t, err, 500, "an_error")
 	checkErrCode(t, err, 500, "an_error")
 }
 }
 
 
 func TestRegistryPullResolveRoundtripError(t *testing.T) {
 func TestRegistryPullResolveRoundtripError(t *testing.T) {
-	rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+	rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
 		if strings.Contains(r.URL.Path, "/manifests/") {
 		if strings.Contains(r.URL.Path, "/manifests/") {
 			w.WriteHeader(499) // force RoundTrip error
 			w.WriteHeader(499) // force RoundTrip error
 			return
 			return
 		}
 		}
 	})
 	})
-	err := rc.Pull(t.Context(), c, "single")
+	err := rc.Pull(t.Context(), "single")
 	if !errors.Is(err, errRoundTrip) {
 	if !errors.Is(err, errRoundTrip) {
 		t.Errorf("err = %v; want %v", err, errRoundTrip)
 		t.Errorf("err = %v; want %v", err, errRoundTrip)
 	}
 	}
@@ -512,7 +514,7 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) {
 
 
 		// Check that we pull all layers that we can.
 		// Check that we pull all layers that we can.
 
 
-		err := rc.Pull(ctx, c, "mixed")
+		err := rc.Pull(ctx, "mixed")
 		if err != nil {
 		if err != nil {
 			t.Fatal(err)
 			t.Fatal(err)
 		}
 		}
@@ -530,7 +532,7 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) {
 }
 }
 
 
 func TestRegistryPullChunking(t *testing.T) {
 func TestRegistryPullChunking(t *testing.T) {
-	rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
+	rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
 		t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range"))
 		t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range"))
 		if r.URL.Host != "blob.store" {
 		if r.URL.Host != "blob.store" {
 			// The production registry redirects to the blob store.
 			// The production registry redirects to the blob store.
@@ -568,7 +570,7 @@ func TestRegistryPullChunking(t *testing.T) {
 		},
 		},
 	})
 	})
 
 
-	err := rc.Pull(ctx, c, "remote")
+	err := rc.Pull(ctx, "remote")
 	testutil.Check(t, err)
 	testutil.Check(t, err)
 
 
 	want := []int64{0, 3, 6}
 	want := []int64{0, 3, 6}
@@ -785,27 +787,27 @@ func TestParseNameExtended(t *testing.T) {
 
 
 func TestUnlink(t *testing.T) {
 func TestUnlink(t *testing.T) {
 	t.Run("found by name", func(t *testing.T) {
 	t.Run("found by name", func(t *testing.T) {
-		rc, c := newClient(t, nil)
+		rc, _ := newClient(t, nil)
 
 
 		// confirm linked
 		// confirm linked
-		_, err := rc.ResolveLocal(c, "single")
+		_, err := rc.ResolveLocal("single")
 		if err != nil {
 		if err != nil {
 			t.Errorf("unexpected error: %v", err)
 			t.Errorf("unexpected error: %v", err)
 		}
 		}
 
 
 		// unlink
 		// unlink
-		_, err = rc.Unlink(c, "single")
+		_, err = rc.Unlink("single")
 		testutil.Check(t, err)
 		testutil.Check(t, err)
 
 
 		// confirm unlinked
 		// confirm unlinked
-		_, err = rc.ResolveLocal(c, "single")
+		_, err = rc.ResolveLocal("single")
 		if !errors.Is(err, fs.ErrNotExist) {
 		if !errors.Is(err, fs.ErrNotExist) {
 			t.Errorf("err = %v; want fs.ErrNotExist", err)
 			t.Errorf("err = %v; want fs.ErrNotExist", err)
 		}
 		}
 	})
 	})
 	t.Run("not found by name", func(t *testing.T) {
 	t.Run("not found by name", func(t *testing.T) {
-		rc, c := newClient(t, nil)
-		ok, err := rc.Unlink(c, "manifestNotFound")
+		rc, _ := newClient(t, nil)
+		ok, err := rc.Unlink("manifestNotFound")
 		if err != nil {
 		if err != nil {
 			t.Fatal(err)
 			t.Fatal(err)
 		}
 		}

+ 3 - 0
server/internal/client/ollama/trace.go

@@ -6,6 +6,9 @@ import (
 
 
 // Trace is a set of functions that are called to report progress during blob
 // Trace is a set of functions that are called to report progress during blob
 // downloads and uploads.
 // downloads and uploads.
+//
+// Use [WithTrace] to attach a Trace to a context for use with [Registry.Push]
+// and [Registry.Pull].
 type Trace struct {
 type Trace struct {
 	// Update is called during [Registry.Push] and [Registry.Pull] to
 	// Update is called during [Registry.Push] and [Registry.Pull] to
 	// report the progress of blob uploads and downloads.
 	// report the progress of blob uploads and downloads.

+ 21 - 18
server/internal/cmd/opp/opp.go

@@ -63,25 +63,28 @@ func main() {
 	}
 	}
 	flag.Parse()
 	flag.Parse()
 
 
-	c, err := ollama.DefaultCache()
-	if err != nil {
-		log.Fatal(err)
-	}
-
-	rc, err := ollama.DefaultRegistry()
-	if err != nil {
-		log.Fatal(err)
-	}
-
 	ctx := context.Background()
 	ctx := context.Background()
 
 
-	err = func() error {
+	err := func() error {
 		switch cmd := flag.Arg(0); cmd {
 		switch cmd := flag.Arg(0); cmd {
 		case "pull":
 		case "pull":
-			return cmdPull(ctx, rc, c)
+			rc, err := ollama.DefaultRegistry()
+			if err != nil {
+				log.Fatal(err)
+			}
+
+			return cmdPull(ctx, rc)
 		case "push":
 		case "push":
-			return cmdPush(ctx, rc, c)
+			rc, err := ollama.DefaultRegistry()
+			if err != nil {
+				log.Fatal(err)
+			}
+			return cmdPush(ctx, rc)
 		case "import":
 		case "import":
+			c, err := ollama.DefaultCache()
+			if err != nil {
+				log.Fatal(err)
+			}
 			return cmdImport(ctx, c)
 			return cmdImport(ctx, c)
 		default:
 		default:
 			if cmd == "" {
 			if cmd == "" {
@@ -99,7 +102,7 @@ func main() {
 	}
 	}
 }
 }
 
 
-func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error {
+func cmdPull(ctx context.Context, rc *ollama.Registry) error {
 	model := flag.Arg(1)
 	model := flag.Arg(1)
 	if model == "" {
 	if model == "" {
 		flag.Usage()
 		flag.Usage()
@@ -145,7 +148,7 @@ func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
 
 
 	errc := make(chan error)
 	errc := make(chan error)
 	go func() {
 	go func() {
-		errc <- rc.Pull(ctx, c, model)
+		errc <- rc.Pull(ctx, model)
 	}()
 	}()
 
 
 	t := time.NewTicker(time.Second)
 	t := time.NewTicker(time.Second)
@@ -161,7 +164,7 @@ func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
 	}
 	}
 }
 }
 
 
-func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error {
+func cmdPush(ctx context.Context, rc *ollama.Registry) error {
 	args := flag.Args()[1:]
 	args := flag.Args()[1:]
 	flag := flag.NewFlagSet("push", flag.ExitOnError)
 	flag := flag.NewFlagSet("push", flag.ExitOnError)
 	flagFrom := flag.String("from", "", "Use the manifest from a model by another name.")
 	flagFrom := flag.String("from", "", "Use the manifest from a model by another name.")
@@ -177,7 +180,7 @@ func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
 	}
 	}
 
 
 	from := cmp.Or(*flagFrom, model)
 	from := cmp.Or(*flagFrom, model)
-	m, err := rc.ResolveLocal(c, from)
+	m, err := rc.ResolveLocal(from)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -203,7 +206,7 @@ func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
 		},
 		},
 	})
 	})
 
 
-	return rc.Push(ctx, c, model, &ollama.PushParams{
+	return rc.Push(ctx, model, &ollama.PushParams{
 		From: from,
 		From: from,
 	})
 	})
 }
 }

+ 1 - 3
server/internal/registry/server.go

@@ -11,7 +11,6 @@ import (
 	"log/slog"
 	"log/slog"
 	"net/http"
 	"net/http"
 
 
-	"github.com/ollama/ollama/server/internal/cache/blob"
 	"github.com/ollama/ollama/server/internal/client/ollama"
 	"github.com/ollama/ollama/server/internal/client/ollama"
 )
 )
 
 
@@ -27,7 +26,6 @@ import (
 // directly to the blob disk cache.
 // directly to the blob disk cache.
 type Local struct {
 type Local struct {
 	Client *ollama.Registry // required
 	Client *ollama.Registry // required
-	Cache  *blob.DiskCache  // required
 	Logger *slog.Logger     // required
 	Logger *slog.Logger     // required
 
 
 	// Fallback, if set, is used to handle requests that are not handled by
 	// Fallback, if set, is used to handle requests that are not handled by
@@ -199,7 +197,7 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	ok, err := s.Client.Unlink(s.Cache, p.model())
+	ok, err := s.Client.Unlink(p.model())
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}

+ 3 - 3
server/internal/registry/server_test.go

@@ -42,10 +42,10 @@ func newTestServer(t *testing.T) *Local {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 	rc := &ollama.Registry{
 	rc := &ollama.Registry{
+		Cache:      c,
 		HTTPClient: panicOnRoundTrip,
 		HTTPClient: panicOnRoundTrip,
 	}
 	}
 	l := &Local{
 	l := &Local{
-		Cache:  c,
 		Client: rc,
 		Client: rc,
 		Logger: testutil.Slogger(t),
 		Logger: testutil.Slogger(t),
 	}
 	}
@@ -87,7 +87,7 @@ func TestServerDelete(t *testing.T) {
 
 
 	s := newTestServer(t)
 	s := newTestServer(t)
 
 
-	_, err := s.Client.ResolveLocal(s.Cache, "smol")
+	_, err := s.Client.ResolveLocal("smol")
 	check(err)
 	check(err)
 
 
 	got := s.send(t, "DELETE", "/api/delete", `{"model": "smol"}`)
 	got := s.send(t, "DELETE", "/api/delete", `{"model": "smol"}`)
@@ -95,7 +95,7 @@ func TestServerDelete(t *testing.T) {
 		t.Fatalf("Code = %d; want 200", got.Code)
 		t.Fatalf("Code = %d; want 200", got.Code)
 	}
 	}
 
 
-	_, err = s.Client.ResolveLocal(s.Cache, "smol")
+	_, err = s.Client.ResolveLocal("smol")
 	if err == nil {
 	if err == nil {
 		t.Fatal("expected smol to have been deleted")
 		t.Fatal("expected smol to have been deleted")
 	}
 	}

+ 2 - 8
server/routes.go

@@ -34,7 +34,6 @@ import (
 	"github.com/ollama/ollama/llm"
 	"github.com/ollama/ollama/llm"
 	"github.com/ollama/ollama/model/models/mllama"
 	"github.com/ollama/ollama/model/models/mllama"
 	"github.com/ollama/ollama/openai"
 	"github.com/ollama/ollama/openai"
-	"github.com/ollama/ollama/server/internal/cache/blob"
 	"github.com/ollama/ollama/server/internal/client/ollama"
 	"github.com/ollama/ollama/server/internal/client/ollama"
 	"github.com/ollama/ollama/server/internal/registry"
 	"github.com/ollama/ollama/server/internal/registry"
 	"github.com/ollama/ollama/template"
 	"github.com/ollama/ollama/template"
@@ -1129,7 +1128,7 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
 	}
 	}
 }
 }
 
 
-func (s *Server) GenerateRoutes(c *blob.DiskCache, rc *ollama.Registry) (http.Handler, error) {
+func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
 	corsConfig := cors.DefaultConfig()
 	corsConfig := cors.DefaultConfig()
 	corsConfig.AllowWildcard = true
 	corsConfig.AllowWildcard = true
 	corsConfig.AllowBrowserExtensions = true
 	corsConfig.AllowBrowserExtensions = true
@@ -1197,7 +1196,6 @@ func (s *Server) GenerateRoutes(c *blob.DiskCache, rc *ollama.Registry) (http.Ha
 
 
 	// wrap old with new
 	// wrap old with new
 	rs := &registry.Local{
 	rs := &registry.Local{
-		Cache:    c,
 		Client:   rc,
 		Client:   rc,
 		Logger:   slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default()
 		Logger:   slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default()
 		Fallback: r,
 		Fallback: r,
@@ -1258,16 +1256,12 @@ func Serve(ln net.Listener) error {
 
 
 	s := &Server{addr: ln.Addr()}
 	s := &Server{addr: ln.Addr()}
 
 
-	c, err := ollama.DefaultCache()
-	if err != nil {
-		return err
-	}
 	rc, err := ollama.DefaultRegistry()
 	rc, err := ollama.DefaultRegistry()
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 
 
-	h, err := s.GenerateRoutes(c, rc)
+	h, err := s.GenerateRoutes(rc)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}

+ 1 - 7
server/routes_test.go

@@ -23,7 +23,6 @@ import (
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/fs/ggml"
 	"github.com/ollama/ollama/fs/ggml"
 	"github.com/ollama/ollama/openai"
 	"github.com/ollama/ollama/openai"
-	"github.com/ollama/ollama/server/internal/cache/blob"
 	"github.com/ollama/ollama/server/internal/client/ollama"
 	"github.com/ollama/ollama/server/internal/client/ollama"
 	"github.com/ollama/ollama/types/model"
 	"github.com/ollama/ollama/types/model"
 	"github.com/ollama/ollama/version"
 	"github.com/ollama/ollama/version"
@@ -490,11 +489,6 @@ func TestRoutes(t *testing.T) {
 	modelsDir := t.TempDir()
 	modelsDir := t.TempDir()
 	t.Setenv("OLLAMA_MODELS", modelsDir)
 	t.Setenv("OLLAMA_MODELS", modelsDir)
 
 
-	c, err := blob.Open(modelsDir)
-	if err != nil {
-		t.Fatalf("failed to open models dir: %v", err)
-	}
-
 	rc := &ollama.Registry{
 	rc := &ollama.Registry{
 		// This is a temporary measure to allow us to move forward,
 		// This is a temporary measure to allow us to move forward,
 		// surfacing any code contacting ollama.com we do not intended
 		// surfacing any code contacting ollama.com we do not intended
@@ -511,7 +505,7 @@ func TestRoutes(t *testing.T) {
 	}
 	}
 
 
 	s := &Server{}
 	s := &Server{}
-	router, err := s.GenerateRoutes(c, rc)
+	router, err := s.GenerateRoutes(rc)
 	if err != nil {
 	if err != nil {
 		t.Fatalf("failed to generate routes: %v", err)
 		t.Fatalf("failed to generate routes: %v", err)
 	}
 	}