Explorar o código

x/model: make equality checks case-insensitive

Blake Mizerany hai 1 ano
pai
achega
bfe89d6fa0
Modificáronse 2 ficheiros con 37 adicións e 7 borrados
  1. 30 0
      x/model/name.go
  2. 7 7
      x/model/name_test.go

+ 30 - 0
x/model/name.go

@@ -2,9 +2,12 @@ package model
 
 import (
 	"cmp"
+	"hash/maphash"
 	"iter"
 	"slices"
 	"strings"
+
+	"github.com/ollama/ollama/x/types/structs"
 )
 
 const MaxNameLength = 255
@@ -36,6 +39,8 @@ var kindNames = map[NamePart]string{
 //
 // Users or Name must check Valid before using it.
 type Name struct {
+	_ structs.Incomparable
+
 	host      string
 	namespace string
 	model     string
@@ -43,6 +48,27 @@ type Name struct {
 	build     string
 }
 
+var mapHashSeed = maphash.MakeSeed()
+
+// MapHash returns a case insensitive hash for use in maps and equality
+// checks. For a convienent way to compare names, use [EqualFold].
+func (r Name) MapHash() uint64 {
+	// correctly hash the parts with case insensitive comparison
+	var h maphash.Hash
+	h.SetSeed(mapHashSeed)
+	for _, part := range r.Parts() {
+		// downcase the part for hashing
+		for i := range part {
+			c := part[i]
+			if c >= 'A' && c <= 'Z' {
+				c = c - 'A' + 'a'
+			}
+			h.WriteByte(c)
+		}
+	}
+	return h.Sum64()
+}
+
 // Format returns a string representation of the ref with the given
 // concreteness. If a part is missing, it is replaced with a loud
 // placeholder.
@@ -135,6 +161,10 @@ func (r Name) Model() string     { return r.model }
 func (r Name) Tag() string       { return r.tag }
 func (r Name) Build() string     { return r.build }
 
+func (r Name) EqualFold(o Name) bool {
+	return r.MapHash() == o.MapHash()
+}
+
 // ParseName parses s into a Name. The input string must be a valid form of
 // a model name in the form:
 //

+ 7 - 7
x/model/name_test.go

@@ -49,21 +49,21 @@ func TestNameParts(t *testing.T) {
 }
 
 func TestParseName(t *testing.T) {
-	for s, want := range testNames {
+	for baseName, want := range testNames {
 		for _, prefix := range []string{"", "https://", "http://"} {
 			// We should get the same results with or without the
 			// http(s) prefixes
-			s := prefix + s
+			s := prefix + baseName
 
 			t.Run(s, func(t *testing.T) {
 				got := ParseName(s)
-				if got != want {
+				if !got.EqualFold(want) {
 					t.Errorf("ParseName(%q) = %q; want %q", s, got, want)
 				}
 
 				// test round-trip
-				if ParseName(got.String()) != got {
-					t.Errorf("String() = %s; want %s", got.String(), s)
+				if !ParseName(got.String()).EqualFold(got) {
+					t.Errorf("String() = %s; want %s", got.String(), baseName)
 				}
 
 				if got.Valid() && got.Model() == "" {
@@ -190,7 +190,7 @@ func FuzzParseName(f *testing.F) {
 	f.Fuzz(func(t *testing.T, s string) {
 		r0 := ParseName(s)
 		if !r0.Valid() {
-			if r0 != (Name{}) {
+			if !r0.EqualFold(Name{}) {
 				t.Errorf("expected invalid path to be zero value; got %#v", r0)
 			}
 			t.Skipf("invalid path: %q", s)
@@ -207,7 +207,7 @@ func FuzzParseName(f *testing.F) {
 		}
 
 		r1 := ParseName(r0.String())
-		if r0 != r1 {
+		if !r0.EqualFold(r1) {
 			t.Errorf("round-trip mismatch: %+v != %+v", r0, r1)
 		}