Browse Source

x/model: disallow . in namespace

Blake Mizerany 1 year ago
parent
commit
a6b8bdf938
2 changed files with 53 additions and 47 deletions
  1. 25 36
      x/model/name.go
  2. 28 11
      x/model/name_test.go

+ 25 - 36
x/model/name.go

@@ -56,7 +56,7 @@ var kindNames = map[NamePart]string{
 // To check if a Name is fully qualified, use [Name.Complete]. A fully
 // qualified name has all parts present.
 //
-// To update parts of a Name with defaults, use [Merge].
+// To update parts of a Name with defaults, use [Fill].
 type Name struct {
 	_ structs.Incomparable
 
@@ -122,19 +122,15 @@ func ParseName(s string) Name {
 	return r
 }
 
-// Merge performs a partial merge of src into dst. Only the non-name parts
-// are merged. The name part is always left untouched. Other parts are
-// merged if and only if they are missing in dst.
+// Fill fills in the missing parts of dst with the parts of src.
 //
 // Use this for merging a fully qualified ref with a partial ref, such as
 // when filling in a missing parts with defaults.
 //
 // The returned Name will only be valid if dst is valid.
-func Merge(dst, src Name) Name {
+func Fill(dst, src Name) Name {
 	return Name{
-		// name is left untouched
-		model: dst.model,
-
+		model:     cmp.Or(dst.model, src.model),
 		host:      cmp.Or(dst.host, src.host),
 		namespace: cmp.Or(dst.namespace, src.namespace),
 		tag:       cmp.Or(dst.tag, src.tag),
@@ -223,24 +219,10 @@ func (r Name) Complete() bool {
 	return r.Valid() && !slices.Contains(r.Parts(), "")
 }
 
-// CompleteWithoutBuild reports whether the ref would be complete if it had a
-// valid build.
-func (r Name) CompleteWithoutBuild() bool {
-	r.build = "x"
-	return r.Valid() && r.Complete()
-}
-
-// Less returns true if r is less concrete than o; false otherwise.
-func (r Name) Less(o Name) bool {
-	rp := r.Parts()
-	op := o.Parts()
-	for i := range rp {
-		if rp[i] < op[i] {
-			return true
-		}
-	}
-	return false
-}
+// TODO(bmizerany): Compare
+// TODO(bmizerany): MarshalText/UnmarshalText
+// TODO(bmizerany): LogValue
+// TODO(bmizerany): driver.Value? (MarshalText etc should be enough)
 
 // Parts returns the parts of the ref in order of concreteness.
 //
@@ -289,7 +271,7 @@ func NameParts(s string) iter.Seq2[NamePart, string] {
 		}
 
 		yieldValid := func(kind NamePart, part string) bool {
-			if !isValidPart(part) {
+			if !isValidPart(kind, part) {
 				yield(Invalid, "")
 				return false
 			}
@@ -338,7 +320,7 @@ func NameParts(s string) iter.Seq2[NamePart, string] {
 					return
 				}
 			default:
-				if !isValidPart(s[i : i+1]) {
+				if !isValidPart(state, s[i:i+1]) {
 					yield(Invalid, "")
 					return
 				}
@@ -362,20 +344,27 @@ func (r Name) Valid() bool {
 }
 
 // isValidPart returns true if given part is valid ascii [a-zA-Z0-9_\.-]
-func isValidPart(s string) bool {
+func isValidPart(kind NamePart, s string) bool {
 	if s == "" {
 		return false
 	}
 	for _, c := range []byte(s) {
-		if c == '.' || c == '-' {
-			return true
-		}
-		if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c >= '0' && c <= '9' || c == '_' {
-			continue
-		} else {
+		if !isValidByte(kind, c) {
 			return false
-
 		}
 	}
 	return true
 }
+
+func isValidByte(kind NamePart, c byte) bool {
+	if kind == Namespace && c == '.' {
+		return false
+	}
+	if c == '.' || c == '-' {
+		return true
+	}
+	if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c >= '0' && c <= '9' || c == '_' {
+		return true
+	}
+	return false
+}

+ 28 - 11
x/model/name_test.go

@@ -92,7 +92,7 @@ func TestParseName(t *testing.T) {
 	}
 }
 
-func TestName(t *testing.T) {
+func TestComplete(t *testing.T) {
 	cases := []struct {
 		in                   string
 		complete             bool
@@ -113,9 +113,6 @@ func TestName(t *testing.T) {
 			if g := p.Complete(); g != tt.complete {
 				t.Errorf("Complete(%q) = %v; want %v", tt.in, g, tt.complete)
 			}
-			if g := p.CompleteWithoutBuild(); g != tt.completeWithoutBuild {
-				t.Errorf("CompleteWithoutBuild(%q) = %v; want %v", tt.in, g, tt.completeWithoutBuild)
-			}
 		})
 	}
 }
@@ -229,16 +226,36 @@ func FuzzParseName(f *testing.F) {
 	})
 }
 
-func ExampleMerge() {
-	src := ParseName("registry.ollama.com/mistral:latest")
-	dst := ParseName("mistral")
-	r := Merge(dst, src)
-	fmt.Println("src:", src)
-	fmt.Println("dst:", dst)
+func TestFill(t *testing.T) {
+	cases := []struct {
+		dst  string
+		src  string
+		want string
+	}{
+		{"mistral", "o.com/library/PLACEHOLDER:latest+Q4_0", "o.com/library/mistral:latest+Q4_0"},
+		{"o.com/library/mistral", "PLACEHOLDER:latest+Q4_0", "o.com/library/mistral:latest+Q4_0"},
+		{"", "o.com/library/mistral:latest+Q4_0", "o.com/library/mistral:latest+Q4_0"},
+	}
+
+	for _, tt := range cases {
+		t.Run(tt.dst, func(t *testing.T) {
+			r := Fill(ParseName(tt.dst), ParseName(tt.src))
+			if r.String() != tt.want {
+				t.Errorf("Fill(%q, %q) = %q; want %q", tt.dst, tt.src, r, tt.want)
+			}
+		})
+	}
+}
+
+func ExampleFill() {
+	r := Fill(
+		ParseName("mistral"),
+		ParseName("registry.ollama.com/library/PLACEHOLDER:latest+Q4_0"),
+	)
 	fmt.Println(r)
 
 	// Output:
-	// registry.ollama.com/mistral:latest+Q4_0
+	// registry.ollama.com/library/mistral:latest+Q4_0
 }
 
 func ExampleName_MapHash() {