|
@@ -7,13 +7,18 @@ import (
|
|
|
"encoding/json"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
+ "io/fs"
|
|
|
"math"
|
|
|
+ "math/rand/v2"
|
|
|
+ "net"
|
|
|
"net/http"
|
|
|
"net/http/httptest"
|
|
|
"os"
|
|
|
+ "path/filepath"
|
|
|
"sort"
|
|
|
"strings"
|
|
|
"testing"
|
|
|
+ "unicode"
|
|
|
|
|
|
"github.com/ollama/ollama/api"
|
|
|
"github.com/ollama/ollama/llm"
|
|
@@ -473,83 +478,129 @@ func Test_Routes(t *testing.T) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func TestCase(t *testing.T) {
|
|
|
- t.Setenv("OLLAMA_MODELS", t.TempDir())
|
|
|
-
|
|
|
- cases := []string{
|
|
|
- "mistral",
|
|
|
- "llama3:latest",
|
|
|
- "library/phi3:q4_0",
|
|
|
- "registry.ollama.ai/library/gemma:q5_K_M",
|
|
|
- // TODO: host:port currently fails on windows (#4107)
|
|
|
- // "localhost:5000/alice/bob:latest",
|
|
|
+func casingShuffle(s string) string {
|
|
|
+ rr := []rune(s)
|
|
|
+ for i := range rr {
|
|
|
+ if rand.N(2) == 0 {
|
|
|
+ rr[i] = unicode.ToUpper(rr[i])
|
|
|
+ } else {
|
|
|
+ rr[i] = unicode.ToLower(rr[i])
|
|
|
+ }
|
|
|
}
|
|
|
+ return string(rr)
|
|
|
+}
|
|
|
|
|
|
- var s Server
|
|
|
- for _, tt := range cases {
|
|
|
- t.Run(tt, func(t *testing.T) {
|
|
|
- w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
|
|
- Name: tt,
|
|
|
- Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
|
|
|
- Stream: &stream,
|
|
|
- })
|
|
|
-
|
|
|
- if w.Code != http.StatusOK {
|
|
|
- t.Fatalf("expected status 200 got %d", w.Code)
|
|
|
- }
|
|
|
+func TestManifestCaseSensitivity(t *testing.T) {
|
|
|
+ t.Setenv("OLLAMA_MODELS", t.TempDir())
|
|
|
|
|
|
- expect, err := json.Marshal(map[string]string{"error": "a model with that name already exists"})
|
|
|
- if err != nil {
|
|
|
- t.Fatal(err)
|
|
|
+ r := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
+ w.WriteHeader(http.StatusOK)
|
|
|
+ io.WriteString(w, `{}`) //nolint:errcheck
|
|
|
+ }))
|
|
|
+ defer r.Close()
|
|
|
+
|
|
|
+ nameUsed := make(map[string]bool)
|
|
|
+ name := func() string {
|
|
|
+ const fqmn = "example/namespace/model:tag"
|
|
|
+ for {
|
|
|
+ v := casingShuffle(fqmn)
|
|
|
+ if nameUsed[v] {
|
|
|
+ continue
|
|
|
}
|
|
|
+ nameUsed[v] = true
|
|
|
+ return v
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
- t.Run("create", func(t *testing.T) {
|
|
|
- w = createRequest(t, s.CreateHandler, api.CreateRequest{
|
|
|
- Name: strings.ToUpper(tt),
|
|
|
- Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
|
|
|
- Stream: &stream,
|
|
|
- })
|
|
|
-
|
|
|
- if w.Code != http.StatusBadRequest {
|
|
|
- t.Fatalf("expected status 500 got %d", w.Code)
|
|
|
- }
|
|
|
-
|
|
|
- if !bytes.Equal(w.Body.Bytes(), expect) {
|
|
|
- t.Fatalf("expected error %s got %s", expect, w.Body.String())
|
|
|
- }
|
|
|
- })
|
|
|
+ wantStableName := name()
|
|
|
|
|
|
- t.Run("pull", func(t *testing.T) {
|
|
|
- w := createRequest(t, s.PullHandler, api.PullRequest{
|
|
|
- Name: strings.ToUpper(tt),
|
|
|
- Stream: &stream,
|
|
|
- })
|
|
|
+ // checkManifestList tests that there is strictly one manifest in the
|
|
|
+ // models directory, and that the manifest is for the model under test.
|
|
|
+ checkManifestList := func() {
|
|
|
+ t.Helper()
|
|
|
|
|
|
- if w.Code != http.StatusBadRequest {
|
|
|
- t.Fatalf("expected status 500 got %d", w.Code)
|
|
|
- }
|
|
|
+ mandir := filepath.Join(os.Getenv("OLLAMA_MODELS"), "manifests/")
|
|
|
+ var entries []string
|
|
|
+ t.Logf("dir entries:")
|
|
|
+ fsys := os.DirFS(mandir)
|
|
|
+ err := fs.WalkDir(fsys, ".", func(path string, info fs.DirEntry, err error) error {
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ t.Logf(" %s", fs.FormatDirEntry(info))
|
|
|
+ if info.IsDir() {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ path = strings.TrimPrefix(path, mandir)
|
|
|
+ entries = append(entries, path)
|
|
|
+ return nil
|
|
|
+ })
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("failed to walk directory: %v", err)
|
|
|
+ }
|
|
|
|
|
|
- if !bytes.Equal(w.Body.Bytes(), expect) {
|
|
|
- t.Fatalf("expected error %s got %s", expect, w.Body.String())
|
|
|
- }
|
|
|
- })
|
|
|
+ if len(entries) != 1 {
|
|
|
+ t.Errorf("len(got) = %d, want 1", len(entries))
|
|
|
+ return // do not use Fatal so following steps run
|
|
|
+ }
|
|
|
|
|
|
- t.Run("copy", func(t *testing.T) {
|
|
|
- w := createRequest(t, s.CopyHandler, api.CopyRequest{
|
|
|
- Source: tt,
|
|
|
- Destination: strings.ToUpper(tt),
|
|
|
- })
|
|
|
+ g := entries[0] // raw path
|
|
|
+ g = filepath.ToSlash(g)
|
|
|
+ w := model.ParseName(wantStableName).Filepath()
|
|
|
+ w = filepath.ToSlash(w)
|
|
|
+ if g != w {
|
|
|
+ t.Errorf("\ngot: %s\nwant: %s", g, w)
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
- if w.Code != http.StatusBadRequest {
|
|
|
- t.Fatalf("expected status 500 got %d", w.Code)
|
|
|
- }
|
|
|
+ checkOK := func(w *httptest.ResponseRecorder) {
|
|
|
+ t.Helper()
|
|
|
+ if w.Code != http.StatusOK {
|
|
|
+ t.Errorf("code = %d, want 200", w.Code)
|
|
|
+ t.Logf("body: %s", w.Body.String())
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
- if !bytes.Equal(w.Body.Bytes(), expect) {
|
|
|
- t.Fatalf("expected error %s got %s", expect, w.Body.String())
|
|
|
- }
|
|
|
- })
|
|
|
- })
|
|
|
+ var s Server
|
|
|
+ testMakeRequestDialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
|
|
|
+ var d net.Dialer
|
|
|
+ return d.DialContext(ctx, "tcp", r.Listener.Addr().String())
|
|
|
}
|
|
|
+ t.Cleanup(func() { testMakeRequestDialContext = nil })
|
|
|
+
|
|
|
+ t.Logf("creating")
|
|
|
+ checkOK(createRequest(t, s.CreateHandler, api.CreateRequest{
|
|
|
+ // Start with the stable name, and later use a case-shuffled
|
|
|
+ // version.
|
|
|
+ Name: wantStableName,
|
|
|
+
|
|
|
+ Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
|
|
|
+ Stream: &stream,
|
|
|
+ }))
|
|
|
+ checkManifestList()
|
|
|
+
|
|
|
+ t.Logf("creating (again)")
|
|
|
+ checkOK(createRequest(t, s.CreateHandler, api.CreateRequest{
|
|
|
+ Name: name(),
|
|
|
+ Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
|
|
|
+ Stream: &stream,
|
|
|
+ }))
|
|
|
+ checkManifestList()
|
|
|
+
|
|
|
+ t.Logf("pulling")
|
|
|
+ checkOK(createRequest(t, s.PullHandler, api.PullRequest{
|
|
|
+ Name: name(),
|
|
|
+ Stream: &stream,
|
|
|
+ Insecure: true,
|
|
|
+ }))
|
|
|
+ checkManifestList()
|
|
|
+
|
|
|
+ t.Logf("copying")
|
|
|
+ checkOK(createRequest(t, s.CopyHandler, api.CopyRequest{
|
|
|
+ Source: name(),
|
|
|
+ Destination: name(),
|
|
|
+ }))
|
|
|
+ checkManifestList()
|
|
|
}
|
|
|
|
|
|
func TestShow(t *testing.T) {
|