Browse Source

routes: use Manifests for ListHandler

Michael Yang 1 year ago
parent
commit
c2714fcbfd
3 changed files with 127 additions and 58 deletions
  1. 10 1
      server/manifest.go
  2. 90 0
      server/manifest_test.go
  3. 27 57
      server/routes.go

+ 10 - 1
server/manifest.go

@@ -6,6 +6,7 @@ import (
 	"encoding/json"
 	"fmt"
 	"io"
+	"log/slog"
 	"os"
 	"path/filepath"
 
@@ -16,6 +17,7 @@ type Manifest struct {
 	ManifestV2
 
 	filepath string
+	fi       os.FileInfo
 	digest   string
 }
 
@@ -65,6 +67,11 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
 	}
 	defer f.Close()
 
+	fi, err := f.Stat()
+	if err != nil {
+		return nil, err
+	}
+
 	sha256sum := sha256.New()
 	if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&m); err != nil {
 		return nil, err
@@ -73,6 +80,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
 	return &Manifest{
 		ManifestV2: m,
 		filepath:   p,
+		fi:         fi,
 		digest:     fmt.Sprintf("%x", sha256sum.Sum(nil)),
 	}, nil
 }
@@ -126,7 +134,8 @@ func Manifests() (map[model.Name]*Manifest, error) {
 		if n.IsValid() {
 			m, err := ParseNamedManifest(n)
 			if err != nil {
-				return nil, err
+				slog.Warn("bad manifest", "name", n, "error", err)
+				continue
 			}
 
 			ms[n] = m

+ 90 - 0
server/manifest_test.go

@@ -0,0 +1,90 @@
+package server
+
+import (
+	"encoding/json"
+	"os"
+	"path/filepath"
+	"slices"
+	"testing"
+
+	"github.com/ollama/ollama/types/model"
+)
+
+func createManifest(t *testing.T, path, name string) {
+	t.Helper()
+
+	p := filepath.Join(path, "manifests", name)
+	if err := os.MkdirAll(filepath.Dir(p), 0755); err != nil {
+		t.Fatal(err)
+	}
+
+	f, err := os.Create(p)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer f.Close()
+
+	if err := json.NewEncoder(f).Encode(ManifestV2{}); err != nil {
+		t.Fatal(err)
+	}
+}
+
+func TestManifests(t *testing.T) {
+	cases := map[string][]string{
+		"empty": {},
+		"single": {
+			filepath.Join("host", "namespace", "model", "tag"),
+		},
+		"multiple": {
+			filepath.Join("registry.ollama.ai", "library", "llama3", "latest"),
+			filepath.Join("registry.ollama.ai", "library", "llama3", "q4_0"),
+			filepath.Join("registry.ollama.ai", "library", "llama3", "q4_1"),
+			filepath.Join("registry.ollama.ai", "library", "llama3", "q8_0"),
+			filepath.Join("registry.ollama.ai", "library", "llama3", "q5_0"),
+			filepath.Join("registry.ollama.ai", "library", "llama3", "q5_1"),
+			filepath.Join("registry.ollama.ai", "library", "llama3", "q2_K"),
+			filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_S"),
+			filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_M"),
+			filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_L"),
+			filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_S"),
+			filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_M"),
+			filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_S"),
+			filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_M"),
+			filepath.Join("registry.ollama.ai", "library", "llama3", "q6_K"),
+		},
+		"hidden": {
+			filepath.Join("host", "namespace", "model", "tag"),
+			filepath.Join("host", "namespace", "model", ".hidden"),
+		},
+	}
+
+	for n, wants := range cases {
+		t.Run(n, func(t *testing.T) {
+			d := t.TempDir()
+			t.Setenv("OLLAMA_MODELS", d)
+
+			for _, want := range wants {
+				createManifest(t, d, want)
+			}
+
+			ms, err := Manifests()
+			if err != nil {
+				t.Fatal(err)
+			}
+
+			var ns []model.Name
+			for k := range ms {
+				ns = append(ns, k)
+			}
+
+			for _, want := range wants {
+				n := model.ParseNameFromFilepath(want)
+				if !n.IsValid() && slices.Contains(ns, n) {
+					t.Errorf("unexpected invalid name: %s", want)
+				} else if n.IsValid() && !slices.Contains(ns, n) {
+					t.Errorf("missing valid name: %s", want)
+				}
+			}
+		})
+	}
+}

+ 27 - 57
server/routes.go

@@ -702,72 +702,42 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
 }
 
 func (s *Server) ListModelsHandler(c *gin.Context) {
-	manifests, err := GetManifestPath()
+	ms, err := Manifests()
 	if err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 	}
 
 	models := []api.ModelResponse{}
-	if err := filepath.Walk(manifests, func(path string, info os.FileInfo, _ error) error {
-		if !info.IsDir() {
-			rel, err := filepath.Rel(manifests, path)
-			if err != nil {
-				return err
-			}
-
-			if hidden, err := filepath.Match(".*", filepath.Base(rel)); err != nil {
-				return err
-			} else if hidden {
-				return nil
-			}
-
-			n := model.ParseNameFromFilepath(rel)
-			if !n.IsValid() {
-				slog.Warn("bad manifest filepath", "path", rel)
-				return nil
-			}
-
-			m, err := ParseNamedManifest(n)
-			if err != nil {
-				slog.Warn("bad manifest", "name", n, "error", err)
-				return nil
-			}
-
-			f, err := m.Config.Open()
-			if err != nil {
-				slog.Warn("bad manifest config filepath", "name", n, "error", err)
-				return nil
-			}
-			defer f.Close()
-
-			var c ConfigV2
-			if err := json.NewDecoder(f).Decode(&c); err != nil {
-				slog.Warn("bad manifest config", "name", n, "error", err)
-				return nil
-			}
+	for n, m := range ms {
+		f, err := m.Config.Open()
+		if err != nil {
+			slog.Warn("bad manifest filepath", "name", n, "error", err)
+			continue
+		}
+		defer f.Close()
 
-			// tag should never be masked
-			models = append(models, api.ModelResponse{
-				Model:      n.DisplayShortest(),
-				Name:       n.DisplayShortest(),
-				Size:       m.Size(),
-				Digest:     m.digest,
-				ModifiedAt: info.ModTime(),
-				Details: api.ModelDetails{
-					Format:            c.ModelFormat,
-					Family:            c.ModelFamily,
-					Families:          c.ModelFamilies,
-					ParameterSize:     c.ModelType,
-					QuantizationLevel: c.FileType,
-				},
-			})
+		var cf ConfigV2
+		if err := json.NewDecoder(f).Decode(&cf); err != nil {
+			slog.Warn("bad manifest config", "name", n, "error", err)
+			continue
 		}
 
-		return nil
-	}); err != nil {
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-		return
+		// tag should never be masked
+		models = append(models, api.ModelResponse{
+			Model:      n.DisplayShortest(),
+			Name:       n.DisplayShortest(),
+			Size:       m.Size(),
+			Digest:     m.digest,
+			ModifiedAt: m.fi.ModTime(),
+			Details: api.ModelDetails{
+				Format:            cf.ModelFormat,
+				Family:            cf.ModelFamily,
+				Families:          cf.ModelFamilies,
+				ParameterSize:     cf.ModelType,
+				QuantizationLevel: cf.FileType,
+			},
+		})
 	}
 
 	slices.SortStableFunc(models, func(i, j api.ModelResponse) int {