Browse Source

get api models

Michael Yang 8 months ago
parent
commit
2fe945412a
2 changed files with 153 additions and 1 deletions
  1. 3 1
      server/manifest.go
  2. 150 0
      server/routes.go

+ 3 - 1
server/manifest.go

@@ -19,6 +19,7 @@ type Manifest struct {
 	Config        *Layer   `json:"config"`
 	Layers        []*Layer `json:"layers"`
 
+	name     model.Name
 	filepath string
 	fi       os.FileInfo
 	digest   string
@@ -69,7 +70,6 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
 
 	p := filepath.Join(manifests, n.Filepath())
 
-	var m Manifest
 	f, err := os.Open(p)
 	if err != nil {
 		return nil, err
@@ -81,11 +81,13 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
 		return nil, err
 	}
 
+	var m Manifest
 	sha256sum := sha256.New()
 	if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&m); err != nil {
 		return nil, err
 	}
 
+	m.name = n
 	m.filepath = p
 	m.fi = fi
 	m.digest = hex.EncodeToString(sha256sum.Sum(nil))

+ 150 - 0
server/routes.go

@@ -703,6 +703,153 @@ func (s *Server) ShowModelHandler(c *gin.Context) {
 	c.JSON(http.StatusOK, resp)
 }
 
+func manifestLayers(m *Manifest, exclude []string) (map[string]any, error) {
+	r := map[string]any{
+		"name":        m.name.DisplayShortest(),
+		"digest":      m.digest,
+		"size":        m.Size(),
+		"modified_at": m.fi.ModTime(),
+	}
+
+	excludeAll := slices.Contains(exclude, "all")
+	excludeDetails := slices.Contains(exclude, "details")
+
+	for _, layer := range m.Layers {
+		var errExcludeKey = errors.New("exclude key")
+		key, content, err := func() (string, any, error) {
+			key := strings.TrimPrefix(layer.MediaType, "application/vnd.ollama.image.")
+			if slices.Contains(exclude, key) || excludeAll {
+				return "", nil, errExcludeKey
+			}
+
+			f, err := layer.Open()
+			if err != nil {
+				return "", nil, err
+			}
+			defer f.Close()
+
+			switch key {
+			case "model", "projector", "adapter":
+				ggml, _, err := llm.DecodeGGML(f, 0)
+				if err != nil {
+					return "", nil, err
+				}
+
+				content := map[string]any{
+					"architecture":    ggml.KV().Architecture(),
+					"file_type":       ggml.KV().FileType().String(),
+					"parameter_count": ggml.KV().ParameterCount(),
+				}
+
+				if !slices.Contains(exclude, key+".details") && !excludeAll && !excludeDetails {
+					// exclude any extraneous or redundant fields
+					delete(ggml.KV(), "general.basename")
+					delete(ggml.KV(), "general.description")
+					delete(ggml.KV(), "general.filename")
+					delete(ggml.KV(), "general.finetune")
+					delete(ggml.KV(), "general.languages")
+					delete(ggml.KV(), "general.license")
+					delete(ggml.KV(), "general.license.link")
+					delete(ggml.KV(), "general.name")
+					delete(ggml.KV(), "general.paramter_count")
+					delete(ggml.KV(), "general.size_label")
+					delete(ggml.KV(), "general.tags")
+					delete(ggml.KV(), "general.type")
+					delete(ggml.KV(), "general.quantization_version")
+					delete(ggml.KV(), "tokenizer.chat_template")
+					content["details"] = ggml.KV()
+				}
+
+				return key, content, nil
+			case "params", "messages":
+				var content any
+				if err := json.NewDecoder(f).Decode(&content); err != nil {
+					return "", nil, err
+				}
+
+				return key, content, nil
+			case "template", "system", "license":
+				bts, err := io.ReadAll(f)
+				if err != nil {
+					return "", nil, err
+				}
+
+				if key == "license" {
+					return key, []any{string(bts)}, nil
+				}
+
+				return key, string(bts), nil
+			}
+
+			return layer.MediaType, nil, nil
+		}()
+		if errors.Is(err, errExcludeKey) {
+			continue
+		} else if err != nil {
+			return nil, err
+		}
+
+		if s, ok := r[key].([]any); ok {
+			r[key] = append(s, content)
+		} else {
+			r[key] = content
+		}
+	}
+
+	return r, nil
+}
+
+func (s *Server) GetModelsHandler(c *gin.Context) {
+	ms, err := Manifests()
+	if err != nil {
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		return
+	}
+
+	var rs []map[string]any
+	for _, m := range ms {
+		r, err := manifestLayers(m, c.QueryArray("exclude"))
+		if err != nil {
+			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+			return
+		}
+
+		rs = append(rs, r)
+	}
+
+	slices.SortStableFunc(rs, func(i, j map[string]any) int {
+		// most recently modified first
+		return cmp.Compare(
+			j["modified_at"].(time.Time).Unix(),
+			i["modified_at"].(time.Time).Unix(),
+		)
+	})
+
+	c.JSON(http.StatusOK, rs)
+}
+
+func (s *Server) GetModelHandler(c *gin.Context) {
+	n := model.ParseName(strings.TrimPrefix(c.Param("model"), "/"))
+	if !n.IsValid() {
+		c.JSON(http.StatusBadRequest, gin.H{"error": "invalid model name"})
+		return
+	}
+
+	m, err := ParseNamedManifest(n)
+	if err != nil {
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		return
+	}
+
+	r, err := manifestLayers(m, c.QueryArray("exclude"))
+	if err != nil {
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		return
+	}
+
+	c.JSON(http.StatusOK, r)
+}
+
 func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
 	m, err := GetModel(req.Model)
 	if err != nil {
@@ -1090,6 +1237,9 @@ func (s *Server) GenerateRoutes() http.Handler {
 			c.String(http.StatusOK, "Ollama is running")
 		})
 
+		r.Handle(method, "/api/models", s.GetModelsHandler)
+		r.Handle(method, "/api/models/*model", s.GetModelHandler)
+
 		r.Handle(method, "/api/tags", s.ListModelsHandler)
 		r.Handle(method, "/api/version", func(c *gin.Context) {
 			c.JSON(http.StatusOK, gin.H{"version": version.Version})