浏览代码

types/model: init with Name and Digest types (#3541)

Blake Mizerany 1 年之前
父节点
当前提交
6a1de23175

+ 120 - 0
types/model/digest.go

@@ -0,0 +1,120 @@
+package model
+
+import (
+	"database/sql"
+	"database/sql/driver"
+	"errors"
+	"fmt"
+	"log/slog"
+	"strings"
+	"unicode"
+)
+
+// Digest represents a digest of a model Manifest. It is a comparable value
+// type and is immutable.
+//
+// The zero Digest is not a valid digest.
+type Digest struct {
+	s string
+}
+
+// Type returns the digest type of the digest.
+//
+// Example:
+//
+//	ParseDigest("sha256-1234").Type() // returns "sha256"
+func (d Digest) Type() string {
+	typ, _, _ := strings.Cut(d.s, "-")
+	return typ
+}
+
+// String returns the digest in the form of "<digest-type>-<digest>", or the
+// empty string if the digest is invalid.
+func (d Digest) String() string { return d.s }
+
+// IsValid returns true if the digest is valid (not zero).
+//
+// A valid digest may be created only by ParseDigest, or
+// ParseName(name).Digest().
+func (d Digest) IsValid() bool { return d.s != "" }
+
+// MarshalText implements encoding.TextMarshaler.
+func (d Digest) MarshalText() ([]byte, error) {
+	return []byte(d.String()), nil
+}
+
+// UnmarshalText implements encoding.TextUnmarshaler.
+func (d *Digest) UnmarshalText(text []byte) error {
+	if d.IsValid() {
+		return errors.New("model.Digest: illegal UnmarshalText on valid Digest")
+	}
+	*d = ParseDigest(string(text))
+	return nil
+}
+
+// LogValue implements slog.Value.
+func (d Digest) LogValue() slog.Value {
+	return slog.StringValue(d.String())
+}
+
+var (
+	_ driver.Valuer  = Digest{}
+	_ sql.Scanner    = (*Digest)(nil)
+	_ slog.LogValuer = Digest{}
+)
+
+// Scan implements the sql.Scanner interface.
+func (d *Digest) Scan(src any) error {
+	if d.IsValid() {
+		return errors.New("model.Digest: illegal Scan on valid Digest")
+	}
+	switch v := src.(type) {
+	case string:
+		*d = ParseDigest(v)
+		return nil
+	case []byte:
+		*d = ParseDigest(string(v))
+		return nil
+	}
+	return fmt.Errorf("model.Digest: invalid Scan source %T", src)
+}
+
+// Value implements the driver.Valuer interface.
+func (d Digest) Value() (driver.Value, error) {
+	return d.String(), nil
+}
+
+// ParseDigest parses a string in the form of "<digest-type>-<digest>" into a
+// Digest.
+func ParseDigest(s string) Digest {
+	typ, digest, ok := strings.Cut(s, "-")
+	if ok && isValidDigestType(typ) && isValidHex(digest) {
+		return Digest{s: s}
+	}
+	return Digest{}
+}
+
+func isValidDigestType(s string) bool {
+	if len(s) == 0 {
+		return false
+	}
+	for _, r := range s {
+		if !unicode.IsLower(r) && !unicode.IsDigit(r) {
+			return false
+		}
+	}
+	return true
+}
+
+func isValidHex(s string) bool {
+	if len(s) == 0 {
+		return false
+	}
+	for i := range s {
+		c := s[i]
+		if c < '0' || c > '9' && c < 'a' || c > 'f' {
+			return false
+		}
+	}
+	return true
+}

+ 46 - 0
types/model/digest_test.go

@@ -0,0 +1,46 @@
+package model
+
+import "testing"
+
+var testDigests = map[string]Digest{
+	"":                 {},
+	"sha256-1234":      {s: "sha256-1234"},
+	"sha256-5678":      {s: "sha256-5678"},
+	"blake2-9abc":      {s: "blake2-9abc"},
+	"-1234":            {},
+	"sha256-":          {},
+	"sha256-1234-5678": {},
+	"sha256-P":         {}, //         invalid  hex
+	"sha256-1234P":     {},
+	"---":              {},
+}
+
+func TestDigestParse(t *testing.T) {
+	// Test cases.
+	for s, want := range testDigests {
+		got := ParseDigest(s)
+		t.Logf("ParseDigest(%q) = %#v", s, got)
+		if got != want {
+			t.Errorf("ParseDigest(%q) = %q; want %q", s, got, want)
+		}
+	}
+}
+
+func TestDigestString(t *testing.T) {
+	// Test cases.
+	for s, d := range testDigests {
+		want := s
+		if !d.IsValid() {
+			want = ""
+		}
+		got := d.String()
+		if got != want {
+			t.Errorf("ParseDigest(%q).String() = %q; want %q", s, got, want)
+		}
+
+		got = ParseDigest(s).String()
+		if got != want {
+			t.Errorf("roundtrip ParseDigest(%q).String() = %q; want %q", s, got, want)
+		}
+	}
+}

+ 581 - 0
types/model/name.go

@@ -0,0 +1,581 @@
+package model
+
+import (
+	"cmp"
+	"errors"
+	"hash/maphash"
+	"io"
+	"log/slog"
+	"slices"
+	"strings"
+	"sync"
+
+	"github.com/ollama/ollama/types/structs"
+)
+
+// Errors
+var (
+	// ErrInvalidName, ErrIncompleteName, and ErrInvalidDigest are not
+	// used by this package, but are exported so that other packages can
+	// use them, instead of defining their own errors for them.
+	ErrInvalidName    = errors.New("invalid model name")
+	ErrIncompleteName = errors.New("incomplete model name")
+	ErrInvalidDigest  = errors.New("invalid digest")
+)
+
+// Defaults
+const (
+	// DefaultMask is the default mask used by [Name.DisplayShortest].
+	DefaultMask = "registry.ollama.ai/library/_:latest"
+
+	// DefaultFill is the default fill used by [ParseName].
+	DefaultFill = "registry.ollama.ai/library/_:latest"
+)
+
+const MaxNamePartLen = 128
+
+type PartKind int
+
+// Levels of concreteness
+const (
+	// Each value aligns with its index in the Name.parts array.
+
+	PartHost PartKind = iota
+	PartNamespace
+	PartModel
+	PartTag
+	PartBuild
+	PartDigest
+
+	// Invalid is a special part that is used to indicate that a part is
+	// invalid. It is not a valid part of a Name.
+	//
+	// It should be kept as the last part in the list.
+	PartInvalid
+)
+
+var kindNames = map[PartKind]string{
+	PartHost:      "Host",
+	PartNamespace: "Namespace",
+	PartModel:     "Name",
+	PartTag:       "Tag",
+	PartBuild:     "Build",
+	PartDigest:    "Digest",
+	PartInvalid:   "Invalid",
+}
+
+func (k PartKind) String() string {
+	return cmp.Or(kindNames[k], "Unknown")
+}
+
+// Name is an opaque reference to a model. It holds the parts of a model
+// with the case preserved, but is not directly comparable with other Names
+// since model names can be represented with different casing depending on
+// the use case. For instance, "Mistral" and "mistral" are the same model
+// but each version may have come from different sources (e.g. copied from a
+// Web page, or from a file path).
+//
+// Valid Names can ONLY be constructed by calling [ParseName].
+//
+// A Name is valid if and only if is have a valid Model part. The other parts
+// are optional.
+//
+// A Name is considered "complete" if it has all parts present. To check if a
+// Name is complete, use [Name.IsComplete].
+//
+// To compare two names in a case-insensitive manner, use [Name.EqualFold].
+//
+// The parts of a Name are:
+//
+//   - Host: the domain of the model (optional)
+//   - Namespace: the namespace of the model (optional)
+//   - Model: the name of the model (required)
+//   - Tag: the tag of the model (optional)
+//   - Build: the build of the model; usually the quantization or "file type" (optional)
+//
+// The parts can be obtained in their original form by calling [Name.Parts].
+//
+// To check if a Name has at minimum a valid model part, use [Name.IsValid].
+//
+// To make a Name by filling in missing parts from another Name, use [Fill].
+type Name struct {
+	_     structs.Incomparable
+	parts [6]string // host, namespace, model, tag, build, digest
+
+	// TODO(bmizerany): track offsets and hold s (raw string) here? We
+	// could pack the offsets all into a single uint64 since the first
+	// parts take less bits since their max offset is less than the max
+	// offset of the next part. This would save a ton of bytes per Name
+	// and mean zero allocations for String.
+}
+
+// ParseNameFill parses s into a Name, and returns the result of filling it with
+// defaults. The input string must be a valid string
+// representation of a model name in the form:
+//
+//	[host/][namespace/]<model>[:tag][+build][@<digest-type>-<digest>]
+//
+// The name part is required, all others are optional. If a part is missing,
+// it is left empty in the returned Name. If a part is invalid, the zero Ref
+// value is returned.
+//
+// The build part is normalized to uppercase.
+//
+// Examples of valid paths:
+//
+//	"example.com/library/mistral:7b+x"
+//	"example.com/eva/mistral:7b+Q4_0"
+//	"mistral:7b+x"
+//	"example.com/mike/mistral:latest+Q4_0"
+//	"example.com/bruce/mistral:latest"
+//	"example.com/pdevine/thisisfine:7b+Q4_0@sha256-1234567890abcdef"
+//
+// Examples of invalid paths:
+//
+//	"example.com/mistral:7b+"
+//	"example.com/mistral:7b+Q4_0+"
+//	"x/y/z/z:8n+I"
+//	""
+//
+// It returns the zero value if any part is invalid.
+//
+// As a rule of thumb, an valid name is one that can be round-tripped with
+// the [Name.String] method. That means ("x+") is invalid because
+// [Name.String] will not print a "+" if the build is empty.
+//
+// For more about filling in missing parts, see [Fill].
+func ParseNameFill(s, defaults string) Name {
+	var r Name
+	parts(s)(func(kind PartKind, part string) bool {
+		if kind == PartInvalid {
+			r = Name{}
+			return false
+		}
+		if kind == PartDigest && !ParseDigest(part).IsValid() {
+			r = Name{}
+			return false
+		}
+		r.parts[kind] = part
+		return true
+	})
+	if r.IsValid() || r.IsResolved() {
+		if defaults == "" {
+			return r
+		}
+		return Fill(r, ParseNameFill(defaults, ""))
+	}
+	return Name{}
+}
+
+// ParseName is equal to ParseNameFill(s, DefaultFill).
+func ParseName(s string) Name {
+	return ParseNameFill(s, DefaultFill)
+}
+
+func MustParseNameFill(s, defaults string) Name {
+	r := ParseNameFill(s, "")
+	if !r.IsValid() {
+		panic("model.MustParseName: invalid name: " + s)
+	}
+	return r
+}
+
+// Fill fills in the missing parts of dst with the parts of src.
+//
+// The returned Name will only be valid if dst is valid.
+func Fill(dst, src Name) Name {
+	var r Name
+	for i := range r.parts {
+		r.parts[i] = cmp.Or(dst.parts[i], src.parts[i])
+	}
+	return r
+}
+
+// WithBuild returns a copy of r with the build set to the given string.
+func (r Name) WithBuild(build string) Name {
+	r.parts[PartBuild] = build
+	return r
+}
+
+func (r Name) WithDigest(digest Digest) Name {
+	r.parts[PartDigest] = digest.String()
+	return r
+}
+
+var mapHashSeed = maphash.MakeSeed()
+
+// MapHash returns a case insensitive hash for use in maps and equality
+// checks. For a convenient way to compare names, use [Name.EqualFold].
+//
+//nolint:errcheck
+func (r Name) MapHash() uint64 {
+	// correctly hash the parts with case insensitive comparison
+	var h maphash.Hash
+	h.SetSeed(mapHashSeed)
+	for _, part := range r.Parts() {
+		// downcase the part for hashing
+		for i := range part {
+			c := part[i]
+			if c >= 'A' && c <= 'Z' {
+				c = c - 'A' + 'a'
+			}
+			h.WriteByte(c)
+		}
+	}
+	return h.Sum64()
+}
+
+func (r Name) slice(from, to PartKind) Name {
+	var v Name
+	copy(v.parts[from:to+1], r.parts[from:to+1])
+	return v
+}
+
+// DisplayShortest returns the shortest possible display string in form:
+//
+//	[host/][<namespace>/]<model>[:<tag>]
+//
+// The host is omitted if it is the mask host is the same as r.
+// The namespace is omitted if the host and the namespace are the same as r.
+// The tag is omitted if it is the mask tag is the same as r.
+func (r Name) DisplayShortest(mask string) string {
+	mask = cmp.Or(mask, DefaultMask)
+	d := ParseName(mask)
+	if !d.IsValid() {
+		panic("mask is an invalid Name")
+	}
+	equalSlice := func(form, to PartKind) bool {
+		return r.slice(form, to).EqualFold(d.slice(form, to))
+	}
+	if equalSlice(PartHost, PartNamespace) {
+		r.parts[PartNamespace] = ""
+	}
+	if equalSlice(PartHost, PartHost) {
+		r.parts[PartHost] = ""
+	}
+	if equalSlice(PartTag, PartTag) {
+		r.parts[PartTag] = ""
+	}
+	return r.slice(PartHost, PartTag).String()
+}
+
+// DisplayLong returns the fullest possible display string in form:
+//
+//	<namespace>/<model>:<tag>
+//
+// If any part is missing, it is omitted from the display string.
+func (r Name) DisplayLong() string {
+	return r.slice(PartNamespace, PartTag).String()
+}
+
+var seps = [...]string{
+	PartHost:      "/",
+	PartNamespace: "/",
+	PartModel:     ":",
+	PartTag:       "+",
+	PartBuild:     "@",
+	PartDigest:    "",
+}
+
+// WriteTo implements io.WriterTo. It writes the fullest possible display
+// string in form:
+//
+//	<host>/<namespace>/<model>:<tag>+<build>@<digest-type>-<digest>
+//
+// Missing parts and their separators are not written.
+//
+// The full digest is always prefixed with "@". That is if [Name.IsValid]
+// reports false and [Name.IsResolved] reports true, then the string is
+// returned as "@<digest-type>-<digest>".
+func (r Name) writeTo(w io.StringWriter) error {
+	var partsWritten int
+	for i := range r.parts {
+		if r.parts[i] == "" {
+			continue
+		}
+		if partsWritten > 0 || i == int(PartDigest) {
+			if _, err := w.WriteString(seps[i-1]); err != nil {
+				return err
+			}
+		}
+		if _, err := w.WriteString(r.parts[i]); err != nil {
+			return err
+		}
+		partsWritten++
+	}
+	return nil
+}
+
+var builderPool = sync.Pool{
+	New: func() interface{} {
+		return &strings.Builder{}
+	},
+}
+
+// String returns the fullest possible display string in form:
+//
+//	<host>/<namespace>/<model>:<tag>+<build>
+//
+// If any part is missing, it is omitted from the display string.
+//
+// For the fullest possible display string without the build, use
+// [Name.DisplayFullest].
+func (r Name) String() string {
+	b := builderPool.Get().(*strings.Builder)
+	defer builderPool.Put(b)
+	b.Reset()
+	b.Grow(50) // arbitrarily long enough for most names
+	_ = r.writeTo(b)
+	return b.String()
+}
+
+// GoString implements fmt.GoStringer. It returns a string suitable for
+// debugging and logging. It is similar to [Name.String] but it always
+// returns a string that includes all parts of the Name, with missing parts
+// replaced with a ("?").
+func (r Name) GoString() string {
+	for i := range r.parts {
+		r.parts[i] = cmp.Or(r.parts[i], "?")
+	}
+	return r.String()
+}
+
+// LogValue implements slog.Valuer.
+func (r Name) LogValue() slog.Value {
+	return slog.StringValue(r.GoString())
+}
+
+// IsComplete reports whether the Name is fully qualified. That is it has a
+// domain, namespace, name, tag, and build.
+func (r Name) IsComplete() bool {
+	return !slices.Contains(r.parts[:PartDigest], "")
+}
+
+// IsCompleteNoBuild is like [Name.IsComplete] but it does not require the
+// build part to be present.
+func (r Name) IsCompleteNoBuild() bool {
+	return !slices.Contains(r.parts[:PartBuild], "")
+}
+
+// IsResolved reports true if the Name has a valid digest.
+//
+// It is possible to have a valid Name, or a complete Name that is not
+// resolved.
+func (r Name) IsResolved() bool {
+	return r.Digest().IsValid()
+}
+
+// Digest returns the digest part of the Name, if any.
+//
+// If Digest returns a non-empty string, then [Name.IsResolved] will return
+// true, and digest is considered valid.
+func (r Name) Digest() Digest {
+	// This was already validated by ParseName, so we can just return it.
+	return Digest{r.parts[PartDigest]}
+}
+
+// EqualFold reports whether r and o are equivalent model names, ignoring
+// case.
+func (r Name) EqualFold(o Name) bool {
+	return r.CompareFold(o) == 0
+}
+
+// CompareFold performs a case-insensitive cmp.Compare on r and o.
+//
+// This can be used with [slices.SortFunc].
+//
+// For simple equality checks, use [Name.EqualFold].
+func (r Name) CompareFold(o Name) int {
+	return slices.CompareFunc(r.parts[:], o.parts[:], compareFold)
+}
+
+func compareFold(a, b string) int {
+	return slices.CompareFunc([]rune(a), []rune(b), func(a, b rune) int {
+		return cmp.Compare(downcase(a), downcase(b))
+	})
+}
+
+func downcase(r rune) rune {
+	if r >= 'A' && r <= 'Z' {
+		return r - 'A' + 'a'
+	}
+	return r
+}
+
+// TODO(bmizerany): driver.Value? (MarshalText etc should be enough)
+
+// Parts returns the parts of the Name in order of concreteness.
+//
+// The length of the returned slice is always 5.
+func (r Name) Parts() []string {
+	return slices.Clone(r.parts[:])
+}
+
+// iter_Seq2 is a iter.Seq2 defined here to avoid the current build
+// restrictions in the go1.22 iter package requiring the
+// goexperiment.rangefunc tag to be set via the GOEXPERIMENT=rangefunc flag,
+// which we are not yet ready to support.
+//
+// Once we are ready to support rangefunc, this can be removed and replaced
+// with the iter.Seq2 type.
+type iter_Seq2[A, B any] func(func(A, B) bool)
+
+// Parts returns a sequence of the parts of a Name string from most specific
+// to least specific.
+//
+// It normalizes the input string by removing "http://" and "https://" only.
+// No other normalizations are performed.
+func parts(s string) iter_Seq2[PartKind, string] {
+	return func(yield func(PartKind, string) bool) {
+		//nolint:gosimple
+		if strings.HasPrefix(s, "http://") {
+			s = s[len("http://"):]
+		}
+		//nolint:gosimple
+		if strings.HasPrefix(s, "https://") {
+			s = s[len("https://"):]
+		}
+
+		if len(s) > MaxNamePartLen || len(s) == 0 {
+			return
+		}
+
+		yieldValid := func(kind PartKind, part string) bool {
+			if !isValidPart(kind, part) {
+				yield(PartInvalid, "")
+				return false
+			}
+			return yield(kind, part)
+		}
+
+		numConsecutiveDots := 0
+		partLen := 0
+		state, j := PartDigest, len(s)
+		for i := len(s) - 1; i >= 0; i-- {
+			if partLen++; partLen > MaxNamePartLen {
+				// catch a part that is too long early, so
+				// we don't keep spinning on it, waiting for
+				// an isInValidPart check which would scan
+				// over it again.
+				yield(PartInvalid, "")
+				return
+			}
+
+			switch s[i] {
+			case '@':
+				switch state {
+				case PartDigest:
+					if !yieldValid(PartDigest, s[i+1:j]) {
+						return
+					}
+					if i == 0 {
+						// This is the form
+						// "@<digest>" which is valid.
+						//
+						// We're done.
+						return
+					}
+					state, j, partLen = PartBuild, i, 0
+				default:
+					yield(PartInvalid, "")
+					return
+				}
+			case '+':
+				switch state {
+				case PartBuild, PartDigest:
+					if !yieldValid(PartBuild, s[i+1:j]) {
+						return
+					}
+					state, j, partLen = PartTag, i, 0
+				default:
+					yield(PartInvalid, "")
+					return
+				}
+			case ':':
+				switch state {
+				case PartTag, PartBuild, PartDigest:
+					if !yieldValid(PartTag, s[i+1:j]) {
+						return
+					}
+					state, j, partLen = PartModel, i, 0
+				default:
+					yield(PartInvalid, "")
+					return
+				}
+			case '/':
+				switch state {
+				case PartModel, PartTag, PartBuild, PartDigest:
+					if !yieldValid(PartModel, s[i+1:j]) {
+						return
+					}
+					state, j = PartNamespace, i
+				case PartNamespace:
+					if !yieldValid(PartNamespace, s[i+1:j]) {
+						return
+					}
+					state, j, partLen = PartHost, i, 0
+				default:
+					yield(PartInvalid, "")
+					return
+				}
+			default:
+				if s[i] == '.' {
+					if numConsecutiveDots++; numConsecutiveDots > 1 {
+						yield(PartInvalid, "")
+						return
+					}
+				} else {
+					numConsecutiveDots = 0
+				}
+				if !isValidByteFor(state, s[i]) {
+					yield(PartInvalid, "")
+					return
+				}
+			}
+		}
+
+		if state <= PartNamespace {
+			yieldValid(state, s[:j])
+		} else {
+			yieldValid(PartModel, s[:j])
+		}
+	}
+}
+
+func (r Name) IsZero() bool {
+	return r.parts == [6]string{}
+}
+
+// IsValid reports if a model has at minimum a valid model part.
+func (r Name) IsValid() bool {
+	// Parts ensures we only have valid parts, so no need to validate
+	// them here, only check if we have a name or not.
+	return r.parts[PartModel] != ""
+}
+
+// isValidPart reports if s contains all valid characters for the given
+// part kind.
+func isValidPart(kind PartKind, s string) bool {
+	if s == "" {
+		return false
+	}
+	for _, c := range []byte(s) {
+		if !isValidByteFor(kind, c) {
+			return false
+		}
+	}
+	return true
+}
+
+func isValidByteFor(kind PartKind, c byte) bool {
+	if kind == PartNamespace && c == '.' {
+		return false
+	}
+	if c == '.' || c == '-' {
+		return true
+	}
+	if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c >= '0' && c <= '9' || c == '_' {
+		return true
+	}
+	return false
+}

+ 490 - 0
types/model/name_test.go

@@ -0,0 +1,490 @@
+package model
+
+import (
+	"bytes"
+	"cmp"
+	"fmt"
+	"log/slog"
+	"slices"
+	"strings"
+	"testing"
+)
+
+type fields struct {
+	host, namespace, model, tag, build string
+	digest                             string
+}
+
+func fieldsFromName(p Name) fields {
+	return fields{
+		host:      p.parts[PartHost],
+		namespace: p.parts[PartNamespace],
+		model:     p.parts[PartModel],
+		tag:       p.parts[PartTag],
+		build:     p.parts[PartBuild],
+		digest:    p.parts[PartDigest],
+	}
+}
+
+var testNames = map[string]fields{
+	"mistral:latest":                 {model: "mistral", tag: "latest"},
+	"mistral":                        {model: "mistral"},
+	"mistral:30B":                    {model: "mistral", tag: "30B"},
+	"mistral:7b":                     {model: "mistral", tag: "7b"},
+	"mistral:7b+Q4_0":                {model: "mistral", tag: "7b", build: "Q4_0"},
+	"mistral+KQED":                   {model: "mistral", build: "KQED"},
+	"mistral.x-3:7b+Q4_0":            {model: "mistral.x-3", tag: "7b", build: "Q4_0"},
+	"mistral:7b+q4_0":                {model: "mistral", tag: "7b", build: "q4_0"},
+	"llama2":                         {model: "llama2"},
+	"user/model":                     {namespace: "user", model: "model"},
+	"example.com/ns/mistral:7b+Q4_0": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "Q4_0"},
+	"example.com/ns/mistral:7b+X":    {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "X"},
+
+	// invalid digest
+	"mistral:latest@invalid256-": {},
+	"mistral:latest@-123":        {},
+	"mistral:latest@!-123":       {},
+	"mistral:latest@1-!":         {},
+	"mistral:latest@":            {},
+
+	// resolved
+	"x@sha123-1": {model: "x", digest: "sha123-1"},
+	"@sha456-2":  {digest: "sha456-2"},
+
+	"@@sha123-1": {},
+
+	// preserves case for build
+	"x+b": {model: "x", build: "b"},
+
+	// invalid (includes fuzzing trophies)
+	" / / : + ": {},
+	" / : + ":   {},
+	" : + ":     {},
+	" + ":       {},
+	" : ":       {},
+	" / ":       {},
+	" /":        {},
+	"/ ":        {},
+	"/":         {},
+	":":         {},
+	"+":         {},
+
+	// (".") in namepsace is not allowed
+	"invalid.com/7b+x": {},
+
+	"invalid:7b+Q4_0:latest": {},
+	"in valid":               {},
+	"invalid/y/z/foo":        {},
+	"/0":                     {},
+	"0 /0":                   {},
+	"0 /":                    {},
+	"0/":                     {},
+	":/0":                    {},
+	"+0/00000":               {},
+	"0+.\xf2\x80\xf6\x9d00000\xe5\x99\xe6\xd900\xd90\xa60\x91\xdc0\xff\xbf\x99\xe800\xb9\xdc\xd6\xc300\x970\xfb\xfd0\xe0\x8a\xe1\xad\xd40\x9700\xa80\x980\xdd0000\xb00\x91000\xfe0\x89\x9b\x90\x93\x9f0\xe60\xf7\x84\xb0\x87\xa5\xff0\xa000\x9a\x85\xf6\x85\xfe\xa9\xf9\xe9\xde00\xf4\xe0\x8f\x81\xad\xde00\xd700\xaa\xe000000\xb1\xee0\x91": {},
+	"0//0":                        {},
+	"m+^^^":                       {},
+	"file:///etc/passwd":          {},
+	"file:///etc/passwd:latest":   {},
+	"file:///etc/passwd:latest+u": {},
+
+	":x": {},
+	"+x": {},
+	"x+": {},
+
+	// Disallow ("\.+") in any part to prevent path traversal anywhere
+	// we convert the name to a path.
+	"../etc/passwd":  {},
+	".../etc/passwd": {},
+	"./../passwd":    {},
+	"./0+..":         {},
+
+	strings.Repeat("a", MaxNamePartLen):   {model: strings.Repeat("a", MaxNamePartLen)},
+	strings.Repeat("a", MaxNamePartLen+1): {},
+}
+
+// TestConsecutiveDots tests that consecutive dots are not allowed in any
+// part, to avoid path traversal. There also are some tests in testNames, but
+// this test is more exhaustive and exists to emphasize the importance of
+// preventing path traversal.
+func TestNameConsecutiveDots(t *testing.T) {
+	for i := 1; i < 10; i++ {
+		s := strings.Repeat(".", i)
+		if i > 1 {
+			if g := ParseNameFill(s, "").String(); g != "" {
+				t.Errorf("ParseName(%q) = %q; want empty string", s, g)
+			}
+		} else {
+			if g := ParseNameFill(s, "").String(); g != s {
+				t.Errorf("ParseName(%q) = %q; want %q", s, g, s)
+			}
+		}
+	}
+}
+
+func TestNameParts(t *testing.T) {
+	var p Name
+	if w, g := int(PartDigest+1), len(p.Parts()); w != g {
+		t.Errorf("Parts() = %d; want %d", g, w)
+	}
+}
+
+func TestNamePartString(t *testing.T) {
+	if g := PartKind(-2).String(); g != "Unknown" {
+		t.Errorf("Unknown part = %q; want %q", g, "Unknown")
+	}
+	for kind, name := range kindNames {
+		if g := kind.String(); g != name {
+			t.Errorf("%s = %q; want %q", kind, g, name)
+		}
+	}
+}
+
+func TestParseName(t *testing.T) {
+	for baseName, want := range testNames {
+		for _, prefix := range []string{"", "https://", "http://"} {
+			// We should get the same results with or without the
+			// http(s) prefixes
+			s := prefix + baseName
+
+			t.Run(s, func(t *testing.T) {
+				name := ParseNameFill(s, "")
+				got := fieldsFromName(name)
+				if got != want {
+					t.Errorf("ParseName(%q) = %q; want %q", s, got, want)
+				}
+
+				// test round-trip
+				if !ParseNameFill(name.String(), "").EqualFold(name) {
+					t.Errorf("ParseName(%q).String() = %s; want %s", s, name.String(), baseName)
+				}
+			})
+		}
+	}
+}
+
+func TestCompleteWithAndWithoutBuild(t *testing.T) {
+	cases := []struct {
+		in              string
+		complete        bool
+		completeNoBuild bool
+	}{
+		{"", false, false},
+		{"incomplete/mistral:7b+x", false, false},
+		{"incomplete/mistral:7b+Q4_0", false, false},
+		{"incomplete:7b+x", false, false},
+		{"complete.com/x/mistral:latest+Q4_0", true, true},
+		{"complete.com/x/mistral:latest", false, true},
+	}
+
+	for _, tt := range cases {
+		t.Run(tt.in, func(t *testing.T) {
+			p := ParseNameFill(tt.in, "")
+			t.Logf("ParseName(%q) = %#v", tt.in, p)
+			if g := p.IsComplete(); g != tt.complete {
+				t.Errorf("Complete(%q) = %v; want %v", tt.in, g, tt.complete)
+			}
+			if g := p.IsCompleteNoBuild(); g != tt.completeNoBuild {
+				t.Errorf("CompleteNoBuild(%q) = %v; want %v", tt.in, g, tt.completeNoBuild)
+			}
+		})
+	}
+
+	// Complete uses Parts which returns a slice, but it should be
+	// inlined when used in Complete, preventing any allocations or
+	// escaping to the heap.
+	allocs := testing.AllocsPerRun(1000, func() {
+		keep(ParseNameFill("complete.com/x/mistral:latest+Q4_0", "").IsComplete())
+	})
+	if allocs > 0 {
+		t.Errorf("Complete allocs = %v; want 0", allocs)
+	}
+}
+
+func TestNameLogValue(t *testing.T) {
+	cases := []string{
+		"example.com/library/mistral:latest+Q4_0",
+		"mistral:latest",
+		"mistral:7b+Q4_0",
+	}
+	for _, s := range cases {
+		t.Run(s, func(t *testing.T) {
+			var b bytes.Buffer
+			log := slog.New(slog.NewTextHandler(&b, nil))
+			name := ParseNameFill(s, "")
+			log.Info("", "name", name)
+			want := fmt.Sprintf("name=%s", name.GoString())
+			got := b.String()
+			if !strings.Contains(got, want) {
+				t.Errorf("expected log output to contain %q; got %q", want, got)
+			}
+		})
+	}
+}
+
+func TestNameGoString(t *testing.T) {
+	cases := []struct {
+		name         string
+		in           string
+		wantString   string
+		wantGoString string // default is tt.in
+	}{
+		{
+			name:         "Complete Name",
+			in:           "example.com/library/mistral:latest+Q4_0",
+			wantGoString: "example.com/library/mistral:latest+Q4_0@?",
+		},
+		{
+			name:         "Short Name",
+			in:           "mistral:latest",
+			wantGoString: "?/?/mistral:latest+?@?",
+		},
+		{
+			name:         "Long Name",
+			in:           "library/mistral:latest",
+			wantGoString: "?/library/mistral:latest+?@?",
+		},
+		{
+			name:         "Case Preserved",
+			in:           "Library/Mistral:Latest",
+			wantGoString: "?/Library/Mistral:Latest+?@?",
+		},
+		{
+			name:         "With digest",
+			in:           "Library/Mistral:Latest@sha256-123456",
+			wantGoString: "?/Library/Mistral:Latest+?@sha256-123456",
+		},
+	}
+
+	for _, tt := range cases {
+		t.Run(tt.name, func(t *testing.T) {
+			p := ParseNameFill(tt.in, "")
+			tt.wantGoString = cmp.Or(tt.wantGoString, tt.in)
+			if g := fmt.Sprintf("%#v", p); g != tt.wantGoString {
+				t.Errorf("GoString() = %q; want %q", g, tt.wantGoString)
+			}
+		})
+	}
+}
+
+func TestDisplayShortest(t *testing.T) {
+	cases := []struct {
+		in        string
+		mask      string
+		want      string
+		wantPanic bool
+	}{
+		{"example.com/library/mistral:latest+Q4_0", "example.com/library/_:latest", "mistral", false},
+		{"example.com/library/mistral:latest+Q4_0", "example.com/_/_:latest", "library/mistral", false},
+		{"example.com/library/mistral:latest+Q4_0", "", "example.com/library/mistral", false},
+		{"example.com/library/mistral:latest+Q4_0", "", "example.com/library/mistral", false},
+
+		// case-insensitive
+		{"Example.com/library/mistral:latest+Q4_0", "example.com/library/_:latest", "mistral", false},
+		{"example.com/Library/mistral:latest+Q4_0", "example.com/library/_:latest", "mistral", false},
+		{"example.com/library/Mistral:latest+Q4_0", "example.com/library/_:latest", "Mistral", false},
+		{"example.com/library/mistral:Latest+Q4_0", "example.com/library/_:latest", "mistral", false},
+		{"example.com/library/mistral:Latest+q4_0", "example.com/library/_:latest", "mistral", false},
+
+		// invalid mask
+		{"example.com/library/mistral:latest+Q4_0", "example.com/mistral", "", true},
+
+		// DefaultMask
+		{"registry.ollama.ai/library/mistral:latest+Q4_0", DefaultMask, "mistral", false},
+
+		// Auto-Fill
+		{"x", "example.com/library/_:latest", "x", false},
+		{"x", "example.com/library/_:latest+Q4_0", "x", false},
+		{"x/y:z", "a.com/library/_:latest+Q4_0", "x/y:z", false},
+		{"x/y:z", "a.com/library/_:latest+Q4_0", "x/y:z", false},
+	}
+
+	for _, tt := range cases {
+		t.Run("", func(t *testing.T) {
+			defer func() {
+				if tt.wantPanic {
+					if recover() == nil {
+						t.Errorf("expected panic")
+					}
+				}
+			}()
+
+			p := ParseNameFill(tt.in, "")
+			t.Logf("ParseName(%q) = %#v", tt.in, p)
+			if g := p.DisplayShortest(tt.mask); g != tt.want {
+				t.Errorf("got = %q; want %q", g, tt.want)
+			}
+		})
+	}
+}
+
+func TestParseNameAllocs(t *testing.T) {
+	allocs := testing.AllocsPerRun(1000, func() {
+		keep(ParseNameFill("example.com/mistral:7b+Q4_0", ""))
+	})
+	if allocs > 0 {
+		t.Errorf("ParseName allocs = %v; want 0", allocs)
+	}
+}
+
+func BenchmarkParseName(b *testing.B) {
+	b.ReportAllocs()
+
+	for range b.N {
+		keep(ParseNameFill("example.com/mistral:7b+Q4_0", ""))
+	}
+}
+
+func FuzzParseName(f *testing.F) {
+	f.Add("example.com/mistral:7b+Q4_0")
+	f.Add("example.com/mistral:7b+q4_0")
+	f.Add("example.com/mistral:7b+x")
+	f.Add("x/y/z:8n+I")
+	f.Add(":x")
+	f.Add("@sha256-123456")
+	f.Add("example.com/mistral:latest+Q4_0@sha256-123456")
+	f.Add(":@!@")
+	f.Add("...")
+	f.Fuzz(func(t *testing.T, s string) {
+		r0 := ParseNameFill(s, "")
+
+		if strings.Contains(s, "..") && !r0.IsZero() {
+			t.Fatalf("non-zero value for path with '..': %q", s)
+		}
+
+		if !r0.IsValid() && !r0.IsResolved() {
+			if !r0.EqualFold(Name{}) {
+				t.Errorf("expected invalid path to be zero value; got %#v", r0)
+			}
+			t.Skipf("invalid path: %q", s)
+		}
+
+		for _, p := range r0.Parts() {
+			if len(p) > MaxNamePartLen {
+				t.Errorf("part too long: %q", p)
+			}
+		}
+
+		if !strings.EqualFold(r0.String(), s) {
+			t.Errorf("String() did not round-trip with case insensitivity: %q\ngot  = %q\nwant = %q", s, r0.String(), s)
+		}
+
+		r1 := ParseNameFill(r0.String(), "")
+		if !r0.EqualFold(r1) {
+			t.Errorf("round-trip mismatch: %+v != %+v", r0, r1)
+		}
+	})
+}
+
+func TestFill(t *testing.T) {
+	cases := []struct {
+		dst  string
+		src  string
+		want string
+	}{
+		{"mistral", "o.com/library/PLACEHOLDER:latest+Q4_0", "o.com/library/mistral:latest+Q4_0"},
+		{"o.com/library/mistral", "PLACEHOLDER:latest+Q4_0", "o.com/library/mistral:latest+Q4_0"},
+		{"", "o.com/library/mistral:latest+Q4_0", "o.com/library/mistral:latest+Q4_0"},
+	}
+
+	for _, tt := range cases {
+		t.Run(tt.dst, func(t *testing.T) {
+			r := Fill(ParseNameFill(tt.dst, ""), ParseNameFill(tt.src, ""))
+			if r.String() != tt.want {
+				t.Errorf("Fill(%q, %q) = %q; want %q", tt.dst, tt.src, r, tt.want)
+			}
+		})
+	}
+}
+
+func TestNameStringAllocs(t *testing.T) {
+	name := ParseNameFill("example.com/ns/mistral:latest+Q4_0", "")
+	allocs := testing.AllocsPerRun(1000, func() {
+		keep(name.String())
+	})
+	if allocs > 1 {
+		t.Errorf("String allocs = %v; want 0", allocs)
+	}
+}
+
+func ExampleFill() {
+	defaults := ParseNameFill("registry.ollama.com/library/PLACEHOLDER:latest+Q4_0", "")
+	r := Fill(ParseNameFill("mistral", ""), defaults)
+	fmt.Println(r)
+
+	// Output:
+	// registry.ollama.com/library/mistral:latest+Q4_0
+}
+
+func ExampleName_MapHash() {
+	m := map[uint64]bool{}
+
+	// key 1
+	m[ParseNameFill("mistral:latest+q4", "").MapHash()] = true
+	m[ParseNameFill("miSTRal:latest+Q4", "").MapHash()] = true
+	m[ParseNameFill("mistral:LATest+Q4", "").MapHash()] = true
+
+	// key 2
+	m[ParseNameFill("mistral:LATest", "").MapHash()] = true
+
+	fmt.Println(len(m))
+	// Output:
+	// 2
+}
+
+func ExampleName_CompareFold_sort() {
+	names := []Name{
+		ParseNameFill("mistral:latest", ""),
+		ParseNameFill("mistRal:7b+q4", ""),
+		ParseNameFill("MIstral:7b", ""),
+	}
+
+	slices.SortFunc(names, Name.CompareFold)
+
+	for _, n := range names {
+		fmt.Println(n)
+	}
+
+	// Output:
+	// MIstral:7b
+	// mistRal:7b+q4
+	// mistral:latest
+}
+
+func ExampleName_completeAndResolved() {
+	for _, s := range []string{
+		"x/y/z:latest+q4_0@sha123-1",
+		"x/y/z:latest+q4_0",
+		"@sha123-1",
+	} {
+		name := ParseNameFill(s, "")
+		fmt.Printf("complete:%v resolved:%v  digest:%s\n", name.IsComplete(), name.IsResolved(), name.Digest())
+	}
+
+	// Output:
+	// complete:true resolved:true  digest:sha123-1
+	// complete:true resolved:false  digest:
+	// complete:false resolved:true  digest:sha123-1
+}
+
+func ExampleName_DisplayShortest() {
+	name := ParseNameFill("example.com/jmorganca/mistral:latest+Q4_0", "")
+
+	fmt.Println(name.DisplayShortest("example.com/jmorganca/_:latest"))
+	fmt.Println(name.DisplayShortest("example.com/_/_:latest"))
+	fmt.Println(name.DisplayShortest("example.com/_/_:_"))
+	fmt.Println(name.DisplayShortest("_/_/_:_"))
+
+	// Default
+	name = ParseNameFill("registry.ollama.ai/library/mistral:latest+Q4_0", "")
+	fmt.Println(name.DisplayShortest(""))
+
+	// Output:
+	// mistral
+	// jmorganca/mistral
+	// jmorganca/mistral:latest
+	// example.com/jmorganca/mistral:latest
+	// mistral
+}
+
+func keep[T any](v T) T { return v }

+ 2 - 0
types/model/testdata/fuzz/FuzzParseRef/1d43ee52085cb4aa

@@ -0,0 +1,2 @@
+go test fuzz v1
+string("/0")

+ 2 - 0
types/model/testdata/fuzz/FuzzParseRef/27fd759314f0e6d6

@@ -0,0 +1,2 @@
+go test fuzz v1
+string("0//0")

+ 2 - 0
types/model/testdata/fuzz/FuzzParseRef/3e3b70dba384074d

@@ -0,0 +1,2 @@
+go test fuzz v1
+string("0 /0")

+ 2 - 0
types/model/testdata/fuzz/FuzzParseRef/71f1fdff711b6dab

@@ -0,0 +1,2 @@
+go test fuzz v1
+string("+0/00000")

+ 2 - 0
types/model/testdata/fuzz/FuzzParseRef/82c2975c430ac608

@@ -0,0 +1,2 @@
+go test fuzz v1
+string(":")

+ 2 - 0
types/model/testdata/fuzz/FuzzParseRef/b51b1c875e61a948

@@ -0,0 +1,2 @@
+go test fuzz v1
+string("0+.\xf2\x80\xf6\x9d00000\xe5\x99\xe6\xd900\xd90\xa60\x91\xdc0\xff\xbf\x99\xe800\xb9\xdc\xd6\xc300\x970\xfb\xfd0\xe0\x8a\xe1\xad\xd40\x9700\xa80\x980\xdd0000\xb00\x91000\xfe0\x89\x9b\x90\x93\x9f0\xe60\xf7\x84\xb0\x87\xa5\xff0\xa000\x9a\x85\xf6\x85\xfe\xa9\xf9\xe9\xde00\xf4\xe0\x8f\x81\xad\xde00\xd700\xaa\xe000000\xb1\xee0\x91")

+ 15 - 0
types/structs/structs.go

@@ -0,0 +1,15 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package structs contains the Incomparable type.
+package structs
+
+// Incomparable is a zero-width incomparable type. If added as the
+// first field in a struct, it marks that struct as not comparable
+// (can't do == or be a map key) and usually doesn't add any width to
+// the struct (unless the struct has only small fields).
+//
+// By making a struct incomparable, you can prevent misuse (prevent
+// people from using ==), but also you can shrink generated binaries,
+// as the compiler can omit equality funcs from the binary.
+type Incomparable [0]func()