Browse Source

x/model: add Digest type

Blake Mizerany 1 year ago
parent
commit
2100129e83
4 changed files with 174 additions and 88 deletions
  1. 120 0
      x/model/digest.go
  2. 53 0
      x/model/digest_test.go
  3. 1 41
      x/model/name.go
  4. 0 47
      x/model/name_test.go

+ 120 - 0
x/model/digest.go

@@ -0,0 +1,120 @@
+package model
+
+import (
+	"database/sql"
+	"database/sql/driver"
+	"errors"
+	"fmt"
+	"log/slog"
+	"strings"
+	"unicode"
+)
+
+// Digest is an opaque reference to a model digest. It holds the digest type
+// and the digest itself.
+//
+// It is comparable with other Digests and can be used as a map key.
+type Digest struct {
+	typ    string
+	digest string
+}
+
+func (d Digest) Type() string   { return d.typ }
+func (d Digest) Digest() string { return d.digest }
+func (d Digest) Valid() bool    { return d != Digest{} }
+
+func (d Digest) String() string {
+	if !d.Valid() {
+		return ""
+	}
+	return fmt.Sprintf("%s-%s", d.typ, d.digest)
+}
+
+func (d Digest) MarshalText() ([]byte, error) {
+	return []byte(d.String()), nil
+}
+
+func (d *Digest) UnmarshalText(text []byte) error {
+	if d.Valid() {
+		return errors.New("model.Digest: illegal UnmarshalText on valid Digest")
+	}
+	*d = ParseDigest(string(text))
+	return nil
+}
+
+func (d Digest) LogValue() slog.Value {
+	return slog.StringValue(d.String())
+}
+
+var (
+	_ driver.Valuer = Digest{}
+	_ sql.Scanner   = (*Digest)(nil)
+)
+
+func (d *Digest) Scan(src any) error {
+	if d.Valid() {
+		return errors.New("model.Digest: illegal Scan on valid Digest")
+	}
+	switch v := src.(type) {
+	case string:
+		*d = ParseDigest(v)
+		return nil
+	case []byte:
+		*d = ParseDigest(string(v))
+		return nil
+	}
+	return fmt.Errorf("model.Digest: invalid Scan source %T", src)
+}
+
+func (d Digest) Value() (driver.Value, error) {
+	return d.String(), nil
+}
+
+// ParseDigest parses a string in the form of "<digest-type>-<digest>" into a
+// Digest.
+func ParseDigest(s string) Digest {
+	typ, digest, ok := strings.Cut(s, "-")
+	if ok && isValidDigestType(typ) && isValidHex(digest) {
+		return Digest{typ: typ, digest: digest}
+	}
+	return Digest{}
+}
+
+// isValidDigest returns true if the given string in the form of
+// "<digest-type>-<digest>", and <digest-type> is in the form of [a-z0-9]+
+// and <digest> is a valid hex string.
+//
+// It does not check if the digest is a valid hash for the given digest
+// type, or restrict the digest type to a known set of types. This is left
+// up to ueers of this package.
+func isValidDigest(s string) bool {
+	typ, digest, ok := strings.Cut(s, "-")
+	res := ok && isValidDigestType(typ) && isValidHex(digest)
+	fmt.Printf("DEBUG: %q: typ: %s, digest: %s, ok: %v res: %v\n", s, typ, digest, ok, res)
+	return res
+}
+
+func isValidDigestType(s string) bool {
+	if len(s) == 0 {
+		return false
+	}
+	for _, r := range s {
+		if !unicode.IsLower(r) && !unicode.IsDigit(r) {
+			return false
+		}
+	}
+	return true
+}
+
+func isValidHex(s string) bool {
+	if len(s) == 0 {
+		return false
+	}
+	for i := range s {
+		c := s[i]
+		if c < '0' || c > '9' && c < 'a' || c > 'f' {
+			return false
+		}
+	}
+	return true
+}

+ 53 - 0
x/model/digest_test.go

@@ -0,0 +1,53 @@
+package model
+
+import "testing"
+
+// - test scan
+// - test marshal text
+// - test unmarshal text
+// - test log value
+// - test string
+// - test type
+// - test digest
+// - test valid
+// - test driver valuer
+// - test sql scanner
+// - test parse digest
+
+var testDigests = map[string]Digest{
+	"":                 {},
+	"sha256-1234":      {typ: "sha256", digest: "1234"},
+	"sha256-5678":      {typ: "sha256", digest: "5678"},
+	"blake2-9abc":      {typ: "blake2", digest: "9abc"},
+	"-1234":            {},
+	"sha256-":          {},
+	"sha256-1234-5678": {},
+	"sha256-P":         {}, //         invalid  hex
+	"sha256-1234P":     {},
+	"---":              {},
+}
+
+func TestDigestParse(t *testing.T) {
+	// Test cases.
+	for s, want := range testDigests {
+		got := ParseDigest(s)
+		t.Logf("ParseDigest(%q) = %#v", s, got)
+		if got != want {
+			t.Errorf("ParseDigest(%q) = %q; want %q", s, got, want)
+		}
+	}
+}
+
+func TestDigestString(t *testing.T) {
+	// Test cases.
+	for s, d := range testDigests {
+		want := s
+		if !d.Valid() {
+			want = ""
+		}
+		got := d.String()
+		if got != want {
+			t.Errorf("ParseDigest(%q).String() = %q; want %q", s, got, want)
+		}
+	}
+}

+ 1 - 41
x/model/name.go

@@ -6,7 +6,6 @@ import (
 	"database/sql"
 	"database/sql/driver"
 	"errors"
-	"fmt"
 	"hash/maphash"
 	"io"
 	"iter"
@@ -14,7 +13,6 @@ import (
 	"slices"
 	"strings"
 	"sync"
-	"unicode"
 
 	"github.com/ollama/ollama/x/types/structs"
 )
@@ -25,6 +23,7 @@ var (
 	// other packages do not need to invent their own error type when they
 	// need to return an error for an invalid name.
 	ErrIncompleteName = errors.New("incomplete model name")
+	ErrInvalidDigest  = errors.New("invalid digest")
 )
 
 const MaxNamePartLen = 128
@@ -592,42 +591,3 @@ func isValidByte(kind NamePart, c byte) bool {
 	}
 	return false
 }
-
-// isValidDigest returns true if the given string in the form of
-// "<digest-type>-<digest>", and <digest-type> is in the form of [a-z0-9]+
-// and <digest> is a valid hex string.
-//
-// It does not check if the digest is a valid hash for the given digest
-// type, or restrict the digest type to a known set of types. This is left
-// up to ueers of this package.
-func isValidDigest(s string) bool {
-	typ, digest, ok := strings.Cut(s, "-")
-	res := ok && isValidDigestType(typ) && isValidHex(digest)
-	fmt.Printf("DEBUG: %q: typ: %s, digest: %s, ok: %v res: %v\n", s, typ, digest, ok, res)
-	return res
-}
-
-func isValidDigestType(s string) bool {
-	if len(s) == 0 {
-		return false
-	}
-	for _, r := range s {
-		if !unicode.IsLower(r) && !unicode.IsDigit(r) {
-			return false
-		}
-	}
-	return true
-}
-
-func isValidHex(s string) bool {
-	if len(s) == 0 {
-		return false
-	}
-	for i := range s {
-		c := s[i]
-		if c < '0' || c > '9' && c < 'a' || c > 'f' {
-			return false
-		}
-	}
-	return true
-}

+ 0 - 47
x/model/name_test.go

@@ -117,53 +117,6 @@ func TestNamePartString(t *testing.T) {
 	}
 }
 
-func TestIsValidDigestType(t *testing.T) {
-	cases := []struct {
-		in   string
-		want bool
-	}{
-		{"sha256", true},
-		{"blake2", true},
-
-		{"", false},
-		{"-sha256", false},
-		{"sha256-", false},
-		{"Sha256", false},
-		{"sha256(", false},
-		{" sha256", false},
-	}
-	for _, tt := range cases {
-		t.Run(tt.in, func(t *testing.T) {
-			if g := isValidDigestType(tt.in); g != tt.want {
-				t.Errorf("isValidDigestType(%q) = %v; want %v", tt.in, g, tt.want)
-			}
-		})
-	}
-}
-
-func TestIsValidDigest(t *testing.T) {
-	cases := []struct {
-		in   string
-		want bool
-	}{
-		{"", false},
-		{"sha256-123", true},
-		{"sha256-1234567890abcdef", true},
-		{"sha256-1234567890abcdef1234567890abcdeffffffffffffffffffffffffffffffffffffffffff", true},
-		{"!sha256-123", false},
-		{"sha256-123!", false},
-		{"sha256-", false},
-		{"-123", false},
-	}
-	for _, tt := range cases {
-		t.Run(tt.in, func(t *testing.T) {
-			if g := isValidDigest(tt.in); g != tt.want {
-				t.Errorf("isValidDigest(%q) = %v; want %v", tt.in, g, tt.want)
-			}
-		})
-	}
-}
-
 func TestParseName(t *testing.T) {
 	for baseName, want := range testNames {
 		for _, prefix := range []string{"", "https://", "http://"} {