Jelajahi Sumber

Merge pull request #3715 from ollama/mxyng/modelname-2

update list handler to use model.Name
Michael Yang 1 tahun lalu
induk
melakukan
ffbd3d173f
6 mengubah file dengan 215 tambahan dan 87 penghapusan
  1. 0 11
      server/images.go
  2. 79 0
      server/manifest.go
  3. 0 34
      server/manifests.go
  4. 42 39
      server/routes.go
  5. 51 3
      types/model/name.go
  6. 43 0
      types/model/name_test.go

+ 0 - 11
server/images.go

@@ -52,7 +52,6 @@ type Model struct {
 	System         string
 	License        []string
 	Digest         string
-	Size           int64
 	Options        map[string]interface{}
 	Messages       []Message
 }
@@ -161,15 +160,6 @@ type RootFS struct {
 	DiffIDs []string `json:"diff_ids"`
 }
 
-func (m *ManifestV2) GetTotalSize() (total int64) {
-	for _, layer := range m.Layers {
-		total += layer.Size
-	}
-
-	total += m.Config.Size
-	return total
-}
-
 func GetManifest(mp ModelPath) (*ManifestV2, string, error) {
 	fp, err := mp.GetManifestPath()
 	if err != nil {
@@ -210,7 +200,6 @@ func GetModel(name string) (*Model, error) {
 		Digest:    digest,
 		Template:  "{{ .Prompt }}",
 		License:   []string{},
-		Size:      manifest.GetTotalSize(),
 	}
 
 	filename, err := GetBlobsPath(manifest.Config.Digest)

+ 79 - 0
server/manifest.go

@@ -0,0 +1,79 @@
+package server
+
+import (
+	"bytes"
+	"crypto/sha256"
+	"encoding/json"
+	"fmt"
+	"io"
+	"os"
+	"path/filepath"
+
+	"github.com/ollama/ollama/types/model"
+)
+
+type Manifest struct {
+	ManifestV2
+	Digest string `json:"-"`
+}
+
+func (m *Manifest) Size() (size int64) {
+	for _, layer := range append(m.Layers, m.Config) {
+		size += layer.Size
+	}
+
+	return
+}
+
+func ParseNamedManifest(name model.Name) (*Manifest, error) {
+	if !name.IsFullyQualified() {
+		return nil, model.Unqualified(name)
+	}
+
+	manifests, err := GetManifestPath()
+	if err != nil {
+		return nil, err
+	}
+
+	var manifest ManifestV2
+	manifestfile, err := os.Open(filepath.Join(manifests, name.Filepath()))
+	if err != nil {
+		return nil, err
+	}
+
+	sha256sum := sha256.New()
+	if err := json.NewDecoder(io.TeeReader(manifestfile, sha256sum)).Decode(&manifest); err != nil {
+		return nil, err
+	}
+
+	return &Manifest{
+		ManifestV2: manifest,
+		Digest:     fmt.Sprintf("%x", sha256sum.Sum(nil)),
+	}, nil
+}
+
+func WriteManifest(name string, config *Layer, layers []*Layer) error {
+	manifest := ManifestV2{
+		SchemaVersion: 2,
+		MediaType:     "application/vnd.docker.distribution.manifest.v2+json",
+		Config:        config,
+		Layers:        layers,
+	}
+
+	var b bytes.Buffer
+	if err := json.NewEncoder(&b).Encode(manifest); err != nil {
+		return err
+	}
+
+	modelpath := ParseModelPath(name)
+	manifestPath, err := modelpath.GetManifestPath()
+	if err != nil {
+		return err
+	}
+
+	if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil {
+		return err
+	}
+
+	return os.WriteFile(manifestPath, b.Bytes(), 0o644)
+}

+ 0 - 34
server/manifests.go

@@ -1,34 +0,0 @@
-package server
-
-import (
-	"bytes"
-	"encoding/json"
-	"os"
-	"path/filepath"
-)
-
-func WriteManifest(name string, config *Layer, layers []*Layer) error {
-	manifest := ManifestV2{
-		SchemaVersion: 2,
-		MediaType:     "application/vnd.docker.distribution.manifest.v2+json",
-		Config:        config,
-		Layers:        layers,
-	}
-
-	var b bytes.Buffer
-	if err := json.NewEncoder(&b).Encode(manifest); err != nil {
-		return err
-	}
-
-	modelpath := ParseModelPath(name)
-	manifestPath, err := modelpath.GetManifestPath()
-	if err != nil {
-		return err
-	}
-
-	if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil {
-		return err
-	}
-
-	return os.WriteFile(manifestPath, b.Bytes(), 0o644)
-}

+ 42 - 39
server/routes.go

@@ -719,62 +719,65 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
 }
 
 func (s *Server) ListModelsHandler(c *gin.Context) {
-	models := make([]api.ModelResponse, 0)
-	manifestsPath, err := GetManifestPath()
+	manifests, err := GetManifestPath()
 	if err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 	}
 
-	modelResponse := func(modelName string) (api.ModelResponse, error) {
-		model, err := GetModel(modelName)
-		if err != nil {
-			return api.ModelResponse{}, err
-		}
-
-		modelDetails := api.ModelDetails{
-			Format:            model.Config.ModelFormat,
-			Family:            model.Config.ModelFamily,
-			Families:          model.Config.ModelFamilies,
-			ParameterSize:     model.Config.ModelType,
-			QuantizationLevel: model.Config.FileType,
-		}
-
-		return api.ModelResponse{
-			Model:   model.ShortName,
-			Name:    model.ShortName,
-			Size:    model.Size,
-			Digest:  model.Digest,
-			Details: modelDetails,
-		}, nil
-	}
-
-	walkFunc := func(path string, info os.FileInfo, _ error) error {
+	var models []api.ModelResponse
+	if err := filepath.Walk(manifests, func(path string, info os.FileInfo, _ error) error {
 		if !info.IsDir() {
-			path, tag := filepath.Split(path)
-			model := strings.Trim(strings.TrimPrefix(path, manifestsPath), string(os.PathSeparator))
-			modelPath := strings.Join([]string{model, tag}, ":")
-			canonicalModelPath := strings.ReplaceAll(modelPath, string(os.PathSeparator), "/")
+			rel, err := filepath.Rel(manifests, path)
+			if err != nil {
+				return err
+			}
 
-			resp, err := modelResponse(canonicalModelPath)
+			n := model.ParseNameFromFilepath(rel)
+			m, err := ParseNamedManifest(n)
 			if err != nil {
-				slog.Info(fmt.Sprintf("skipping file: %s", canonicalModelPath))
-				// nolint: nilerr
-				return nil
+				return err
 			}
 
-			resp.ModifiedAt = info.ModTime()
-			models = append(models, resp)
+			f, err := m.Config.Open()
+			if err != nil {
+				return err
+			}
+			defer f.Close()
+
+			var c ConfigV2
+			if err := json.NewDecoder(f).Decode(&c); err != nil {
+				return err
+			}
+
+			// tag should never be masked
+			models = append(models, api.ModelResponse{
+				Model:      n.DisplayShortest(),
+				Name:       n.DisplayShortest(),
+				Size:       m.Size(),
+				Digest:     m.Digest,
+				ModifiedAt: info.ModTime(),
+				Details: api.ModelDetails{
+					Format:            c.ModelFormat,
+					Family:            c.ModelFamily,
+					Families:          c.ModelFamilies,
+					ParameterSize:     c.ModelType,
+					QuantizationLevel: c.FileType,
+				},
+			})
 		}
 
 		return nil
-	}
-
-	if err := filepath.Walk(manifestsPath, walkFunc); err != nil {
+	}); err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 	}
 
+	slices.SortStableFunc(models, func(i, j api.ModelResponse) int {
+		// most recently modified first
+		return cmp.Compare(j.ModifiedAt.Unix(), i.ModifiedAt.Unix())
+	})
+
 	c.JSON(http.StatusOK, api.ListResponse{Models: models})
 }
 

+ 51 - 3
types/model/name.go

@@ -35,6 +35,12 @@ func Unqualified(n Name) error {
 // spot in logs.
 const MissingPart = "!MISSING!"
 
+const (
+	defaultHost      = "registry.ollama.ai"
+	defaultNamespace = "library"
+	defaultTag       = "latest"
+)
+
 // DefaultName returns a name with the default values for the host, namespace,
 // and tag parts. The model and digest parts are empty.
 //
@@ -43,9 +49,9 @@ const MissingPart = "!MISSING!"
 //   - The default tag is ("latest")
 func DefaultName() Name {
 	return Name{
-		Host:      "registry.ollama.ai",
-		Namespace: "library",
-		Tag:       "latest",
+		Host:      defaultHost,
+		Namespace: defaultNamespace,
+		Tag:       defaultTag,
 	}
 }
 
@@ -169,6 +175,27 @@ func ParseNameBare(s string) Name {
 	return n
 }
 
+// ParseNameFromFilepath parses a 4-part filepath as a Name. The parts are
+// expected to be in the form:
+//
+// { host } "/" { namespace } "/" { model } "/" { tag }
+func ParseNameFromFilepath(s string) (n Name) {
+	parts := strings.Split(s, string(filepath.Separator))
+	if len(parts) != 4 {
+		return Name{}
+	}
+
+	n.Host = parts[0]
+	n.Namespace = parts[1]
+	n.Model = parts[2]
+	n.Tag = parts[3]
+	if !n.IsFullyQualified() {
+		return Name{}
+	}
+
+	return n
+}
+
 // Merge merges the host, namespace, and tag parts of the two names,
 // preferring the non-empty parts of a.
 func Merge(a, b Name) Name {
@@ -203,6 +230,27 @@ func (n Name) String() string {
 	return b.String()
 }
 
+// DisplayShort returns a short string version of the name.
+func (n Name) DisplayShortest() string {
+	var sb strings.Builder
+
+	if n.Host != defaultHost {
+		sb.WriteString(n.Host)
+		sb.WriteByte('/')
+		sb.WriteString(n.Namespace)
+		sb.WriteByte('/')
+	} else if n.Namespace != defaultNamespace {
+		sb.WriteString(n.Namespace)
+		sb.WriteByte('/')
+	}
+
+	// always include model and tag
+	sb.WriteString(n.Model)
+	sb.WriteString(":")
+	sb.WriteString(n.Tag)
+	return sb.String()
+}
+
 // IsValid reports whether all parts of the name are present and valid. The
 // digest is a special case, and is checked for validity only if present.
 func (n Name) IsValid() bool {

+ 43 - 0
types/model/name_test.go

@@ -309,6 +309,49 @@ func TestParseDigest(t *testing.T) {
 	}
 }
 
+func TestParseNameFromFilepath(t *testing.T) {
+	cases := map[string]Name{
+		filepath.Join("host", "namespace", "model", "tag"):      {Host: "host", Namespace: "namespace", Model: "model", Tag: "tag"},
+		filepath.Join("host:port", "namespace", "model", "tag"): {Host: "host:port", Namespace: "namespace", Model: "model", Tag: "tag"},
+		filepath.Join("namespace", "model", "tag"):              {},
+		filepath.Join("model", "tag"):                           {},
+		filepath.Join("model"):                                  {},
+		filepath.Join("..", "..", "model", "tag"):               {},
+		filepath.Join("", "namespace", ".", "tag"):              {},
+		filepath.Join(".", ".", ".", "."):                       {},
+		filepath.Join("/", "path", "to", "random", "file"):      {},
+	}
+
+	for in, want := range cases {
+		t.Run(in, func(t *testing.T) {
+			got := ParseNameFromFilepath(in)
+
+			if !reflect.DeepEqual(got, want) {
+				t.Errorf("parseNameFromFilepath(%q) = %v; want %v", in, got, want)
+			}
+		})
+	}
+}
+
+func TestDisplayShortest(t *testing.T) {
+	cases := map[string]string{
+		"registry.ollama.ai/library/model:latest": "model:latest",
+		"registry.ollama.ai/library/model:tag":    "model:tag",
+		"registry.ollama.ai/namespace/model:tag":  "namespace/model:tag",
+		"host/namespace/model:tag":                "host/namespace/model:tag",
+		"host/library/model:tag":                  "host/library/model:tag",
+	}
+
+	for in, want := range cases {
+		t.Run(in, func(t *testing.T) {
+			got := ParseNameBare(in).DisplayShortest()
+			if got != want {
+				t.Errorf("parseName(%q).DisplayShortest() = %q; want %q", in, got, want)
+			}
+		})
+	}
+}
+
 func FuzzName(f *testing.F) {
 	for s := range testCases {
 		f.Add(s)