Browse Source

x/model: limit part len, not entire len

Limiting the whole name length comes naturally with part name length
restrictions. This aligns with Docker's registry behavior.
Blake Mizerany 1 năm trước cách đây
mục cha
commit
7c7f56a7fb
3 tập tin đã thay đổi với 122 bổ sung70 xóa
  1. 2 8
      x/model/file.go
  2. 84 58
      x/model/name.go
  3. 36 4
      x/model/name_test.go

+ 2 - 8
x/model/file.go

@@ -1,14 +1,8 @@
 // Package model implements the File and Name types for working with and
 // representing Modelfiles and model Names.
 //
-// The Name type is designed for safety and correctness. It is an opaque
-// reference to a model, and holds the parts of a model, casing preserved,
-// but is not directly comparable with other Names since model names can be
-// represented with different caseing depending on the use case.
-//
-// Names should never be compared manually parsed. Instead, use the
-// [Name.EqualFold] method to compare two names in a case-insensitive
-// manner, and [ParseName] to create a Name from a string, safely.
+// The Name type should be used when working with model names, and the File
+// type should be used when working with Modelfiles.
 package model
 
 import (

+ 84 - 58
x/model/name.go

@@ -2,6 +2,7 @@ package model
 
 import (
 	"cmp"
+	"errors"
 	"hash/maphash"
 	"iter"
 	"slices"
@@ -10,21 +11,19 @@ import (
 	"github.com/ollama/ollama/x/types/structs"
 )
 
-const MaxNameLength = 255
+// Errors
+var (
+	// ErrInvalidName is not used by this package, but is exported so that
+	// other packages do not need to invent their own error type when they
+	// need to return an error for an invalid name.
+	ErrIncompleteName = errors.New("incomplete model name")
+)
 
-type NamePart int
+const MaxNamePartLen = 128
 
-// Levels of concreteness
-const (
-	Invalid NamePart = iota
-	Host
-	Namespace
-	Model
-	Tag
-	Build
-)
+type NamePartKind int
 
-var kindNames = map[NamePart]string{
+var kindNames = map[NamePartKind]string{
 	Invalid:   "Invalid",
 	Host:      "Host",
 	Namespace: "Namespace",
@@ -33,12 +32,36 @@ var kindNames = map[NamePart]string{
 	Build:     "Build",
 }
 
-// Name is an opaque reference to a model. It holds the parts of a model,
-// casing preserved, and provides methods for comparing and manipulating
-// them in a case-insensitive manner.
+func (k NamePartKind) String() string {
+	return cmp.Or(kindNames[k], "!(UNKNOWN PART KIND)")
+}
+
+// Levels of concreteness
+const (
+	Invalid NamePartKind = iota
+	Host
+	Namespace
+	Model
+	Tag
+	Build
+)
+
+// Name is an opaque reference to a model. It holds the parts of a model
+// with the case preserved, but is not directly comparable with other Names
+// since model names can be represented with different caseing depending on
+// the use case. For instance, "Mistral" and "mistral" are the same model
+// but each version may have come from different sources (e.g. copied from a
+// Web page, or from a file path).
 //
-// To create a Name, use [ParseName]. To compare two names, use
-// [Name.EqualFold]. To use a name as a key in a map, use [Name.MapHash].
+// Valid Names can ONLY be constructed by calling [ParseName].
+//
+// A Name is valid if and only if is have a valid Model part. The other parts
+// are optional.
+//
+// A Name is considered "complete" if it has all parts present. To check if a
+// Name is complete, use [Name.Complete].
+//
+// To compare two names in a case-insensitive manner, use [Name.EqualFold].
 //
 // The parts of a Name are:
 //
@@ -124,7 +147,7 @@ func ParseName(s string) Name {
 
 // Fill fills in the missing parts of dst with the parts of src.
 //
-// Use this for merging a fully qualified ref with a partial ref, such as
+// Use this for merging a fully qualified Name with a partial Name, such as
 // when filling in a missing parts with defaults.
 //
 // The returned Name will only be valid if dst is valid.
@@ -144,6 +167,23 @@ func (r Name) WithBuild(build string) Name {
 	return r
 }
 
+// Has reports whether the Name has the given part kind.
+func (r Name) Has(kind NamePartKind) bool {
+	switch kind {
+	case Host:
+		return r.host != ""
+	case Namespace:
+		return r.namespace != ""
+	case Model:
+		return r.model != ""
+	case Tag:
+		return r.tag != ""
+	case Build:
+		return r.build != ""
+	}
+	return false
+}
+
 var mapHashSeed = maphash.MakeSeed()
 
 // MapHash returns a case insensitive hash for use in maps and equality
@@ -165,9 +205,10 @@ func (r Name) MapHash() uint64 {
 	return h.Sum64()
 }
 
-// Format returns a string representation of the ref with the given
-// concreteness. If a part is missing, it is replaced with a loud
-// placeholder.
+func (r Name) DisplayModel() string {
+	return r.model
+}
+
 func (r Name) DisplayFull() string {
 	return (Name{
 		host:      cmp.Or(r.host, "!(MISSING DOMAIN)"),
@@ -178,27 +219,7 @@ func (r Name) DisplayFull() string {
 	}).String()
 }
 
-func (r Name) DisplayModel() string {
-	return r.model
-}
-
-func (r Name) Has(kind NamePart) bool {
-	switch kind {
-	case Host:
-		return r.host != ""
-	case Namespace:
-		return r.namespace != ""
-	case Model:
-		return r.model != ""
-	case Tag:
-		return r.tag != ""
-	case Build:
-		return r.build != ""
-	}
-	return false
-}
-
-// DisplayCompact returns a compact display string of the ref with only the
+// DisplayCompact returns a compact display string of the Name with only the
 // model and tag parts.
 func (r Name) DisplayCompact() string {
 	return (Name{
@@ -207,7 +228,7 @@ func (r Name) DisplayCompact() string {
 	}).String()
 }
 
-// DisplayShort returns a short display string of the ref with only the
+// DisplayShort returns a short display string of the Name with only the
 // model, tag, and build parts.
 func (r Name) DisplayShort() string {
 	return (Name{
@@ -217,7 +238,7 @@ func (r Name) DisplayShort() string {
 	}).String()
 }
 
-// DisplayLong returns a long display string of the ref including namespace,
+// DisplayLong returns a long display string of the Name including namespace,
 // model, tag, and build parts.
 func (r Name) DisplayLong() string {
 	return (Name{
@@ -228,7 +249,7 @@ func (r Name) DisplayLong() string {
 	}).String()
 }
 
-// String returns the fully qualified ref string.
+// String returns the fully qualified Name string.
 func (r Name) String() string {
 	var b strings.Builder
 	if r.host != "" {
@@ -251,7 +272,7 @@ func (r Name) String() string {
 	return b.String()
 }
 
-// Complete reports whether the ref is fully qualified. That is it has a
+// Complete reports whether the Name is fully qualified. That is it has a
 // domain, namespace, name, tag, and build.
 func (r Name) Complete() bool {
 	return r.Valid() && !slices.Contains(r.Parts(), "")
@@ -262,7 +283,7 @@ func (r Name) Complete() bool {
 // TODO(bmizerany): LogValue
 // TODO(bmizerany): driver.Value? (MarshalText etc should be enough)
 
-// Parts returns the parts of the ref in order of concreteness.
+// Parts returns the parts of the Name in order of concreteness.
 //
 // The length of the returned slice is always 5.
 func (r Name) Parts() []string {
@@ -287,7 +308,7 @@ func (r Name) EqualFold(o Name) bool {
 	return r.MapHash() == o.MapHash()
 }
 
-// Parts returns a sequence of the parts of a ref string from most specific
+// Parts returns a sequence of the parts of a Name string from most specific
 // to least specific.
 //
 // It normalizes the input string by removing "http://" and "https://" only.
@@ -295,8 +316,8 @@ func (r Name) EqualFold(o Name) bool {
 //
 // As a special case, question marks are ignored so they may be used as
 // placeholders for missing parts in string literals.
-func NameParts(s string) iter.Seq2[NamePart, string] {
-	return func(yield func(NamePart, string) bool) {
+func NameParts(s string) iter.Seq2[NamePartKind, string] {
+	return func(yield func(NamePartKind, string) bool) {
 		if strings.HasPrefix(s, "http://") {
 			s = s[len("http://"):]
 		}
@@ -304,11 +325,11 @@ func NameParts(s string) iter.Seq2[NamePart, string] {
 			s = s[len("https://"):]
 		}
 
-		if len(s) > MaxNameLength || len(s) == 0 {
+		if len(s) > MaxNamePartLen || len(s) == 0 {
 			return
 		}
 
-		yieldValid := func(kind NamePart, part string) bool {
+		yieldValid := func(kind NamePartKind, part string) bool {
 			if !isValidPart(kind, part) {
 				yield(Invalid, "")
 				return false
@@ -316,8 +337,13 @@ func NameParts(s string) iter.Seq2[NamePart, string] {
 			return yield(kind, part)
 		}
 
+		partLen := 0
 		state, j := Build, len(s)
 		for i := len(s) - 1; i >= 0; i-- {
+			if partLen++; partLen > MaxNamePartLen {
+				yield(Invalid, "")
+				return
+			}
 			switch s[i] {
 			case '+':
 				switch state {
@@ -325,7 +351,7 @@ func NameParts(s string) iter.Seq2[NamePart, string] {
 					if !yieldValid(Build, s[i+1:j]) {
 						return
 					}
-					state, j = Tag, i
+					state, j, partLen = Tag, i, 0
 				default:
 					yield(Invalid, "")
 					return
@@ -336,7 +362,7 @@ func NameParts(s string) iter.Seq2[NamePart, string] {
 					if !yieldValid(Tag, s[i+1:j]) {
 						return
 					}
-					state, j = Model, i
+					state, j, partLen = Model, i, 0
 				default:
 					yield(Invalid, "")
 					return
@@ -352,7 +378,7 @@ func NameParts(s string) iter.Seq2[NamePart, string] {
 					if !yieldValid(Namespace, s[i+1:j]) {
 						return
 					}
-					state, j = Host, i
+					state, j, partLen = Host, i, 0
 				default:
 					yield(Invalid, "")
 					return
@@ -373,7 +399,7 @@ func NameParts(s string) iter.Seq2[NamePart, string] {
 	}
 }
 
-// Valid returns true if the ref has a valid nick. To know if a ref is
+// Valid returns true if the Name has a valid nick. To know if a Name is
 // "complete", use Complete.
 func (r Name) Valid() bool {
 	// Parts ensures we only have valid parts, so no need to validate
@@ -382,7 +408,7 @@ func (r Name) Valid() bool {
 }
 
 // isValidPart returns true if given part is valid ascii [a-zA-Z0-9_\.-]
-func isValidPart(kind NamePart, s string) bool {
+func isValidPart(kind NamePartKind, s string) bool {
 	if s == "" {
 		return false
 	}
@@ -394,7 +420,7 @@ func isValidPart(kind NamePart, s string) bool {
 	return true
 }
 
-func isValidByte(kind NamePart, c byte) bool {
+func isValidByte(kind NamePartKind, c byte) bool {
 	if kind == Namespace && c == '.' {
 		return false
 	}

+ 36 - 4
x/model/name_test.go

@@ -52,8 +52,8 @@ var testNames = map[string]Name{
 	"file:///etc/passwd:latest":   {},
 	"file:///etc/passwd:latest+u": {},
 
-	strings.Repeat("a", MaxNameLength):   {model: strings.Repeat("a", MaxNameLength)},
-	strings.Repeat("a", MaxNameLength+1): {},
+	strings.Repeat("a", MaxNamePartLen):   {model: strings.Repeat("a", MaxNamePartLen)},
+	strings.Repeat("a", MaxNamePartLen+1): {},
 }
 
 func TestNameParts(t *testing.T) {
@@ -64,6 +64,34 @@ func TestNameParts(t *testing.T) {
 	}
 }
 
+func TestPartTooLong(t *testing.T) {
+	for i := Host; i <= Build; i++ {
+		t.Run(i.String(), func(t *testing.T) {
+			var p Name
+			switch i {
+			case Host:
+				p.host = strings.Repeat("a", MaxNamePartLen+1)
+			case Namespace:
+				p.namespace = strings.Repeat("a", MaxNamePartLen+1)
+			case Model:
+				p.model = strings.Repeat("a", MaxNamePartLen+1)
+			case Tag:
+				p.tag = strings.Repeat("a", MaxNamePartLen+1)
+			case Build:
+				p.build = strings.Repeat("a", MaxNamePartLen+1)
+			}
+			s := strings.Trim(p.String(), "+:/")
+			if len(s) != MaxNamePartLen+1 {
+				t.Errorf("len(String()) = %d; want %d", len(s), MaxNamePartLen+1)
+				t.Logf("String() = %q", s)
+			}
+			if ParseName(p.String()).Valid() {
+				t.Errorf("Valid(%q) = true; want false", p)
+			}
+		})
+	}
+}
+
 func TestParseName(t *testing.T) {
 	for baseName, want := range testNames {
 		for _, prefix := range []string{"", "https://", "http://"} {
@@ -210,7 +238,7 @@ func FuzzParseName(f *testing.F) {
 		}
 
 		for _, p := range r0.Parts() {
-			if len(p) > MaxNameLength {
+			if len(p) > MaxNamePartLen {
 				t.Errorf("part too long: %q", p)
 			}
 		}
@@ -261,11 +289,15 @@ func ExampleFill() {
 func ExampleName_MapHash() {
 	m := map[uint64]bool{}
 
+	// key 1
 	m[ParseName("mistral:latest+q4").MapHash()] = true
 	m[ParseName("miSTRal:latest+Q4").MapHash()] = true
 	m[ParseName("mistral:LATest+Q4").MapHash()] = true
 
+	// key 2
+	m[ParseName("mistral:LATest").MapHash()] = true
+
 	fmt.Println(len(m))
 	// Output:
-	// 1
+	// 2
 }