Jelajahi Sumber

x/model: Name: implement sql.Scanner and driver.Valuer

Blake Mizerany 1 tahun lalu
induk
melakukan
1407fd3d4a
2 mengubah file dengan 59 tambahan dan 1 penghapusan
  1. 33 1
      x/model/name.go
  2. 26 0
      x/model/name_test.go

+ 33 - 1
x/model/name.go

@@ -3,6 +3,8 @@ package model
 import (
 	"bytes"
 	"cmp"
+	"database/sql"
+	"database/sql/driver"
 	"errors"
 	"hash/maphash"
 	"io"
@@ -330,7 +332,7 @@ func (r *Name) UnmarshalText(text []byte) error {
 		// called on an invalid/zero Name. If we allow UnmarshalText
 		// on a valid Name, then the Name will be mutated, breaking
 		// the immutability of the Name.
-		return errors.New("model.Name: UnmarshalText on valid Name")
+		return errors.New("model.Name: illegal UnmarshalText on valid Name")
 	}
 
 	// The contract of UnmarshalText  is that we copy to keep the text.
@@ -338,6 +340,36 @@ func (r *Name) UnmarshalText(text []byte) error {
 	return nil
 }
 
+var (
+	_ driver.Valuer = Name{}
+	_ sql.Scanner   = (*Name)(nil)
+)
+
+// Scan implements [database/sql.Scanner].
+func (r *Name) Scan(src any) error {
+	if r.Valid() {
+		// The invariant of Scan is that it should only be called on an
+		// invalid/zero Name. If we allow Scan on a valid Name, then the
+		// Name will be mutated, breaking the immutability of the Name.
+		return errors.New("model.Name: illegal Scan on valid Name")
+
+	}
+	switch v := src.(type) {
+	case string:
+		*r = ParseName(v)
+		return nil
+	case []byte:
+		*r = ParseName(string(v))
+		return nil
+	}
+	return errors.New("model.Name: invalid Scan source")
+}
+
+// Value implements [database/sql/driver.Valuer].
+func (r Name) Value() (driver.Value, error) {
+	return r.String(), nil
+}
+
 // Complete reports whether the Name is fully qualified. That is it has a
 // domain, namespace, name, tag, and build.
 func (r Name) Complete() bool {

+ 26 - 0
x/model/name_test.go

@@ -406,6 +406,32 @@ func TestNameTextUnmarshalCallOnValidName(t *testing.T) {
 	}
 }
 
+func TestSQL(t *testing.T) {
+	t.Run("Scan for already valid Name", func(t *testing.T) {
+		p := mustParse("x")
+		if err := p.Scan("mistral:latest+Q4_0"); err == nil {
+			t.Error("Scan() = nil; want error")
+		}
+	})
+	t.Run("Scan for invalid Name", func(t *testing.T) {
+		p := Name{}
+		if err := p.Scan("mistral:latest+Q4_0"); err != nil {
+			t.Errorf("Scan() = %v; want nil", err)
+		}
+		if p.String() != "mistral:latest+Q4_0" {
+			t.Errorf("String() = %q; want %q", p, "mistral:latest+Q4_0")
+		}
+	})
+	t.Run("Value", func(t *testing.T) {
+		p := mustParse("x")
+		if g, err := p.Value(); err != nil {
+			t.Errorf("Value() error = %v; want nil", err)
+		} else if g != "x" {
+			t.Errorf("Value() = %q; want %q", g, "x")
+		}
+	})
+}
+
 func TestNameTextMarshalAllocs(t *testing.T) {
 	var data []byte
 	name := ParseName("example.com/ns/mistral:latest+Q4_0")