Browse Source

types/model: make Name.Filepath substitute colons in host with ("%")

This makes the filepath legal on all supported platforms.

Fixes #4088
Blake Mizerany 1 year ago
parent
commit
61b287cf25
5 changed files with 112 additions and 5 deletions
  1. 42 0
      server/fixblobs.go
  2. 49 0
      server/fixblobs_test.go
  3. 7 0
      server/routes.go
  4. 2 2
      types/model/name.go
  5. 12 3
      types/model/name_test.go

+ 42 - 0
server/fixblobs.go

@@ -3,6 +3,7 @@ package server
 import (
 	"os"
 	"path/filepath"
+	"runtime"
 	"strings"
 )
 
@@ -24,3 +25,44 @@ func fixBlobs(dir string) error {
 		return nil
 	})
 }
+
+// fixManifests walks the provided dir and replaces (":") to ("%") for all
+// manifest files on non-Windows systems.
+func fixManifests(dir string) error {
+	if runtime.GOOS == "windows" {
+		return nil
+	}
+	return filepath.Walk(dir, func(oldPath string, info os.FileInfo, err error) error {
+		if err != nil {
+			return err
+		}
+		if info.IsDir() {
+			return nil
+		}
+
+		var partNum int
+		newPath := []byte(oldPath)
+		for i := len(newPath) - 1; i >= 0; i-- {
+			if partNum > 3 {
+				break
+			}
+			if partNum == 3 {
+				if newPath[i] == ':' {
+					newPath[i] = '%'
+					break
+				}
+				continue
+			}
+			if newPath[i] == '/' {
+				partNum++
+			}
+		}
+
+		newDir, _ := filepath.Split(string(newPath))
+		if err := os.MkdirAll(newDir, 0o755); err != nil {
+			return err
+		}
+
+		return os.Rename(oldPath, string(newPath))
+	})
+}

+ 49 - 0
server/fixblobs_test.go

@@ -64,6 +64,55 @@ func TestFixBlobs(t *testing.T) {
 	}
 }
 
+func TestFixManifests(t *testing.T) {
+	cases := []struct {
+		path []string
+		want []string
+	}{
+		{path: []string{}, want: []string{}},
+		{path: []string{"h/n/m/t"}, want: []string{"h/n/m/t"}},
+		{path: []string{"h:p/n/m/t"}, want: []string{"h%p/n/m/t"}},
+		{path: []string{"x:y/h:p/n/m/t"}, want: []string{"x:y/h%p/n/m/t"}},
+	}
+
+	for _, tt := range cases {
+		t.Run(strings.Join(tt.path, "|"), func(t *testing.T) {
+			hasColon := slices.ContainsFunc(tt.path, func(s string) bool { return strings.Contains(s, ":") })
+			if hasColon && runtime.GOOS == "windows" {
+				t.Skip("skipping test on windows")
+			}
+
+			rootDir := t.TempDir()
+			for _, path := range tt.path {
+				fullPath := filepath.Join(rootDir, path)
+				fullDir, _ := filepath.Split(fullPath)
+
+				t.Logf("creating dir %s", fullDir)
+				if err := os.MkdirAll(fullDir, 0o755); err != nil {
+					t.Fatal(err)
+				}
+
+				t.Logf("writing file %s", fullPath)
+				if err := os.WriteFile(fullPath, nil, 0o644); err != nil {
+					t.Fatal(err)
+				}
+			}
+
+			if err := fixManifests(rootDir); err != nil {
+				t.Fatal(err)
+			}
+
+			got := slurpFiles(os.DirFS(rootDir))
+
+			slices.Sort(tt.want)
+			slices.Sort(got)
+			if !slices.Equal(got, tt.want) {
+				t.Fatalf("got = %v, want %v", got, tt.want)
+			}
+		})
+	}
+}
+
 func slurpFiles(fsys fs.FS) []string {
 	var sfs []string
 	fn := func(path string, d fs.DirEntry, err error) error {

+ 7 - 0
server/routes.go

@@ -1046,6 +1046,13 @@ func Serve(ln net.Listener) error {
 	if err := fixBlobs(blobsDir); err != nil {
 		return err
 	}
+	manifestsDir, err := GetManifestPath()
+	if err != nil {
+		return err
+	}
+	if err := fixManifests(manifestsDir); err != nil {
+		return err
+	}
 
 	if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
 		// clean up unused layers and manifests

+ 2 - 2
types/model/name.go

@@ -161,7 +161,7 @@ func ParseNameBare(s string) Name {
 	}
 
 	scheme, host, ok := strings.Cut(s, "://")
-	if ! ok {
+	if !ok {
 		host = scheme
 	}
 	n.Host = host
@@ -243,7 +243,7 @@ func (n Name) Filepath() string {
 		panic("illegal attempt to get filepath of invalid name")
 	}
 	return strings.ToLower(filepath.Join(
-		n.Host,
+		strings.Replace(n.Host, ":", "%", 1),
 		n.Namespace,
 		n.Model,
 		n.Tag,

+ 12 - 3
types/model/name_test.go

@@ -27,7 +27,7 @@ func TestParseNameParts(t *testing.T) {
 				Model:     "model",
 				Tag:       "tag",
 			},
-			wantFilepath: filepath.Join("host:port", "namespace", "model", "tag"),
+			wantFilepath: filepath.Join("host%port", "namespace", "model", "tag"),
 		},
 		{
 			in: "host/namespace/model:tag",
@@ -47,7 +47,7 @@ func TestParseNameParts(t *testing.T) {
 				Model:     "model",
 				Tag:       "tag",
 			},
-			wantFilepath: filepath.Join("host:port", "namespace", "model", "tag"),
+			wantFilepath: filepath.Join("host%port", "namespace", "model", "tag"),
 		},
 		{
 			in: "host/namespace/model",
@@ -65,7 +65,7 @@ func TestParseNameParts(t *testing.T) {
 				Namespace: "namespace",
 				Model:     "model",
 			},
-			wantFilepath: filepath.Join("host:port", "namespace", "model", "latest"),
+			wantFilepath: filepath.Join("host%port", "namespace", "model", "latest"),
 		},
 		{
 			in: "namespace/model",
@@ -127,6 +127,15 @@ func TestParseNameParts(t *testing.T) {
 			},
 			wantValidDigest: true,
 		},
+		{
+			in: "y.com:443/n/model",
+			want: Name{
+				Host:      "y.com:443",
+				Namespace: "n",
+				Model:     "model",
+			},
+			wantFilepath: filepath.Join("y.com%443", "n", "model", "latest"),
+		},
 	}
 
 	for _, tt := range cases {