Przeglądaj źródła

server: allow mixed-case model names on push, pull, cp, and create (#7676)

This change allows for mixed-case model names to be pushed, pulled,
copied, and created, which was previously disallowed because the Ollama
registry was backed by a Docker registry that enforced a naming
convention that disallowed mixed-case names, which is no longer the
case.

This does not break existing, intended, behaviors.

Also, make TestCase test a story of creating, updating, pulling, and
copying a model with case variations, ensuring the model's manifest is
updated correctly, and not duplicated across different files with
different case variations.
Blake Mizerany 5 miesięcy temu
rodzic
commit
4b8a2e341a
4 zmienionych plików z 163 dodań i 78 usunięć
  1. 19 0
      server/images.go
  2. 21 13
      server/routes.go
  3. 116 65
      server/routes_test.go
  4. 7 0
      types/model/name.go

+ 19 - 0
server/images.go

@@ -13,6 +13,7 @@ import (
 	"io"
 	"log"
 	"log/slog"
+	"net"
 	"net/http"
 	"net/url"
 	"os"
@@ -1071,6 +1072,21 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
 	return nil, errUnauthorized
 }
 
+// testMakeRequestDialContext specifies the dial function for the http client in
+// makeRequest. It can be used to resolve hosts in model names to local
+// addresses for testing. For example, the model name ("example.com/my/model")
+// can be directed to push/pull from "127.0.0.1:1234".
+//
+// This is not safe to set across goroutines. It should be set in
+// the main test goroutine, and not by tests marked to run in parallel with
+// t.Parallel().
+//
+// It should be cleared after use, otherwise it will affect other tests.
+//
+// Ideally we would have some set this up the stack, but the code is not
+// structured in a way that makes this easy, so this will have to do for now.
+var testMakeRequestDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
+
 func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *registryOptions) (*http.Response, error) {
 	if requestURL.Scheme != "http" && regOpts != nil && regOpts.Insecure {
 		requestURL.Scheme = "http"
@@ -1105,6 +1121,9 @@ func makeRequest(ctx context.Context, method string, requestURL *url.URL, header
 	}
 
 	resp, err := (&http.Client{
+		Transport: &http.Transport{
+			DialContext: testMakeRequestDialContext,
+		},
 		CheckRedirect: regOpts.CheckRedirect,
 	}).Do(req)
 	if err != nil {

+ 21 - 13
server/routes.go

@@ -540,7 +540,8 @@ func (s *Server) PullHandler(c *gin.Context) {
 		return
 	}
 
-	if err := checkNameExists(name); err != nil {
+	name, err = getExistingName(name)
+	if err != nil {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		return
 	}
@@ -621,19 +622,20 @@ func (s *Server) PushHandler(c *gin.Context) {
 	streamResponse(c, ch)
 }
 
-func checkNameExists(name model.Name) error {
-	names, err := Manifests(true)
+// getExistingName returns the original, on disk name if the input name is a
+// case-insensitive match, otherwise it returns the input name.
+func getExistingName(n model.Name) (model.Name, error) {
+	var zero model.Name
+	existing, err := Manifests(true)
 	if err != nil {
-		return err
+		return zero, err
 	}
-
-	for n := range names {
-		if strings.EqualFold(n.Filepath(), name.Filepath()) && n != name {
-			return errors.New("a model with that name already exists")
+	for e := range existing {
+		if n.EqualFold(e) {
+			return e, nil
 		}
 	}
-
-	return nil
+	return n, nil
 }
 
 func (s *Server) CreateHandler(c *gin.Context) {
@@ -652,7 +654,8 @@ func (s *Server) CreateHandler(c *gin.Context) {
 		return
 	}
 
-	if err := checkNameExists(name); err != nil {
+	name, err := getExistingName(name)
+	if err != nil {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		return
 	}
@@ -958,14 +961,19 @@ func (s *Server) CopyHandler(c *gin.Context) {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("source %q is invalid", r.Source)})
 		return
 	}
+	src, err := getExistingName(src)
+	if err != nil {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		return
+	}
 
 	dst := model.ParseName(r.Destination)
 	if !dst.IsValid() {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("destination %q is invalid", r.Destination)})
 		return
 	}
-
-	if err := checkNameExists(dst); err != nil {
+	dst, err = getExistingName(dst)
+	if err != nil {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		return
 	}

+ 116 - 65
server/routes_test.go

@@ -7,13 +7,18 @@ import (
 	"encoding/json"
 	"fmt"
 	"io"
+	"io/fs"
 	"math"
+	"math/rand/v2"
+	"net"
 	"net/http"
 	"net/http/httptest"
 	"os"
+	"path/filepath"
 	"sort"
 	"strings"
 	"testing"
+	"unicode"
 
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/llm"
@@ -473,83 +478,129 @@ func Test_Routes(t *testing.T) {
 	}
 }
 
-func TestCase(t *testing.T) {
-	t.Setenv("OLLAMA_MODELS", t.TempDir())
-
-	cases := []string{
-		"mistral",
-		"llama3:latest",
-		"library/phi3:q4_0",
-		"registry.ollama.ai/library/gemma:q5_K_M",
-		// TODO: host:port currently fails on windows (#4107)
-		// "localhost:5000/alice/bob:latest",
+func casingShuffle(s string) string {
+	rr := []rune(s)
+	for i := range rr {
+		if rand.N(2) == 0 {
+			rr[i] = unicode.ToUpper(rr[i])
+		} else {
+			rr[i] = unicode.ToLower(rr[i])
+		}
 	}
+	return string(rr)
+}
 
-	var s Server
-	for _, tt := range cases {
-		t.Run(tt, func(t *testing.T) {
-			w := createRequest(t, s.CreateHandler, api.CreateRequest{
-				Name:      tt,
-				Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
-				Stream:    &stream,
-			})
-
-			if w.Code != http.StatusOK {
-				t.Fatalf("expected status 200 got %d", w.Code)
-			}
+func TestManifestCaseSensitivity(t *testing.T) {
+	t.Setenv("OLLAMA_MODELS", t.TempDir())
 
-			expect, err := json.Marshal(map[string]string{"error": "a model with that name already exists"})
-			if err != nil {
-				t.Fatal(err)
+	r := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		w.WriteHeader(http.StatusOK)
+		io.WriteString(w, `{}`) //nolint:errcheck
+	}))
+	defer r.Close()
+
+	nameUsed := make(map[string]bool)
+	name := func() string {
+		const fqmn = "example/namespace/model:tag"
+		for {
+			v := casingShuffle(fqmn)
+			if nameUsed[v] {
+				continue
 			}
+			nameUsed[v] = true
+			return v
+		}
+	}
 
-			t.Run("create", func(t *testing.T) {
-				w = createRequest(t, s.CreateHandler, api.CreateRequest{
-					Name:      strings.ToUpper(tt),
-					Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
-					Stream:    &stream,
-				})
-
-				if w.Code != http.StatusBadRequest {
-					t.Fatalf("expected status 500 got %d", w.Code)
-				}
-
-				if !bytes.Equal(w.Body.Bytes(), expect) {
-					t.Fatalf("expected error %s got %s", expect, w.Body.String())
-				}
-			})
+	wantStableName := name()
 
-			t.Run("pull", func(t *testing.T) {
-				w := createRequest(t, s.PullHandler, api.PullRequest{
-					Name:   strings.ToUpper(tt),
-					Stream: &stream,
-				})
+	// checkManifestList tests that there is strictly one manifest in the
+	// models directory, and that the manifest is for the model under test.
+	checkManifestList := func() {
+		t.Helper()
 
-				if w.Code != http.StatusBadRequest {
-					t.Fatalf("expected status 500 got %d", w.Code)
-				}
+		mandir := filepath.Join(os.Getenv("OLLAMA_MODELS"), "manifests/")
+		var entries []string
+		t.Logf("dir entries:")
+		fsys := os.DirFS(mandir)
+		err := fs.WalkDir(fsys, ".", func(path string, info fs.DirEntry, err error) error {
+			if err != nil {
+				return err
+			}
+			t.Logf("    %s", fs.FormatDirEntry(info))
+			if info.IsDir() {
+				return nil
+			}
+			path = strings.TrimPrefix(path, mandir)
+			entries = append(entries, path)
+			return nil
+		})
+		if err != nil {
+			t.Fatalf("failed to walk directory: %v", err)
+		}
 
-				if !bytes.Equal(w.Body.Bytes(), expect) {
-					t.Fatalf("expected error %s got %s", expect, w.Body.String())
-				}
-			})
+		if len(entries) != 1 {
+			t.Errorf("len(got) = %d, want 1", len(entries))
+			return // do not use Fatal so following steps run
+		}
 
-			t.Run("copy", func(t *testing.T) {
-				w := createRequest(t, s.CopyHandler, api.CopyRequest{
-					Source:      tt,
-					Destination: strings.ToUpper(tt),
-				})
+		g := entries[0] // raw path
+		g = filepath.ToSlash(g)
+		w := model.ParseName(wantStableName).Filepath()
+		w = filepath.ToSlash(w)
+		if g != w {
+			t.Errorf("\ngot:  %s\nwant: %s", g, w)
+		}
+	}
 
-				if w.Code != http.StatusBadRequest {
-					t.Fatalf("expected status 500 got %d", w.Code)
-				}
+	checkOK := func(w *httptest.ResponseRecorder) {
+		t.Helper()
+		if w.Code != http.StatusOK {
+			t.Errorf("code = %d, want 200", w.Code)
+			t.Logf("body: %s", w.Body.String())
+		}
+	}
 
-				if !bytes.Equal(w.Body.Bytes(), expect) {
-					t.Fatalf("expected error %s got %s", expect, w.Body.String())
-				}
-			})
-		})
+	var s Server
+	testMakeRequestDialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
+		var d net.Dialer
+		return d.DialContext(ctx, "tcp", r.Listener.Addr().String())
 	}
+	t.Cleanup(func() { testMakeRequestDialContext = nil })
+
+	t.Logf("creating")
+	checkOK(createRequest(t, s.CreateHandler, api.CreateRequest{
+		// Start with the stable name, and later use a case-shuffled
+		// version.
+		Name: wantStableName,
+
+		Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
+		Stream:    &stream,
+	}))
+	checkManifestList()
+
+	t.Logf("creating (again)")
+	checkOK(createRequest(t, s.CreateHandler, api.CreateRequest{
+		Name:      name(),
+		Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
+		Stream:    &stream,
+	}))
+	checkManifestList()
+
+	t.Logf("pulling")
+	checkOK(createRequest(t, s.PullHandler, api.PullRequest{
+		Name:     name(),
+		Stream:   &stream,
+		Insecure: true,
+	}))
+	checkManifestList()
+
+	t.Logf("copying")
+	checkOK(createRequest(t, s.CopyHandler, api.CopyRequest{
+		Source:      name(),
+		Destination: name(),
+	}))
+	checkManifestList()
 }
 
 func TestShow(t *testing.T) {

+ 7 - 0
types/model/name.go

@@ -298,6 +298,13 @@ func (n Name) LogValue() slog.Value {
 	return slog.StringValue(n.String())
 }
 
+func (n Name) EqualFold(o Name) bool {
+	return strings.EqualFold(n.Host, o.Host) &&
+		strings.EqualFold(n.Namespace, o.Namespace) &&
+		strings.EqualFold(n.Model, o.Model) &&
+		strings.EqualFold(n.Tag, o.Tag)
+}
+
 func isValidLen(kind partKind, s string) bool {
 	switch kind {
 	case kindHost: