Browse Source

Separate ListResponse and ModelResponse for api/tags vs api/ps (#4842)

* Remove false time fields

* Struct Separation for List and Process

* Remove Marshaler
royjhan 10 tháng trước cách đây
mục cha
commit
4bf1da4944
4 tập tin đã thay đổi với 30 bổ sung14 xóa
  1. 2 2
      api/client.go
  2. 20 6
      api/types.go
  3. 6 6
      server/routes.go
  4. 2 0
      server/routes_test.go

+ 2 - 2
api/client.go

@@ -355,8 +355,8 @@ func (c *Client) List(ctx context.Context) (*ListResponse, error) {
 }
 
 // List running models.
-func (c *Client) ListRunning(ctx context.Context) (*ListResponse, error) {
-	var lr ListResponse
+func (c *Client) ListRunning(ctx context.Context) (*ProcessResponse, error) {
+	var lr ProcessResponse
 	if err := c.do(ctx, http.MethodGet, "/api/ps", nil, &lr); err != nil {
 		return nil, err
 	}

+ 20 - 6
api/types.go

@@ -282,19 +282,33 @@ type PushRequest struct {
 
 // ListResponse is the response from [Client.List].
 type ListResponse struct {
-	Models []ModelResponse `json:"models"`
+	Models []ListModelResponse `json:"models"`
 }
 
-// ModelResponse is a single model description in [ListResponse].
-type ModelResponse struct {
+// ProcessResponse is the response from [Client.Process].
+type ProcessResponse struct {
+	Models []ProcessModelResponse `json:"models"`
+}
+
+// ListModelResponse is a single model description in [ListResponse].
+type ListModelResponse struct {
 	Name       string       `json:"name"`
 	Model      string       `json:"model"`
-	ModifiedAt time.Time    `json:"modified_at,omitempty"`
+	ModifiedAt time.Time    `json:"modified_at"`
 	Size       int64        `json:"size"`
 	Digest     string       `json:"digest"`
 	Details    ModelDetails `json:"details,omitempty"`
-	ExpiresAt  time.Time    `json:"expires_at,omitempty"`
-	SizeVRAM   int64        `json:"size_vram,omitempty"`
+}
+
+// ProcessModelResponse is a single model description in [ProcessResponse].
+type ProcessModelResponse struct {
+	Name      string       `json:"name"`
+	Model     string       `json:"model"`
+	Size      int64        `json:"size"`
+	Digest    string       `json:"digest"`
+	Details   ModelDetails `json:"details,omitempty"`
+	ExpiresAt time.Time    `json:"expires_at"`
+	SizeVRAM  int64        `json:"size_vram"`
 }
 
 type TokenResponse struct {

+ 6 - 6
server/routes.go

@@ -730,7 +730,7 @@ func (s *Server) ListModelsHandler(c *gin.Context) {
 		return
 	}
 
-	models := []api.ModelResponse{}
+	models := []api.ListModelResponse{}
 	for n, m := range ms {
 		f, err := m.Config.Open()
 		if err != nil {
@@ -746,7 +746,7 @@ func (s *Server) ListModelsHandler(c *gin.Context) {
 		}
 
 		// tag should never be masked
-		models = append(models, api.ModelResponse{
+		models = append(models, api.ListModelResponse{
 			Model:      n.DisplayShortest(),
 			Name:       n.DisplayShortest(),
 			Size:       m.Size(),
@@ -762,7 +762,7 @@ func (s *Server) ListModelsHandler(c *gin.Context) {
 		})
 	}
 
-	slices.SortStableFunc(models, func(i, j api.ModelResponse) int {
+	slices.SortStableFunc(models, func(i, j api.ListModelResponse) int {
 		// most recently modified first
 		return cmp.Compare(j.ModifiedAt.Unix(), i.ModifiedAt.Unix())
 	})
@@ -1139,7 +1139,7 @@ func streamResponse(c *gin.Context, ch chan any) {
 }
 
 func (s *Server) ProcessHandler(c *gin.Context) {
-	models := []api.ModelResponse{}
+	models := []api.ProcessModelResponse{}
 
 	for _, v := range s.sched.loaded {
 		model := v.model
@@ -1151,7 +1151,7 @@ func (s *Server) ProcessHandler(c *gin.Context) {
 			QuantizationLevel: model.Config.FileType,
 		}
 
-		mr := api.ModelResponse{
+		mr := api.ProcessModelResponse{
 			Model:     model.ShortName,
 			Name:      model.ShortName,
 			Size:      int64(v.estimatedTotal),
@@ -1171,7 +1171,7 @@ func (s *Server) ProcessHandler(c *gin.Context) {
 		models = append(models, mr)
 	}
 
-	c.JSON(http.StatusOK, api.ListResponse{Models: models})
+	c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
 }
 
 // ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model

+ 2 - 0
server/routes_test.go

@@ -116,6 +116,8 @@ func Test_Routes(t *testing.T) {
 				body, err := io.ReadAll(resp.Body)
 				require.NoError(t, err)
 
+				assert.NotContains(t, string(body), "expires_at")
+
 				var modelList api.ListResponse
 				err = json.Unmarshal(body, &modelList)
 				require.NoError(t, err)