Browse Source

x/model: add support for @digest

Blake Mizerany 1 year ago
parent
commit
ff68227ca1
2 changed files with 150 additions and 12 deletions
  1. 84 11
      x/model/name.go
  2. 66 1
      x/model/name_test.go

+ 84 - 11
x/model/name.go

@@ -6,6 +6,7 @@ import (
 	"database/sql"
 	"database/sql/driver"
 	"errors"
+	"fmt"
 	"hash/maphash"
 	"io"
 	"iter"
@@ -13,6 +14,7 @@ import (
 	"slices"
 	"strings"
 	"sync"
+	"unicode"
 
 	"github.com/ollama/ollama/x/types/structs"
 )
@@ -105,7 +107,7 @@ type Name struct {
 // ParseName parses s into a Name. The input string must be a valid string
 // representation of a model name in the form:
 //
-//	<host>/<namespace>/<model>:<tag>+<build>
+//	<host>/<namespace>/<model>:<tag>+<build>@<digest-type>-<digest>
 //
 // The name part is required, all others are optional. If a part is missing,
 // it is left empty in the returned Name. If a part is invalid, the zero Ref
@@ -120,6 +122,7 @@ type Name struct {
 //	"mistral:7b+x"
 //	"example.com/mike/mistral:latest+Q4_0"
 //	"example.com/bruce/mistral:latest"
+//	"example.com/mistral:7b+Q4_0@sha256-1234567890abcdef"
 //
 // Examples of invalid paths:
 //
@@ -141,10 +144,10 @@ func ParseName(s string) Name {
 		}
 		r.parts[kind] = part
 	}
-	if !r.Valid() {
-		return Name{}
+	if r.Valid() || r.Resolved() {
+		return r
 	}
-	return r
+	return Name{}
 }
 
 // Fill fills in the missing parts of dst with the parts of src.
@@ -238,15 +241,19 @@ var seps = [...]string{
 // WriteTo implements io.WriterTo. It writes the fullest possible display
 // string in form:
 //
-//	<host>/<namespace>/<model>:<tag>+<build>
+//	<host>/<namespace>/<model>:<tag>+<build>@<digest-type>-<digest>
 //
 // Missing parts and their seperators are not written.
+//
+// The full digest is always prefixed with "@". That is if [Name.Valid]
+// reports false and [Name.Resolved] reports true, then the string is
+// returned as "@<digest-type>-<digest>".
 func (r Name) WriteTo(w io.Writer) (n int64, err error) {
 	for i := range r.parts {
 		if r.parts[i] == "" {
 			continue
 		}
-		if n > 0 {
+		if n > 0 || NamePart(i) == Digest {
 			n1, err := io.WriteString(w, seps[i-1])
 			n += int64(n1)
 			if err != nil {
@@ -382,6 +389,22 @@ func (r Name) CompleteNoBuild() bool {
 	return !slices.Contains(r.parts[:Build], "")
 }
 
+// Resolved reports true if the Name has a valid digest.
+//
+// It is possible to have a valid Name, or a complete Name that is not
+// resolved.
+func (r Name) Resolved() bool {
+	return r.parts[Digest] != ""
+}
+
+// Digest returns the digest part of the Name, if any.
+//
+// If Digest returns a non-empty string, then [Name.Resolved] will return
+// true, and digest is considered valid.
+func (r Name) Digest() string {
+	return r.parts[Digest]
+}
+
 // EqualFold reports whether r and o are equivalent model names, ignoring
 // case.
 func (r Name) EqualFold(o Name) bool {
@@ -452,11 +475,29 @@ func Parts(s string) iter.Seq2[NamePart, string] {
 				yield(Invalid, "")
 				return
 			}
+
 			switch s[i] {
 			case '@':
 				switch state {
 				case Digest:
-					if !yieldValid(Digest, s[i+1:j]) {
+					part := s[i+1:]
+					if isValidDigest(part) {
+						if !yield(Digest, part) {
+							return
+						}
+						if i == 0 {
+							// The name is in
+							// the form of
+							// "@digest". This
+							// is valid ans so
+							// we want to skip
+							// the final
+							// validation for
+							// any other state.
+							return
+						}
+					} else {
+						yield(Invalid, "")
 						return
 					}
 					state, j, partLen = Build, i, 0
@@ -552,9 +593,41 @@ func isValidByte(kind NamePart, c byte) bool {
 	return false
 }
 
-func sumLens(a []string) (sum int) {
-	for _, n := range a {
-		sum += len(n)
+// isValidDigest returns true if the given string in the form of
+// "<digest-type>-<digest>", and <digest-type> is in the form of [a-z0-9]+
+// and <digest> is a valid hex string.
+//
+// It does not check if the digest is a valid hash for the given digest
+// type, or restrict the digest type to a known set of types. This is left
+// up to ueers of this package.
+func isValidDigest(s string) bool {
+	typ, digest, ok := strings.Cut(s, "-")
+	res := ok && isValidDigestType(typ) && isValidHex(digest)
+	fmt.Printf("DEBUG: %q: typ: %s, digest: %s, ok: %v res: %v\n", s, typ, digest, ok, res)
+	return res
+}
+
+func isValidDigestType(s string) bool {
+	if len(s) == 0 {
+		return false
+	}
+	for _, r := range s {
+		if !unicode.IsLower(r) && !unicode.IsDigit(r) {
+			return false
+		}
 	}
-	return
+	return true
+}
+
+func isValidHex(s string) bool {
+	if len(s) == 0 {
+		return false
+	}
+	for i := range s {
+		c := s[i]
+		if c < '0' || c > '9' && c < 'a' || c > 'f' {
+			return false
+		}
+	}
+	return true
 }

+ 66 - 1
x/model/name_test.go

@@ -49,8 +49,16 @@ var testNames = map[string]fields{
 	"example.com/ns/mistral:7b+Q4_0": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "Q4_0"},
 	"example.com/ns/mistral:7b+X":    {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "X"},
 
+	// invalid digest
+	"mistral:latest@invalid256-": {},
+	"mistral:latest@-123":        {},
+	"mistral:latest@!-123":       {},
+	"mistral:latest@1-!":         {},
+	"mistral:latest@":            {},
+
 	// resolved
-	"x@123": {model: "x", digest: "123"},
+	"x@sha123-1": {model: "x", digest: "sha123-1"},
+	"@sha456-2":  {digest: "sha456-2"},
 
 	// preserves case for build
 	"x+b": {model: "x", build: "b"},
@@ -109,6 +117,53 @@ func TestNamePartString(t *testing.T) {
 	}
 }
 
+func TestIsValidDigestType(t *testing.T) {
+	cases := []struct {
+		in   string
+		want bool
+	}{
+		{"sha256", true},
+		{"blake2", true},
+
+		{"", false},
+		{"-sha256", false},
+		{"sha256-", false},
+		{"Sha256", false},
+		{"sha256(", false},
+		{" sha256", false},
+	}
+	for _, tt := range cases {
+		t.Run(tt.in, func(t *testing.T) {
+			if g := isValidDigestType(tt.in); g != tt.want {
+				t.Errorf("isValidDigestType(%q) = %v; want %v", tt.in, g, tt.want)
+			}
+		})
+	}
+}
+
+func TestIsValidDigest(t *testing.T) {
+	cases := []struct {
+		in   string
+		want bool
+	}{
+		{"", false},
+		{"sha256-123", true},
+		{"sha256-1234567890abcdef", true},
+		{"sha256-1234567890abcdef1234567890abcdeffffffffffffffffffffffffffffffffffffffffff", true},
+		{"!sha256-123", false},
+		{"sha256-123!", false},
+		{"sha256-", false},
+		{"-123", false},
+	}
+	for _, tt := range cases {
+		t.Run(tt.in, func(t *testing.T) {
+			if g := isValidDigest(tt.in); g != tt.want {
+				t.Errorf("isValidDigest(%q) = %v; want %v", tt.in, g, tt.want)
+			}
+		})
+	}
+}
+
 func TestParseName(t *testing.T) {
 	for baseName, want := range testNames {
 		for _, prefix := range []string{"", "https://", "http://"} {
@@ -117,6 +172,10 @@ func TestParseName(t *testing.T) {
 			s := prefix + baseName
 
 			t.Run(s, func(t *testing.T) {
+				for kind, part := range Parts(s) {
+					t.Logf("Part: %s: %q", kind, part)
+				}
+
 				name := ParseName(s)
 				got := fieldsFromName(name)
 				if got != want {
@@ -133,6 +192,12 @@ func TestParseName(t *testing.T) {
 				} else if !name.Valid() && name.DisplayModel() != "" {
 					t.Errorf("Valid() = false; Model() = %q; want empty name", got.model)
 				}
+
+				if name.Resolved() && name.Digest() == "" {
+					t.Errorf("Resolved() = true; Digest() = %q; want non-empty digest", got.digest)
+				} else if !name.Resolved() && name.Digest() != "" {
+					t.Errorf("Resolved() = false; Digest() = %q; want empty digest", got.digest)
+				}
 			})
 		}
 	}