Browse Source

x/model: add MarshalText and UnmarshalText to Name

Blake Mizerany 1 year ago
parent
commit
45d8d22785
2 changed files with 107 additions and 1 deletions
  1. 45 1
      x/model/name.go
  2. 62 0
      x/model/name_test.go

+ 45 - 1
x/model/name.go

@@ -8,6 +8,8 @@ import (
 	"log/slog"
 	"slices"
 	"strings"
+	"sync"
+	"unsafe"
 
 	"github.com/ollama/ollama/x/types/structs"
 )
@@ -233,6 +235,12 @@ func (r Name) DisplayLong() string {
 	}).String()
 }
 
+var builderPool = sync.Pool{
+	New: func() interface{} {
+		return &strings.Builder{}
+	},
+}
+
 // String returns the fullest possible display string in form:
 //
 //	<host>/<namespace>/<model>:<tag>+<build>
@@ -242,7 +250,17 @@ func (r Name) DisplayLong() string {
 // For the fullest possible display string without the build, use
 // [Name.DisplayFullest].
 func (r Name) String() string {
-	var b strings.Builder
+	b := builderPool.Get().(*strings.Builder)
+	b.Reset()
+	defer builderPool.Put(b)
+	b.Grow(0 +
+		len(r.host) +
+		len(r.namespace) +
+		len(r.model) +
+		len(r.tag) +
+		len(r.build) +
+		4, // 4 possible separators
+	)
 	if r.host != "" {
 		b.WriteString(r.host)
 		b.WriteString("/")
@@ -282,6 +300,32 @@ func (r Name) LogValue() slog.Value {
 	return slog.StringValue(r.GoString())
 }
 
+// MarshalText implements encoding.TextMarshaler.
+func (r Name) MarshalText() ([]byte, error) {
+	// unsafeBytes is safe here because we gurantee that the string is
+	// never used after this function returns.
+	//
+	// TODO: We can remove this if https://github.com/golang/go/issues/62384
+	// lands.
+	return unsafeBytes(r.String()), nil
+}
+
+func unsafeBytes(s string) []byte {
+	return *(*[]byte)(unsafe.Pointer(&s))
+}
+
+// UnmarshalText implements encoding.TextUnmarshaler.
+func (r *Name) UnmarshalText(text []byte) error {
+	// unsafeString is safe here because the contract of UnmarshalText
+	// that text belongs to us for the duration of the call.
+	*r = ParseName(unsafeString(text))
+	return nil
+}
+
+func unsafeString(b []byte) string {
+	return *(*string)(unsafe.Pointer(&b))
+}
+
 // Complete reports whether the Name is fully qualified. That is it has a
 // domain, namespace, name, tag, and build.
 func (r Name) Complete() bool {

+ 62 - 0
x/model/name_test.go

@@ -3,6 +3,7 @@ package model
 import (
 	"bytes"
 	"cmp"
+	"errors"
 	"fmt"
 	"log/slog"
 	"slices"
@@ -352,6 +353,67 @@ func TestFill(t *testing.T) {
 	}
 }
 
+func TestNameTextMarshal(t *testing.T) {
+	cases := []struct {
+		in      string
+		want    string
+		wantErr error
+	}{
+		{"example.com/mistral:latest+Q4_0", "", nil},
+		{"mistral:latest+Q4_0", "mistral:latest+Q4_0", nil},
+		{"mistral:latest", "mistral:latest", nil},
+		{"mistral", "mistral", nil},
+		{"mistral:7b", "mistral:7b", nil},
+		{"example.com/library/mistral:latest+Q4_0", "example.com/library/mistral:latest+Q4_0", nil},
+	}
+
+	for _, tt := range cases {
+		t.Run(tt.in, func(t *testing.T) {
+			p := ParseName(tt.in)
+			got, err := p.MarshalText()
+			if !errors.Is(err, tt.wantErr) {
+				t.Fatalf("MarshalText() error = %v; want %v", err, tt.wantErr)
+			}
+			if string(got) != tt.want {
+				t.Errorf("MarshalText() = %q; want %q", got, tt.want)
+			}
+
+			var r Name
+			if err := r.UnmarshalText(got); err != nil {
+				t.Fatalf("UnmarshalText() error = %v; want nil", err)
+			}
+			if !r.EqualFold(p) {
+				t.Errorf("UnmarshalText() = %q; want %q", r, p)
+			}
+		})
+	}
+
+	var data []byte
+	name := ParseName("example.com/ns/mistral:latest+Q4_0")
+	if !name.Complete() {
+		// sanity check
+		t.Fatal("name is not complete")
+	}
+
+	allocs := testing.AllocsPerRun(1000, func() {
+		var err error
+		data, err = name.MarshalText()
+		if err != nil {
+			t.Fatal(err)
+		}
+		if len(data) == 0 {
+			t.Fatal("MarshalText() = 0; want non-zero")
+		}
+	})
+	if allocs > 1 {
+		// TODO: Update when/if this lands:
+		// https://github.com/golang/go/issues/62384
+		//
+		// Currently, the best we can do is 1 alloc.
+		t.Errorf("MarshalText allocs = %v; want <= 1", allocs)
+	}
+}
+
 func ExampleFill() {
 	defaults := ParseName("registry.ollama.com/library/PLACEHOLDER:latest+Q4_0")
 	r := Fill(ParseName("mistral"), defaults)