Pārlūkot izejas kodu

server/internal/client/ollama: handle extended names in client/ollama (#9454)

The extended name format is a superset of the name format that only the
client needs to know about, not the server or other dependents of the
name package, so move the split logic into the client package.

Also, take advantage of knowing about the extended name format to allow
the client to use the extended name format when unlinking to verify they
are unlinking the manifest with the content they intend.
Blake Mizerany 2 mēneši atpakaļ
vecāks
revīzija
ee048b76d4

+ 44 - 21
server/internal/client/ollama/registry.go

@@ -212,12 +212,16 @@ type Registry struct {
 	Mask string
 	Mask string
 }
 }
 
 
-func (r *Registry) completeName(name string) names.Name {
+func (r *Registry) parseName(name string) (names.Name, error) {
 	mask := defaultMask
 	mask := defaultMask
 	if r.Mask != "" {
 	if r.Mask != "" {
 		mask = names.Parse(r.Mask)
 		mask = names.Parse(r.Mask)
 	}
 	}
-	return names.Merge(names.Parse(name), mask)
+	n := names.Merge(names.Parse(name), mask)
+	if !n.IsFullyQualified() {
+		return names.Name{}, fmt.Errorf("%w: %q", ErrNameInvalid, name)
+	}
+	return n, nil
 }
 }
 
 
 // DefaultRegistry returns a new Registry configured from the environment. The
 // DefaultRegistry returns a new Registry configured from the environment. The
@@ -306,7 +310,7 @@ func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *
 
 
 	t := traceFromContext(ctx)
 	t := traceFromContext(ctx)
 
 
-	scheme, n, _, err := parseName(name, r.Mask)
+	scheme, n, _, err := r.parseNameExtended(name)
 	if err != nil {
 	if err != nil {
 		// This should never happen since ResolveLocal should have
 		// This should never happen since ResolveLocal should have
 		// already validated the name.
 		// already validated the name.
@@ -400,7 +404,7 @@ func canRetry(err error) bool {
 // typically slower than splitting the model up across layers, and is mostly
 // typically slower than splitting the model up across layers, and is mostly
 // utilized for layers of type equal to "application/vnd.ollama.image".
 // utilized for layers of type equal to "application/vnd.ollama.image".
 func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error {
 func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error {
-	scheme, n, _, err := parseName(name, r.Mask)
+	scheme, n, _, err := r.parseNameExtended(name)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -551,9 +555,9 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err
 // Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified
 // Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified
 // before attempting to unlink the model.
 // before attempting to unlink the model.
 func (r *Registry) Unlink(c *blob.DiskCache, name string) (ok bool, _ error) {
 func (r *Registry) Unlink(c *blob.DiskCache, name string) (ok bool, _ error) {
-	n := r.completeName(name)
-	if !n.IsFullyQualified() {
-		return false, fmt.Errorf("%w: %q", ErrNameInvalid, name)
+	n, err := r.parseName(name)
+	if err != nil {
+		return false, err
 	}
 	}
 	return c.Unlink(n.String())
 	return c.Unlink(n.String())
 }
 }
@@ -626,10 +630,9 @@ type Layer struct {
 	Size      int64       `json:"size"`
 	Size      int64       `json:"size"`
 }
 }
 
 
-// ResolveLocal resolves a name to a Manifest in the local cache. The name is
-// parsed using [names.Split] but the scheme is ignored.
+// ResolveLocal resolves a name to a Manifest in the local cache.
 func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
 func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
-	_, n, d, err := parseName(name, r.Mask)
+	_, n, d, err := r.parseNameExtended(name)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
@@ -655,7 +658,7 @@ func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, erro
 
 
 // Resolve resolves a name to a Manifest in the remote registry.
 // Resolve resolves a name to a Manifest in the remote registry.
 func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) {
 func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) {
-	scheme, n, d, err := parseName(name, r.Mask)
+	scheme, n, d, err := r.parseNameExtended(name)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
@@ -859,7 +862,7 @@ var supportedSchemes = []string{
 
 
 var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Join(supportedSchemes, ", "))
 var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Join(supportedSchemes, ", "))
 
 
-// parseName parses and validates an extended name, returning the scheme, name,
+// parseNameExtended parses and validates an extended name, returning the scheme, name,
 // and digest.
 // and digest.
 //
 //
 // If the scheme is empty, scheme will be "https". If an unsupported scheme is
 // If the scheme is empty, scheme will be "https". If an unsupported scheme is
@@ -870,8 +873,8 @@ var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Jo
 //
 //
 // If the name is not, once merged with the mask, fully qualified,
 // If the name is not, once merged with the mask, fully qualified,
 // [ErrNameInvalid] wrapped with a display friendly message is returned.
 // [ErrNameInvalid] wrapped with a display friendly message is returned.
-func parseName(s string, mask string) (scheme string, _ names.Name, _ blob.Digest, _ error) {
-	scheme, name, digest := names.Split(s)
+func (r *Registry) parseNameExtended(s string) (scheme string, _ names.Name, _ blob.Digest, _ error) {
+	scheme, name, digest := splitExtended(s)
 	scheme = cmp.Or(scheme, "https")
 	scheme = cmp.Or(scheme, "https")
 	if !slices.Contains(supportedSchemes, scheme) {
 	if !slices.Contains(supportedSchemes, scheme) {
 		err := withPublicMessagef(ErrNameInvalid, "unsupported scheme: %q: %s", scheme, supportedSchemesMessage)
 		err := withPublicMessagef(ErrNameInvalid, "unsupported scheme: %q: %s", scheme, supportedSchemesMessage)
@@ -894,13 +897,33 @@ func parseName(s string, mask string) (scheme string, _ names.Name, _ blob.Diges
 		}
 		}
 	}
 	}
 
 
-	maskName := defaultMask
-	if mask != "" {
-		maskName = names.Parse(mask)
-	}
-	n := names.Merge(names.Parse(name), maskName)
-	if !n.IsFullyQualified() {
-		return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s)
+	n, err := r.parseName(name)
+	if err != nil {
+		return "", names.Name{}, blob.Digest{}, err
 	}
 	}
 	return scheme, n, d, nil
 	return scheme, n, d, nil
 }
 }
+
+// splitExtended splits an extended name string into its scheme, name, and digest
+// parts.
+//
+// Examples:
+//
+//	http://ollama.com/bmizerany/smol:latest@digest
+//	https://ollama.com/bmizerany/smol:latest
+//	ollama.com/bmizerany/smol:latest@digest // returns "https" scheme.
+//	model@digest
+//	@digest
+func splitExtended(s string) (scheme, name, digest string) {
+	i := strings.Index(s, "://")
+	if i >= 0 {
+		scheme = s[:i]
+		s = s[i+3:]
+	}
+	i = strings.LastIndex(s, "@")
+	if i >= 0 {
+		digest = s[i+1:]
+		s = s[:i]
+	}
+	return scheme, s, digest
+}

+ 92 - 14
server/internal/client/ollama/registry_test.go

@@ -2,6 +2,7 @@ package ollama
 
 
 import (
 import (
 	"bytes"
 	"bytes"
+	"cmp"
 	"context"
 	"context"
 	"encoding/json"
 	"encoding/json"
 	"errors"
 	"errors"
@@ -91,7 +92,7 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
 	}
 	}
 
 
 	link := func(name string, manifest string) {
 	link := func(name string, manifest string) {
-		_, n, _, err := parseName(name, r.Mask)
+		n, err := r.parseName(name)
 		if err != nil {
 		if err != nil {
 			panic(err)
 			panic(err)
 		}
 		}
@@ -709,25 +710,16 @@ func TestErrorUnmarshal(t *testing.T) {
 //
 //
 // It is only for testing error messages, not that all invalids and valids are
 // It is only for testing error messages, not that all invalids and valids are
 // covered. Those are in other tests for names.Name and blob.Digest.
 // covered. Those are in other tests for names.Name and blob.Digest.
-func TestParseNameErrors(t *testing.T) {
+func TestParseNameExtendedErrors(t *testing.T) {
 	cases := []struct {
 	cases := []struct {
 		name string
 		name string
 		err  error
 		err  error
 		want string
 		want string
-	}{
-		{"x", nil, ""},
-		{"x@", nil, ""},
-
-		{"", ErrNameInvalid, `invalid or missing name: ""`},
-		{"://", ErrNameInvalid, `invalid or missing name: "://"`},
-		{"x://", ErrNameInvalid, `unsupported scheme: "x": supported schemes are http, https, https+insecure`},
-
-		{"@sha123-1234", ErrNameInvalid, `invalid digest: "sha123-1234"`},
-		{"x@sha123-1234", ErrNameInvalid, `invalid digest: "sha123-1234"`},
-	}
+	}{}
 
 
+	var r Registry
 	for _, tt := range cases {
 	for _, tt := range cases {
-		_, _, _, err := parseName(tt.name, DefaultMask)
+		_, _, _, err := r.parseNameExtended(tt.name)
 		if !errors.Is(err, tt.err) {
 		if !errors.Is(err, tt.err) {
 			t.Errorf("[%s]: err = %v; want %v", tt.name, err, tt.err)
 			t.Errorf("[%s]: err = %v; want %v", tt.name, err, tt.err)
 		}
 		}
@@ -736,3 +728,89 @@ func TestParseNameErrors(t *testing.T) {
 		}
 		}
 	}
 	}
 }
 }
+
+func TestParseNameExtended(t *testing.T) {
+	cases := []struct {
+		in     string
+		scheme string
+		name   string
+		digest string
+		err    string
+	}{
+		{in: "http://m", scheme: "http", name: "m"},
+		{in: "https+insecure://m", scheme: "https+insecure", name: "m"},
+		{in: "http+insecure://m", err: "unsupported scheme"},
+
+		{in: "http://m@sha256:1111111111111111111111111111111111111111111111111111111111111111", scheme: "http", name: "m", digest: "sha256:1111111111111111111111111111111111111111111111111111111111111111"},
+
+		{in: "", err: "invalid or missing name"},
+		{in: "m", scheme: "https", name: "m"},
+		{in: "://", err: "invalid or missing name"},
+		{in: "@sha256:deadbeef", err: "invalid digest"},
+		{in: "@sha256:deadbeef@sha256:deadbeef", err: "invalid digest"},
+	}
+	for _, tt := range cases {
+		t.Run(tt.in, func(t *testing.T) {
+			var r Registry
+			scheme, n, digest, err := r.parseNameExtended(tt.in)
+			if err != nil {
+				if tt.err == "" {
+					t.Errorf("err = %v; want nil", err)
+				} else if !strings.Contains(err.Error(), tt.err) {
+					t.Errorf("err = %v; want %q", err, tt.err)
+				}
+			} else if tt.err != "" {
+				t.Errorf("err = nil; want %q", tt.err)
+			}
+			if err == nil && !n.IsFullyQualified() {
+				t.Errorf("name = %q; want fully qualified", n)
+			}
+
+			if scheme != tt.scheme {
+				t.Errorf("scheme = %q; want %q", scheme, tt.scheme)
+			}
+
+			// smoke-test name is superset of tt.name
+			if !strings.Contains(n.String(), tt.name) {
+				t.Errorf("name = %q; want %q", n, tt.name)
+			}
+
+			tt.digest = cmp.Or(tt.digest, (&blob.Digest{}).String())
+			if digest.String() != tt.digest {
+				t.Errorf("digest = %q; want %q", digest, tt.digest)
+			}
+		})
+	}
+}
+
+func TestUnlink(t *testing.T) {
+	t.Run("found by name", func(t *testing.T) {
+		rc, c := newClient(t, nil)
+
+		// confirm linked
+		_, err := rc.ResolveLocal(c, "single")
+		if err != nil {
+			t.Errorf("unexpected error: %v", err)
+		}
+
+		// unlink
+		_, err = rc.Unlink(c, "single")
+		testutil.Check(t, err)
+
+		// confirm unlinked
+		_, err = rc.ResolveLocal(c, "single")
+		if !errors.Is(err, fs.ErrNotExist) {
+			t.Errorf("err = %v; want fs.ErrNotExist", err)
+		}
+	})
+	t.Run("not found by name", func(t *testing.T) {
+		rc, c := newClient(t, nil)
+		ok, err := rc.Unlink(c, "manifestNotFound")
+		if err != nil {
+			t.Fatal(err)
+		}
+		if ok {
+			t.Error("expected not found")
+		}
+	})
+}