Browse Source

Merge pull request #4087 from ollama/mxyng/fix-host-port

types/model: fix name for hostport
Michael Yang 1 year ago
parent
commit
cb1e072643
2 changed files with 55 additions and 2 deletions
  1. 12 2
      types/model/name.go
  2. 43 0
      types/model/name_test.go

+ 12 - 2
types/model/name.go

@@ -143,18 +143,28 @@ func ParseNameBare(s string) Name {
 		n.RawDigest = MissingPart
 	}
 
-	s, n.Tag, _ = cutPromised(s, ":")
+	// "/" is an illegal tag character, so we can use it to split the host
+	if strings.LastIndex(s, ":") > strings.LastIndex(s, "/") {
+		s, n.Tag, _ = cutPromised(s, ":")
+	}
+
 	s, n.Model, promised = cutPromised(s, "/")
 	if !promised {
 		n.Model = s
 		return n
 	}
+
 	s, n.Namespace, promised = cutPromised(s, "/")
 	if !promised {
 		n.Namespace = s
 		return n
 	}
-	n.Host = s
+
+	scheme, host, ok := strings.Cut(s, "://")
+	if ! ok {
+		host = scheme
+	}
+	n.Host = host
 
 	return n
 }

+ 43 - 0
types/model/name_test.go

@@ -1,6 +1,7 @@
 package model
 
 import (
+	"path/filepath"
 	"reflect"
 	"runtime"
 	"testing"
@@ -15,8 +16,19 @@ func TestParseNameParts(t *testing.T) {
 	cases := []struct {
 		in              string
 		want            Name
+		wantFilepath    string
 		wantValidDigest bool
 	}{
+		{
+			in: "scheme://host:port/namespace/model:tag",
+			want: Name{
+				Host:      "host:port",
+				Namespace: "namespace",
+				Model:     "model",
+				Tag:       "tag",
+			},
+			wantFilepath: filepath.Join("host:port", "namespace", "model", "tag"),
+		},
 		{
 			in: "host/namespace/model:tag",
 			want: Name{
@@ -25,6 +37,17 @@ func TestParseNameParts(t *testing.T) {
 				Model:     "model",
 				Tag:       "tag",
 			},
+			wantFilepath: filepath.Join("host", "namespace", "model", "tag"),
+		},
+		{
+			in: "host:port/namespace/model:tag",
+			want: Name{
+				Host:      "host:port",
+				Namespace: "namespace",
+				Model:     "model",
+				Tag:       "tag",
+			},
+			wantFilepath: filepath.Join("host:port", "namespace", "model", "tag"),
 		},
 		{
 			in: "host/namespace/model",
@@ -33,6 +56,16 @@ func TestParseNameParts(t *testing.T) {
 				Namespace: "namespace",
 				Model:     "model",
 			},
+			wantFilepath: filepath.Join("host", "namespace", "model", "latest"),
+		},
+		{
+			in: "host:port/namespace/model",
+			want: Name{
+				Host:      "host:port",
+				Namespace: "namespace",
+				Model:     "model",
+			},
+			wantFilepath: filepath.Join("host:port", "namespace", "model", "latest"),
 		},
 		{
 			in: "namespace/model",
@@ -40,12 +73,14 @@ func TestParseNameParts(t *testing.T) {
 				Namespace: "namespace",
 				Model:     "model",
 			},
+			wantFilepath: filepath.Join("registry.ollama.ai", "namespace", "model", "latest"),
 		},
 		{
 			in: "model",
 			want: Name{
 				Model: "model",
 			},
+			wantFilepath: filepath.Join("registry.ollama.ai", "library", "model", "latest"),
 		},
 		{
 			in: "h/nn/mm:t",
@@ -55,6 +90,7 @@ func TestParseNameParts(t *testing.T) {
 				Model:     "mm",
 				Tag:       "t",
 			},
+			wantFilepath: filepath.Join("h", "nn", "mm", "t"),
 		},
 		{
 			in: part80 + "/" + part80 + "/" + part80 + ":" + part80,
@@ -64,6 +100,7 @@ func TestParseNameParts(t *testing.T) {
 				Model:     part80,
 				Tag:       part80,
 			},
+			wantFilepath: filepath.Join(part80, part80, part80, part80),
 		},
 		{
 			in: part350 + "/" + part80 + "/" + part80 + ":" + part80,
@@ -73,6 +110,7 @@ func TestParseNameParts(t *testing.T) {
 				Model:     part80,
 				Tag:       part80,
 			},
+			wantFilepath: filepath.Join(part350, part80, part80, part80),
 		},
 		{
 			in: "@digest",
@@ -97,6 +135,11 @@ func TestParseNameParts(t *testing.T) {
 			if !reflect.DeepEqual(got, tt.want) {
 				t.Errorf("parseName(%q) = %v; want %v", tt.in, got, tt.want)
 			}
+
+			got = ParseName(tt.in)
+			if tt.wantFilepath != "" && got.Filepath() != tt.wantFilepath {
+				t.Errorf("parseName(%q).Filepath() = %q; want %q", tt.in, got.Filepath(), tt.wantFilepath)
+			}
 		})
 	}
 }