Selaa lähdekoodia

server/internal/registry: take over pulls from server package (#9485)

This commit replaces the old pull implementation in the server package
with the new, faster, more robust pull implementation in the registry
package.

The new endpoint, and now the remove endpoint too, are behind the
feature gate "client2" enabled only by setting the OLLAMA_EXPERIMENT
environment variable include "client2".

Currently, the progress indication is wired to perform the same as the
previous implementation to avoid making changes to the CLI, and because
the status reports happen at the start of the download, and the end of
the write to disk, the progress indication is not as smooth as it could
be. This is a known issue and will be addressed in a future change.

This implementation may be ~0.5-1.0% slower in rare cases, depending on
network and disk speed, but is generally MUCH faster and more robust
than the its predecessor in all other cases.
Blake Mizerany 1 kuukausi sitten
vanhempi
commit
e2252d0fc6

+ 3 - 3
api/types.go

@@ -361,9 +361,9 @@ type CopyRequest struct {
 // PullRequest is the request passed to [Client.Pull].
 type PullRequest struct {
 	Model    string `json:"model"`
-	Insecure bool   `json:"insecure,omitempty"`
-	Username string `json:"username"`
-	Password string `json:"password"`
+	Insecure bool   `json:"insecure,omitempty"` // Deprecated: ignored
+	Username string `json:"username"`           // Deprecated: ignored
+	Password string `json:"password"`           // Deprecated: ignored
 	Stream   *bool  `json:"stream,omitempty"`
 
 	// Deprecated: set the model name with Model instead

+ 1 - 0
go.mod

@@ -24,6 +24,7 @@ require (
 	github.com/nlpodyssey/gopickle v0.3.0
 	github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
 	golang.org/x/image v0.22.0
+	golang.org/x/tools v0.30.0
 	gonum.org/v1/gonum v0.15.0
 )
 

+ 2 - 0
go.sum

@@ -309,6 +309,8 @@ golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapK
 golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
 golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
 golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
+golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
+golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
 golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

+ 86 - 31
server/internal/client/ollama/registry.go

@@ -45,9 +45,9 @@ import (
 
 // Errors
 var (
-	// ErrManifestNotFound is returned when a manifest is not found in the
+	// ErrModelNotFound is returned when a manifest is not found in the
 	// cache or registry.
-	ErrManifestNotFound = errors.New("manifest not found")
+	ErrModelNotFound = errors.New("model not found")
 
 	// ErrManifestInvalid is returned when a manifest found in a local or
 	// remote cache is invalid.
@@ -114,7 +114,18 @@ type Error struct {
 }
 
 func (e *Error) Error() string {
-	return fmt.Sprintf("registry responded with status %d: %s %s", e.Status, e.Code, e.Message)
+	var b strings.Builder
+	b.WriteString("registry responded with status ")
+	b.WriteString(strconv.Itoa(e.Status))
+	if e.Code != "" {
+		b.WriteString(": code ")
+		b.WriteString(e.Code)
+	}
+	if e.Message != "" {
+		b.WriteString(": ")
+		b.WriteString(e.Message)
+	}
+	return b.String()
 }
 
 func (e *Error) LogValue() slog.Value {
@@ -355,7 +366,7 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
 				n.Model(),
 				l.Digest,
 			)
-			res, err := r.doOK(ctx, "POST", startURL, nil)
+			res, err := r.send(ctx, "POST", startURL, nil)
 			if err != nil {
 				return err
 			}
@@ -379,7 +390,7 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
 			}
 			req.ContentLength = l.Size
 
-			res, err = doOK(r.client(), req)
+			res, err = sendRequest(r.client(), req)
 			if err == nil {
 				res.Body.Close()
 			}
@@ -399,7 +410,7 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
 		n.Model(),
 		n.Tag(),
 	)
-	res, err := r.doOK(ctx, "PUT", path, bytes.NewReader(m.Data))
+	res, err := r.send(ctx, "PUT", path, bytes.NewReader(m.Data))
 	if err == nil {
 		res.Body.Close()
 	}
@@ -448,10 +459,15 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
 
 	t := traceFromContext(ctx)
 
-	var g errgroup.Group
+	g, ctx := errgroup.WithContext(ctx)
 	g.SetLimit(r.maxStreams())
 
-	for _, l := range m.Layers {
+	layers := m.Layers
+	if m.Config != nil && m.Config.Digest.IsValid() {
+		layers = append(layers, m.Config)
+	}
+
+	for _, l := range layers {
 		if exists(l) {
 			t.update(l, l.Size, ErrCached)
 			continue
@@ -468,7 +484,9 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
 
 		if l.Size <= r.maxChunkingThreshold() {
 			g.Go(func() error {
-				res, err := doOK(r.client(), req)
+				// TODO(bmizerany): retry/backoff like below in
+				// the chunking case
+				res, err := sendRequest(r.client(), req)
 				if err != nil {
 					return err
 				}
@@ -494,19 +512,21 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
 			// fire an initial request to get the final URL and
 			// then use that URL for the chunk requests.
 			req.Header.Set("Range", "bytes=0-0")
-			res, err := doOK(r.client(), req)
+			res, err := sendRequest(r.client(), req)
 			if err != nil {
 				return err
 			}
 			res.Body.Close()
 			req = res.Request.WithContext(req.Context())
 
-			streamNo := 0
-			tws := make([]*bufio.Writer, r.maxStreams()-1)
+			wp := writerPool{size: r.maxChunkSize()}
+
 			for chunk := range chunks.Of(l.Size, r.maxChunkSize()) {
+				if ctx.Err() != nil {
+					break
+				}
+
 				ticket := q.Take()
-				bufIdx := streamNo % len(tws)
-				streamNo++
 				g.Go(func() (err error) {
 					defer func() {
 						if err != nil {
@@ -520,23 +540,18 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
 						if err != nil {
 							return err
 						}
-
 						err := func() error {
 							req := req.Clone(req.Context())
 							req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk))
-							res, err := doOK(r.client(), req)
+							res, err := sendRequest(r.client(), req)
 							if err != nil {
 								return err
 							}
 							defer res.Body.Close()
 
-							tw := tws[bufIdx]
-							if tw == nil {
-								tw = bufio.NewWriterSize(nil, int(r.maxChunkSize()))
-								tws[bufIdx] = tw
-							}
+							tw := wp.get()
 							tw.Reset(ticket)
-							defer tw.Reset(nil) // release ticket
+							defer wp.put(tw)
 
 							_, err = io.CopyN(tw, res.Body, chunk.Size())
 							if err != nil {
@@ -595,6 +610,9 @@ type Manifest struct {
 	Name   string   `json:"-"` // the canonical name of the model
 	Data   []byte   `json:"-"` // the raw data of the manifest
 	Layers []*Layer `json:"layers"`
+
+	// For legacy reasons, we still have to download the config layer.
+	Config *Layer `json:"config"`
 }
 
 var emptyDigest, _ = blob.ParseDigest("sha256:0000000000000000000000000000000000000000000000000000000000000000")
@@ -678,7 +696,7 @@ func (r *Registry) ResolveLocal(name string) (*Manifest, error) {
 	data, err := os.ReadFile(c.GetFile(d))
 	if err != nil {
 		if errors.Is(err, fs.ErrNotExist) {
-			return nil, fmt.Errorf("%w: %s", ErrManifestNotFound, name)
+			return nil, fmt.Errorf("%w: %s", ErrModelNotFound, name)
 		}
 		return nil, err
 	}
@@ -701,7 +719,7 @@ func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error)
 		manifestURL = fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", scheme, n.Host(), n.Namespace(), n.Model(), d)
 	}
 
-	res, err := r.doOK(ctx, "GET", manifestURL, nil)
+	res, err := r.send(ctx, "GET", manifestURL, nil)
 	if err != nil {
 		return nil, err
 	}
@@ -726,7 +744,7 @@ func (r *Registry) client() *http.Client {
 }
 
 // newRequest constructs a new request, ready to use, with the given method,
-// url, and body, presigned with client Key and UserAgent.
+// url, and body, pre-signed with client [Key] and [UserAgent].
 func (r *Registry) newRequest(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) {
 	req, err := http.NewRequestWithContext(ctx, method, url, body)
 	if err != nil {
@@ -745,11 +763,17 @@ func (r *Registry) newRequest(ctx context.Context, method, url string, body io.R
 	return req, nil
 }
 
-// doOK makes a request with the given client and request, and returns the
+// sendRequest makes a request with the given client and request, and returns the
 // response if the status code is 200. If the status code is not 200, an Error
 // is parsed from the response body and returned. If any other error occurs, it
 // is returned.
-func doOK(c *http.Client, r *http.Request) (*http.Response, error) {
+func sendRequest(c *http.Client, r *http.Request) (_ *http.Response, err error) {
+	defer func() {
+		if err != nil {
+			err = fmt.Errorf("request error %s: %w", r.URL, err)
+		}
+	}()
+
 	if r.URL.Scheme == "https+insecure" {
 		// TODO(bmizerany): clone client.Transport, set
 		// InsecureSkipVerify, etc.
@@ -792,20 +816,26 @@ func doOK(c *http.Client, r *http.Request) (*http.Response, error) {
 			// Use the raw body if we can't parse it as an error object.
 			re.Message = string(out)
 		}
+
+		// coerce MANIFEST_UNKNOWN to ErrManifestNotFound
+		if strings.EqualFold(re.Code, "MANIFEST_UNKNOWN") {
+			return nil, ErrModelNotFound
+		}
+
 		re.Status = res.StatusCode
 		return nil, &re
 	}
 	return res, nil
 }
 
-// doOK is a convenience method for making a request with newRequest and
-// passing it to doOK with r.client().
-func (r *Registry) doOK(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
+// send is a convenience method for making a request with newRequest and
+// passing it to send with r.client().
+func (r *Registry) send(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
 	req, err := r.newRequest(ctx, method, path, body)
 	if err != nil {
 		return nil, err
 	}
-	return doOK(r.client(), req)
+	return sendRequest(r.client(), req)
 }
 
 // makeAuthToken creates an Ollama auth token for the given private key.
@@ -960,3 +990,28 @@ func splitExtended(s string) (scheme, name, digest string) {
 	}
 	return scheme, s, digest
 }
+
+type writerPool struct {
+	size int64 // set by the caller
+
+	mu sync.Mutex
+	ws []*bufio.Writer
+}
+
+func (p *writerPool) get() *bufio.Writer {
+	p.mu.Lock()
+	defer p.mu.Unlock()
+	if len(p.ws) == 0 {
+		return bufio.NewWriterSize(nil, int(p.size))
+	}
+	w := p.ws[len(p.ws)-1]
+	p.ws = p.ws[:len(p.ws)-1]
+	return w
+}
+
+func (p *writerPool) put(w *bufio.Writer) {
+	p.mu.Lock()
+	defer p.mu.Unlock()
+	w.Reset(nil)
+	p.ws = append(p.ws, w)
+}

+ 1 - 1
server/internal/client/ollama/registry_test.go

@@ -608,7 +608,7 @@ func TestInsecureSkipVerify(t *testing.T) {
 	url := fmt.Sprintf("https://%s/%s", s.Listener.Addr(), name)
 	_, err := rc.Resolve(t.Context(), url)
 	if err == nil || !strings.Contains(err.Error(), "failed to verify") {
-		t.Errorf("err = %v; want cert verifiction failure", err)
+		t.Errorf("err = %v; want cert verification failure", err)
 	}
 
 	url = fmt.Sprintf("https+insecure://%s/%s", s.Listener.Addr(), name)

+ 7 - 3
server/internal/client/ollama/trace.go

@@ -13,9 +13,13 @@ type Trace struct {
 	// Update is called during [Registry.Push] and [Registry.Pull] to
 	// report the progress of blob uploads and downloads.
 	//
-	// It is called once at the beginning of the download with a zero n and
-	// then once per read operation with the number of bytes read so far,
-	// and an error if any.
+	// The n argument is the number of bytes transferred so far, and err is
+	// any error that has occurred. If n == 0, and err is nil, the download
+	// or upload has just started. If err is [ErrCached], the download or
+	// upload has been skipped because the blob is already present in the
+	// local cache or remote registry, respectively. Otherwise, if err is
+	// non-nil, the download or upload has failed. When l.Size == n, and
+	// err is nil, the download or upload has completed.
 	//
 	// A function assigned must be safe for concurrent use. The function is
 	// called synchronously and so should not block or take long to run.

+ 97 - 0
server/internal/registry/server.go

@@ -7,10 +7,14 @@ import (
 	"cmp"
 	"encoding/json"
 	"errors"
+	"fmt"
 	"io"
 	"log/slog"
 	"net/http"
+	"sync"
+	"time"
 
+	"github.com/ollama/ollama/server/internal/cache/blob"
 	"github.com/ollama/ollama/server/internal/client/ollama"
 )
 
@@ -109,6 +113,8 @@ func (s *Local) serveHTTP(rec *statusCodeRecorder, r *http.Request) {
 		switch r.URL.Path {
 		case "/api/delete":
 			return false, s.handleDelete(rec, r)
+		case "/api/pull":
+			return false, s.handlePull(rec, r)
 		default:
 			if s.Fallback != nil {
 				s.Fallback.ServeHTTP(rec, r)
@@ -214,6 +220,97 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
 	return s.Prune()
 }
 
+type progressUpdateJSON struct {
+	Status    string      `json:"status"`
+	Digest    blob.Digest `json:"digest,omitempty,omitzero"`
+	Total     int64       `json:"total,omitempty,omitzero"`
+	Completed int64       `json:"completed,omitempty,omitzero"`
+}
+
+func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
+	if r.Method != "POST" {
+		return errMethodNotAllowed
+	}
+
+	p, err := decodeUserJSON[*params](r.Body)
+	if err != nil {
+		return err
+	}
+
+	maybeFlush := func() {
+		fl, _ := w.(http.Flusher)
+		if fl != nil {
+			fl.Flush()
+		}
+	}
+	defer maybeFlush()
+
+	var mu sync.Mutex
+	enc := json.NewEncoder(w)
+	enc.Encode(progressUpdateJSON{Status: "pulling manifest"})
+
+	ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
+		Update: func(l *ollama.Layer, n int64, err error) {
+			mu.Lock()
+			defer mu.Unlock()
+
+			// TODO(bmizerany): coalesce these updates; writing per
+			// update is expensive
+			enc.Encode(progressUpdateJSON{
+				Digest:    l.Digest,
+				Status:    "pulling",
+				Total:     l.Size,
+				Completed: n,
+			})
+		},
+	})
+
+	done := make(chan error, 1)
+	go func() {
+		// TODO(bmizerany): continue to support non-streaming responses
+		done <- s.Client.Pull(ctx, p.model())
+	}()
+
+	func() {
+		t := time.NewTicker(100 * time.Millisecond)
+		defer t.Stop()
+		for {
+			select {
+			case <-t.C:
+				mu.Lock()
+				maybeFlush()
+				mu.Unlock()
+			case err := <-done:
+				if err != nil {
+					var status string
+					if errors.Is(err, ollama.ErrModelNotFound) {
+						status = fmt.Sprintf("error: model %q not found", p.model())
+						enc.Encode(progressUpdateJSON{Status: status})
+					} else {
+						status = fmt.Sprintf("error: %v", err)
+						enc.Encode(progressUpdateJSON{Status: status})
+					}
+					return
+				}
+
+				// These final updates are not strictly necessary, because they have
+				// already happened at this point. Our pull handler code used to do
+				// these steps after, not during, the pull, and they were slow, so we
+				// wanted to provide feedback to users what was happening. For now, we
+				// keep them to not jar users who are used to seeing them. We can phase
+				// them out with a new and nicer UX later. One without progress bars
+				// and digests that no one cares about.
+				enc.Encode(progressUpdateJSON{Status: "verifying layers"})
+				enc.Encode(progressUpdateJSON{Status: "writing manifest"})
+				enc.Encode(progressUpdateJSON{Status: "success"})
+				return
+			}
+		}
+	}()
+
+	return nil
+}
+
 func decodeUserJSON[T any](r io.Reader) (T, error) {
 	var v T
 	err := json.NewDecoder(r).Decode(&v)

+ 126 - 4
server/internal/registry/server_test.go

@@ -1,17 +1,27 @@
 package registry
 
 import (
+	"bytes"
+	"context"
 	"encoding/json"
+	"fmt"
+	"io"
+	"io/fs"
+	"net"
 	"net/http"
 	"net/http/httptest"
 	"os"
 	"regexp"
 	"strings"
+	"sync"
 	"testing"
 
 	"github.com/ollama/ollama/server/internal/cache/blob"
 	"github.com/ollama/ollama/server/internal/client/ollama"
 	"github.com/ollama/ollama/server/internal/testutil"
+	"golang.org/x/tools/txtar"
+
+	_ "embed"
 )
 
 type panicTransport struct{}
@@ -30,7 +40,7 @@ type bytesResetter interface {
 	Reset()
 }
 
-func newTestServer(t *testing.T) *Local {
+func newTestServer(t *testing.T, upstreamRegistry http.HandlerFunc) *Local {
 	t.Helper()
 	dir := t.TempDir()
 	err := os.CopyFS(dir, os.DirFS("testdata/models"))
@@ -41,10 +51,25 @@ func newTestServer(t *testing.T) *Local {
 	if err != nil {
 		t.Fatal(err)
 	}
+
+	client := panicOnRoundTrip
+	if upstreamRegistry != nil {
+		s := httptest.NewTLSServer(upstreamRegistry)
+		t.Cleanup(s.Close)
+		tr := s.Client().Transport.(*http.Transport).Clone()
+		tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
+			var d net.Dialer
+			return d.DialContext(ctx, "tcp", s.Listener.Addr().String())
+		}
+		client = &http.Client{Transport: tr}
+	}
+
 	rc := &ollama.Registry{
 		Cache:      c,
-		HTTPClient: panicOnRoundTrip,
+		HTTPClient: client,
+		Mask:       "example.com/library/_:latest",
 	}
+
 	l := &Local{
 		Client: rc,
 		Logger: testutil.Slogger(t),
@@ -85,7 +110,7 @@ func captureLogs(t *testing.T, s *Local) (*Local, bytesResetter) {
 func TestServerDelete(t *testing.T) {
 	check := testutil.Checker(t)
 
-	s := newTestServer(t)
+	s := newTestServer(t, nil)
 
 	_, err := s.Client.ResolveLocal("smol")
 	check(err)
@@ -127,8 +152,105 @@ func TestServerDelete(t *testing.T) {
 	}
 }
 
+//go:embed testdata/registry.txt
+var registryTXT []byte
+
+var registryFS = sync.OnceValue(func() fs.FS {
+	// Txtar gets hung up on \r\n line endings, so we need to convert them
+	// to \n when parsing the txtar on Windows.
+	data := bytes.ReplaceAll(registryTXT, []byte("\r\n"), []byte("\n"))
+	a := txtar.Parse(data)
+	fmt.Printf("%q\n", a.Comment)
+	fsys, err := txtar.FS(a)
+	if err != nil {
+		panic(err)
+	}
+	return fsys
+})
+
+func TestServerPull(t *testing.T) {
+	modelsHandler := http.FileServerFS(registryFS())
+	s := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
+		switch r.URL.Path {
+		case "/v2/library/BOOM/manifests/latest":
+			w.WriteHeader(999)
+			io.WriteString(w, `{"error": "boom"}`)
+		case "/v2/library/unknown/manifests/latest":
+			w.WriteHeader(404)
+			io.WriteString(w, `{"errors": [{"code": "MANIFEST_UNKNOWN", "message": "manifest unknown"}]}`)
+		default:
+			t.Logf("serving file: %s", r.URL.Path)
+			modelsHandler.ServeHTTP(w, r)
+		}
+	})
+
+	checkResponse := func(got *httptest.ResponseRecorder, wantlines string) {
+		t.Helper()
+
+		if got.Code != 200 {
+			t.Fatalf("Code = %d; want 200", got.Code)
+		}
+		gotlines := got.Body.String()
+		t.Logf("got:\n%s", gotlines)
+		for want := range strings.Lines(wantlines) {
+			want = strings.TrimSpace(want)
+			want, unwanted := strings.CutPrefix(want, "!")
+			want = strings.TrimSpace(want)
+			if !unwanted && !strings.Contains(gotlines, want) {
+				t.Fatalf("! missing %q in body", want)
+			}
+			if unwanted && strings.Contains(gotlines, want) {
+				t.Fatalf("! unexpected %q in body", want)
+			}
+		}
+	}
+
+	got := s.send(t, "POST", "/api/pull", `{"model": "BOOM"}`)
+	checkResponse(got, `
+		{"status":"pulling manifest"}
+		{"status":"error: request error https://example.com/v2/library/BOOM/manifests/latest: registry responded with status 999: boom"}
+	`)
+
+	got = s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
+	checkResponse(got, `
+		{"status":"pulling manifest"}
+		{"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5}
+		{"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3}
+		{"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5}
+		{"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3}
+		{"status":"verifying layers"}
+		{"status":"writing manifest"}
+		{"status":"success"}
+	`)
+
+	got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`)
+	checkResponse(got, `
+		{"status":"pulling manifest"}
+		{"status":"error: model \"unknown\" not found"}
+	`)
+
+	got = s.send(t, "DELETE", "/api/pull", `{"model": "smol"}`)
+	checkErrorResponse(t, got, 405, "method_not_allowed", "method not allowed")
+
+	got = s.send(t, "POST", "/api/pull", `!`)
+	checkErrorResponse(t, got, 400, "bad_request", "invalid character '!' looking for beginning of value")
+
+	got = s.send(t, "POST", "/api/pull", ``)
+	checkErrorResponse(t, got, 400, "bad_request", "empty request body")
+
+	got = s.send(t, "POST", "/api/pull", `{"model": "://"}`)
+	checkResponse(got, `
+		{"status":"pulling manifest"}
+		{"status":"error: invalid or missing name: \"\""}
+
+		!verifying
+		!writing
+		!success
+	`)
+}
+
 func TestServerUnknownPath(t *testing.T) {
-	s := newTestServer(t)
+	s := newTestServer(t, nil)
 	got := s.send(t, "DELETE", "/api/unknown", `{}`)
 	checkErrorResponse(t, got, 404, "not_found", "not found")
 }

+ 0 - 0
server/internal/registry/testdata/models/manifests/registry.ollama.ai/library/smol/latest → server/internal/registry/testdata/models/manifests/example.com/library/smol/latest


+ 22 - 0
server/internal/registry/testdata/registry.txt

@@ -0,0 +1,22 @@
+-- v2/library/smol/manifests/latest --
+{
+  "schemaVersion": 2,
+  "mediaType": "application/vnd.docker.distribution.manifest.v2+json",
+  "config": {
+    "mediaType": "application/vnd.docker.container.image.v1+json",
+    "digest": "sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356",
+    "size": 3
+  },
+  "layers": [
+    {
+      "mediaType": "application/vnd.ollama.image.model",
+      "digest": "sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312",
+      "size": 5
+    }
+  ]
+}
+
+-- v2/library/smol/blobs/sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312 --
+GGUF
+-- v2/library/smol/blobs/sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356 --
+{}

+ 25 - 10
server/routes.go

@@ -42,6 +42,12 @@ import (
 	"github.com/ollama/ollama/version"
 )
 
+func experimentEnabled(name string) bool {
+	return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name)
+}
+
+var useClient2 = experimentEnabled("client2")
+
 var mode string = gin.DebugMode
 
 type Server struct {
@@ -1173,6 +1179,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
 	r.HEAD("/api/tags", s.ListHandler)
 	r.GET("/api/tags", s.ListHandler)
 	r.POST("/api/show", s.ShowHandler)
+	r.DELETE("/api/delete", s.DeleteHandler)
 
 	// Create
 	r.POST("/api/create", s.CreateHandler)
@@ -1194,16 +1201,19 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
 	r.GET("/v1/models", openai.ListMiddleware(), s.ListHandler)
 	r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowHandler)
 
-	// wrap old with new
-	rs := &registry.Local{
-		Client:   rc,
-		Logger:   slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default()
-		Fallback: r,
+	if rc != nil {
+		// wrap old with new
+		rs := &registry.Local{
+			Client:   rc,
+			Logger:   slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default()
+			Fallback: r,
 
-		Prune: PruneLayers,
+			Prune: PruneLayers,
+		}
+		return rs, nil
 	}
 
-	return rs, nil
+	return r, nil
 }
 
 func Serve(ln net.Listener) error {
@@ -1258,15 +1268,20 @@ func Serve(ln net.Listener) error {
 
 	s := &Server{addr: ln.Addr()}
 
-	rc, err := ollama.DefaultRegistry()
-	if err != nil {
-		return err
+	var rc *ollama.Registry
+	if useClient2 {
+		var err error
+		rc, err = ollama.DefaultRegistry()
+		if err != nil {
+			return err
+		}
 	}
 
 	h, err := s.GenerateRoutes(rc)
 	if err != nil {
 		return err
 	}
+
 	http.Handle("/", h)
 
 	ctx, done := context.WithCancel(context.Background())