Jelajahi Sumber

check api status

Michael Yang 1 tahun lalu
induk
melakukan
e243329e2e
3 mengubah file dengan 38 tambahan dan 11 penghapusan
  1. 23 5
      api/client.go
  2. 14 5
      cmd/cmd.go
  3. 1 1
      server/routes.go

+ 23 - 5
api/client.go

@@ -10,6 +10,20 @@ import (
 	"net/url"
 )
 
+type StatusError struct {
+	StatusCode int
+	Status     string
+	Message    string
+}
+
+func (e StatusError) Error() string {
+	if e.Message != "" {
+		return fmt.Sprintf("%s: %s", e.Status, e.Message)
+	}
+
+	return e.Status
+}
+
 type Client struct {
 	base url.URL
 }
@@ -25,7 +39,7 @@ func NewClient(hosts ...string) *Client {
 	}
 }
 
-func (c *Client) stream(ctx context.Context, method, path string, data any, callback func([]byte) error) error {
+func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
 	var buf *bytes.Buffer
 	if data != nil {
 		bts, err := json.Marshal(data)
@@ -53,7 +67,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, call
 	scanner := bufio.NewScanner(response.Body)
 	for scanner.Scan() {
 		var errorResponse struct {
-			Error string `json:"error"`
+			Error string `json:"error,omitempty"`
 		}
 
 		bts := scanner.Bytes()
@@ -61,11 +75,15 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, call
 			return fmt.Errorf("unmarshal: %w", err)
 		}
 
-		if len(errorResponse.Error) > 0 {
-			return fmt.Errorf("stream: %s", errorResponse.Error)
+		if response.StatusCode >= 400 {
+			return StatusError{
+				StatusCode: response.StatusCode,
+				Status:     response.Status,
+				Message:    errorResponse.Error,
+			}
 		}
 
-		if err := callback(bts); err != nil {
+		if err := fn(bts); err != nil {
 			return err
 		}
 	}

+ 14 - 5
cmd/cmd.go

@@ -7,6 +7,7 @@ import (
 	"fmt"
 	"log"
 	"net"
+	"net/http"
 	"os"
 	"path"
 	"strings"
@@ -34,7 +35,14 @@ func RunRun(cmd *cobra.Command, args []string) error {
 	switch {
 	case errors.Is(err, os.ErrNotExist):
 		if err := pull(args[0]); err != nil {
-			return err
+			var apiStatusError api.StatusError
+			if !errors.As(err, &apiStatusError) {
+				return err
+			}
+
+			if apiStatusError.StatusCode != http.StatusBadGateway {
+				return err
+			}
 		}
 	case err != nil:
 		return err
@@ -50,11 +58,12 @@ func pull(model string) error {
 		context.Background(),
 		&api.PullRequest{Model: model},
 		func(progress api.PullProgress) error {
-			if bar == nil && progress.Percent == 100 {
-				// already downloaded
-				return nil
-			}
 			if bar == nil {
+				if progress.Percent == 100 {
+					// already downloaded
+					return nil
+				}
+
 				bar = progressbar.DefaultBytes(progress.Total)
 			}
 

+ 1 - 1
server/routes.go

@@ -108,7 +108,7 @@ func pull(c *gin.Context) {
 
 	remote, err := getRemote(req.Model)
 	if err != nil {
-		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
 		return
 	}