瀏覽代碼

x/model: replace part fields with array of parts

This makes building strings and reasoning about parts easier.
Blake Mizerany 1 年之前
父節點
當前提交
14a6f85e9e
共有 2 個文件被更改,包括 136 次插入146 次删除
  1. 101 108
      x/model/name.go
  2. 35 38
      x/model/name_test.go

+ 101 - 108
x/model/name.go

@@ -1,9 +1,11 @@
 package model
 
 import (
+	"bytes"
 	"cmp"
 	"errors"
 	"hash/maphash"
+	"io"
 	"iter"
 	"log/slog"
 	"slices"
@@ -41,12 +43,15 @@ func (k NamePart) String() string {
 
 // Levels of concreteness
 const (
-	Invalid NamePart = iota
-	Host
+	Host NamePart = iota
 	Namespace
 	Model
 	Tag
 	Build
+
+	NumParts = Build + 1
+
+	Invalid = NamePart(-1)
 )
 
 // Name is an opaque reference to a model. It holds the parts of a model
@@ -84,13 +89,8 @@ const (
 //
 // To update parts of a Name with defaults, use [Fill].
 type Name struct {
-	_ structs.Incomparable
-
-	host      string
-	namespace string
-	model     string
-	tag       string
-	build     string
+	_     structs.Incomparable
+	parts [NumParts]string
 }
 
 // ParseName parses s into a Name. The input string must be a valid string
@@ -127,20 +127,10 @@ type Name struct {
 func ParseName(s string) Name {
 	var r Name
 	for kind, part := range NameParts(s) {
-		switch kind {
-		case Host:
-			r.host = part
-		case Namespace:
-			r.namespace = part
-		case Model:
-			r.model = part
-		case Tag:
-			r.tag = part
-		case Build:
-			r.build = part
-		case Invalid:
+		if kind == Invalid {
 			return Name{}
 		}
+		r.parts[kind] = part
 	}
 	if !r.Valid() {
 		return Name{}
@@ -152,18 +142,16 @@ func ParseName(s string) Name {
 //
 // The returned Name will only be valid if dst is valid.
 func Fill(dst, src Name) Name {
-	return Name{
-		model:     cmp.Or(dst.model, src.model),
-		host:      cmp.Or(dst.host, src.host),
-		namespace: cmp.Or(dst.namespace, src.namespace),
-		tag:       cmp.Or(dst.tag, src.tag),
-		build:     cmp.Or(dst.build, src.build),
+	var r Name
+	for i := range r.parts {
+		r.parts[i] = cmp.Or(dst.parts[i], src.parts[i])
 	}
+	return r
 }
 
 // WithBuild returns a copy of r with the build set to the given string.
 func (r Name) WithBuild(build string) Name {
-	r.build = build
+	r.parts[Build] = build
 	return r
 }
 
@@ -188,9 +176,15 @@ func (r Name) MapHash() uint64 {
 	return h.Sum64()
 }
 
+func (r Name) slice(from, to NamePart) Name {
+	var v Name
+	copy(v.parts[from:to+1], r.parts[from:to+1])
+	return v
+}
+
 // DisplayModel returns the a display string composed of the model only.
 func (r Name) DisplayModel() string {
-	return r.model
+	return r.parts[Model]
 }
 
 // DisplayFullest returns the fullest possible display string in form:
@@ -202,12 +196,7 @@ func (r Name) DisplayModel() string {
 // It does not include the build part. For the fullest possible display
 // string with the build, use [Name.String].
 func (r Name) DisplayFullest() string {
-	return (Name{
-		host:      r.host,
-		namespace: r.namespace,
-		model:     r.model,
-		tag:       r.tag,
-	}).String()
+	return r.slice(Host, Tag).String()
 }
 
 // DisplayShort returns the fullest possible display string in form:
@@ -216,10 +205,7 @@ func (r Name) DisplayFullest() string {
 //
 // If any part is missing, it is omitted from the display string.
 func (r Name) DisplayShort() string {
-	return (Name{
-		model: r.model,
-		tag:   r.tag,
-	}).String()
+	return r.slice(Model, Tag).String()
 }
 
 // DisplayLong returns the fullest possible display string in form:
@@ -228,11 +214,36 @@ func (r Name) DisplayShort() string {
 //
 // If any part is missing, it is omitted from the display string.
 func (r Name) DisplayLong() string {
-	return (Name{
-		namespace: r.namespace,
-		model:     r.model,
-		tag:       r.tag,
-	}).String()
+	return r.slice(Namespace, Tag).String()
+}
+
+var seps = [...]string{
+	Host:      "/",
+	Namespace: "/",
+	Model:     ":",
+	Tag:       "+",
+	Build:     "",
+}
+
+func (r Name) WriteTo(w io.Writer) (n int64, err error) {
+	for i := range r.parts {
+		if r.parts[i] == "" {
+			continue
+		}
+		if n > 0 {
+			n1, err := io.WriteString(w, seps[i-1])
+			n += int64(n1)
+			if err != nil {
+				return n, err
+			}
+		}
+		n1, err := io.WriteString(w, r.parts[i])
+		n += int64(n1)
+		if err != nil {
+			return n, err
+		}
+	}
+	return n, nil
 }
 
 var builderPool = sync.Pool{
@@ -241,6 +252,9 @@ var builderPool = sync.Pool{
 	},
 }
 
+// TODO(bmizerany): Add WriteTo and use in String and MarshalText with
+// strings.Builder and bytes.Buffer, respectively.
+
 // String returns the fullest possible display string in form:
 //
 //	<host>/<namespace>/<model>:<tag>+<build>
@@ -251,33 +265,10 @@ var builderPool = sync.Pool{
 // [Name.DisplayFullest].
 func (r Name) String() string {
 	b := builderPool.Get().(*strings.Builder)
-	b.Reset()
 	defer builderPool.Put(b)
-	b.Grow(0 +
-		len(r.host) +
-		len(r.namespace) +
-		len(r.model) +
-		len(r.tag) +
-		len(r.build) +
-		4, // 4 possible separators
-	)
-	if r.host != "" {
-		b.WriteString(r.host)
-		b.WriteString("/")
-	}
-	if r.namespace != "" {
-		b.WriteString(r.namespace)
-		b.WriteString("/")
-	}
-	b.WriteString(r.model)
-	if r.tag != "" {
-		b.WriteString(":")
-		b.WriteString(r.tag)
-	}
-	if r.build != "" {
-		b.WriteString("+")
-		b.WriteString(r.build)
-	}
+	b.Reset()
+	b.Grow(50) // arbitrarily long enough for most names
+	_, _ = r.WriteTo(b)
 	return b.String()
 }
 
@@ -286,13 +277,11 @@ func (r Name) String() string {
 // returns a string that includes all parts of the Name, with missing parts
 // replaced with a ("?").
 func (r Name) GoString() string {
-	return (Name{
-		host:      cmp.Or(r.host, "?"),
-		namespace: cmp.Or(r.namespace, "?"),
-		model:     cmp.Or(r.model, "?"),
-		tag:       cmp.Or(r.tag, "?"),
-		build:     cmp.Or(r.build, "?"),
-	}).String()
+	var v Name
+	for i := range r.parts {
+		v.parts[i] = cmp.Or(r.parts[i], "?")
+	}
+	return v.String()
 }
 
 // LogValue implements slog.Valuer.
@@ -300,18 +289,25 @@ func (r Name) LogValue() slog.Value {
 	return slog.StringValue(r.GoString())
 }
 
-// MarshalText implements encoding.TextMarshaler.
-func (r Name) MarshalText() ([]byte, error) {
-	// unsafeBytes is safe here because we gurantee that the string is
-	// never used after this function returns.
-	//
-	// TODO: We can remove this if https://github.com/golang/go/issues/62384
-	// lands.
-	return unsafeBytes(r.String()), nil
+var bufPool = sync.Pool{
+	New: func() interface{} {
+		return new(bytes.Buffer)
+	},
 }
 
-func unsafeBytes(s string) []byte {
-	return *(*[]byte)(unsafe.Pointer(&s))
+// MarshalText implements encoding.TextMarshaler.
+func (r Name) MarshalText() ([]byte, error) {
+	b := bufPool.Get().(*bytes.Buffer)
+	b.Reset()
+	b.Grow(50) // arbitrarily long enough for most names
+	defer bufPool.Put(b)
+	_, err := r.WriteTo(b)
+	if err != nil {
+		return nil, err
+	}
+	// TODO: We can remove this alloc if/when
+	// https://github.com/golang/go/issues/62384 lands.
+	return b.Bytes(), nil
 }
 
 // UnmarshalText implements encoding.TextUnmarshaler.
@@ -329,13 +325,13 @@ func unsafeString(b []byte) string {
 // Complete reports whether the Name is fully qualified. That is it has a
 // domain, namespace, name, tag, and build.
 func (r Name) Complete() bool {
-	return !slices.Contains(r.Parts(), "")
+	return !slices.Contains(r.parts[:], "")
 }
 
 // CompleteNoBuild is like [Name.Complete] but it does not require the
 // build part to be present.
 func (r Name) CompleteNoBuild() bool {
-	return !slices.Contains(r.Parts()[:4], "")
+	return !slices.Contains(r.parts[:Tag], "")
 }
 
 // EqualFold reports whether r and o are equivalent model names, ignoring
@@ -350,27 +346,23 @@ func (r Name) EqualFold(o Name) bool {
 //
 // For simple equality checks, use [Name.EqualFold].
 func (r Name) CompareFold(o Name) int {
-	return cmp.Or(
-		compareFold(r.host, o.host),
-		compareFold(r.namespace, o.namespace),
-		compareFold(r.model, o.model),
-		compareFold(r.tag, o.tag),
-		compareFold(r.build, o.build),
-	)
+	for i := range r.parts {
+		if n := compareFold(r.parts[i], o.parts[i]); n != 0 {
+			return n
+		}
+	}
+	return 0
 }
 
 func compareFold(a, b string) int {
 	// fast-path for unequal lengths
-	if n := cmp.Compare(len(a), len(b)); n != 0 {
-		return n
-	}
 	for i := 0; i < len(a) && i < len(b); i++ {
 		ca, cb := downcase(a[i]), downcase(b[i])
 		if n := cmp.Compare(ca, cb); n != 0 {
 			return n
 		}
 	}
-	return 0
+	return cmp.Compare(len(a), len(b))
 }
 
 func downcase(c byte) byte {
@@ -387,13 +379,7 @@ func downcase(c byte) byte {
 //
 // The length of the returned slice is always 5.
 func (r Name) Parts() []string {
-	return []string{
-		r.host,
-		r.namespace,
-		r.model,
-		r.tag,
-		r.build,
-	}
+	return slices.Clone(r.parts[:])
 }
 
 // Parts returns a sequence of the parts of a Name string from most specific
@@ -492,7 +478,7 @@ func NameParts(s string) iter.Seq2[NamePart, string] {
 func (r Name) Valid() bool {
 	// Parts ensures we only have valid parts, so no need to validate
 	// them here, only check if we have a name or not.
-	return r.model != ""
+	return r.parts[Model] != ""
 }
 
 // isValidPart returns true if given part is valid ascii [a-zA-Z0-9_\.-]
@@ -520,3 +506,10 @@ func isValidByte(kind NamePart, c byte) bool {
 	}
 	return false
 }
+
+func sumLens(a []string) (sum int) {
+	for _, n := range a {
+		sum += len(n)
+	}
+	return
+}

+ 35 - 38
x/model/name_test.go

@@ -11,7 +11,21 @@ import (
 	"testing"
 )
 
-var testNames = map[string]Name{
+type fields struct {
+	host, namespace, model, tag, build string
+}
+
+func fieldsFromName(p Name) fields {
+	return fields{
+		host:      p.parts[Host],
+		namespace: p.parts[Namespace],
+		model:     p.parts[Model],
+		tag:       p.parts[Tag],
+		build:     p.parts[Build],
+	}
+}
+
+var testNames = map[string]fields{
 	"mistral:latest":                 {model: "mistral", tag: "latest"},
 	"mistral":                        {model: "mistral"},
 	"mistral:30B":                    {model: "mistral", tag: "30B"},
@@ -23,7 +37,7 @@ var testNames = map[string]Name{
 	"llama2":                         {model: "llama2"},
 	"user/model":                     {namespace: "user", model: "model"},
 	"example.com/ns/mistral:7b+Q4_0": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "Q4_0"},
-	"example.com/ns/mistral:7b+x":    {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "X"},
+	"example.com/ns/mistral:7b+X":    {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "X"},
 
 	// preserves case for build
 	"x+b": {model: "x", build: "b"},
@@ -73,7 +87,7 @@ func TestNameParts(t *testing.T) {
 }
 
 func TestNamePartString(t *testing.T) {
-	if g := NamePart(-1).String(); g != "Unknown" {
+	if g := NamePart(-2).String(); g != "Unknown" {
 		t.Errorf("Unknown part = %q; want %q", g, "Unknown")
 	}
 	for kind, name := range kindNames {
@@ -83,34 +97,6 @@ func TestNamePartString(t *testing.T) {
 	}
 }
 
-func TestPartTooLong(t *testing.T) {
-	for i := Host; i <= Build; i++ {
-		t.Run(i.String(), func(t *testing.T) {
-			var p Name
-			switch i {
-			case Host:
-				p.host = strings.Repeat("a", MaxNamePartLen+1)
-			case Namespace:
-				p.namespace = strings.Repeat("a", MaxNamePartLen+1)
-			case Model:
-				p.model = strings.Repeat("a", MaxNamePartLen+1)
-			case Tag:
-				p.tag = strings.Repeat("a", MaxNamePartLen+1)
-			case Build:
-				p.build = strings.Repeat("a", MaxNamePartLen+1)
-			}
-			s := strings.Trim(p.String(), "+:/")
-			if len(s) != MaxNamePartLen+1 {
-				t.Errorf("len(String()) = %d; want %d", len(s), MaxNamePartLen+1)
-				t.Logf("String() = %q", s)
-			}
-			if ParseName(p.String()).Valid() {
-				t.Errorf("Valid(%q) = true; want false", p)
-			}
-		})
-	}
-}
-
 func TestParseName(t *testing.T) {
 	for baseName, want := range testNames {
 		for _, prefix := range []string{"", "https://", "http://"} {
@@ -119,19 +105,20 @@ func TestParseName(t *testing.T) {
 			s := prefix + baseName
 
 			t.Run(s, func(t *testing.T) {
-				got := ParseName(s)
-				if !got.EqualFold(want) {
+				name := ParseName(s)
+				got := fieldsFromName(name)
+				if got != want {
 					t.Errorf("ParseName(%q) = %q; want %q", s, got, want)
 				}
 
 				// test round-trip
-				if !ParseName(got.String()).EqualFold(got) {
-					t.Errorf("String() = %s; want %s", got.String(), baseName)
+				if !ParseName(name.String()).EqualFold(name) {
+					t.Errorf("String() = %s; want %s", name.String(), baseName)
 				}
 
-				if got.Valid() && got.model == "" {
+				if name.Valid() && name.DisplayModel() == "" {
 					t.Errorf("Valid() = true; Model() = %q; want non-empty name", got.model)
-				} else if !got.Valid() && got.DisplayModel() != "" {
+				} else if !name.Valid() && name.DisplayModel() != "" {
 					t.Errorf("Valid() = false; Model() = %q; want empty name", got.model)
 				}
 			})
@@ -405,7 +392,7 @@ func TestNameTextMarshal(t *testing.T) {
 			t.Fatal("MarshalText() = 0; want non-zero")
 		}
 	})
-	if allocs > 1 {
+	if allocs > 0 {
 		// TODO: Update when/if this lands:
 		// https://github.com/golang/go/issues/62384
 		//
@@ -414,6 +401,16 @@ func TestNameTextMarshal(t *testing.T) {
 	}
 }
 
+func TestNameStringAllocs(t *testing.T) {
+	name := ParseName("example.com/ns/mistral:latest+Q4_0")
+	allocs := testing.AllocsPerRun(1000, func() {
+		keep(name.String())
+	})
+	if allocs > 1 {
+		t.Errorf("String allocs = %v; want 0", allocs)
+	}
+}
+
 func ExampleFill() {
 	defaults := ParseName("registry.ollama.com/library/PLACEHOLDER:latest+Q4_0")
 	r := Fill(ParseName("mistral"), defaults)