Sfoglia il codice sorgente

Merge pull request #58 from jmorganca/generate-errors

return error in generate response
Michael Yang 1 anno fa
parent
commit
0859d50942
4 ha cambiato i file con 38 aggiunte e 7 eliminazioni
  1. 27 1
      api/client.go
  2. 1 0
      api/types.go
  3. 7 2
      cmd/cmd.go
  4. 3 4
      server/routes.go

+ 27 - 1
api/client.go

@@ -5,6 +5,7 @@ import (
 	"bytes"
 	"bytes"
 	"context"
 	"context"
 	"encoding/json"
 	"encoding/json"
+	"fmt"
 	"io"
 	"io"
 	"net/http"
 	"net/http"
 	"net/url"
 	"net/url"
@@ -25,6 +26,18 @@ func NewClient(hosts ...string) *Client {
 	}
 	}
 }
 }
 
 
+func StatusError(status int, message ...string) error {
+	if status < 400 {
+		return nil
+	}
+
+	if len(message) > 0 && len(message[0]) > 0 {
+		return fmt.Errorf("%d %s: %s", status, http.StatusText(status), message[0])
+	}
+
+	return fmt.Errorf("%d %s", status, http.StatusText(status))
+}
+
 type options struct {
 type options struct {
 	requestBody  io.Reader
 	requestBody  io.Reader
 	responseFunc func(bts []byte) error
 	responseFunc func(bts []byte) error
@@ -70,7 +83,20 @@ func (c *Client) stream(ctx context.Context, method, path string, fns ...func(*o
 	if opts.responseFunc != nil {
 	if opts.responseFunc != nil {
 		scanner := bufio.NewScanner(response.Body)
 		scanner := bufio.NewScanner(response.Body)
 		for scanner.Scan() {
 		for scanner.Scan() {
-			if err := opts.responseFunc(scanner.Bytes()); err != nil {
+			var errorResponse struct {
+				Error string `json:"error"`
+			}
+
+			bts := scanner.Bytes()
+			if err := json.Unmarshal(bts, &errorResponse); err != nil {
+				return err
+			}
+
+			if err := StatusError(response.StatusCode, errorResponse.Error); err != nil {
+				return err
+			}
+
+			if err := opts.responseFunc(bts); err != nil {
 				return err
 				return err
 			}
 			}
 		}
 		}

+ 1 - 0
api/types.go

@@ -15,6 +15,7 @@ func (e Error) Error() string {
 	if e.Message == "" {
 	if e.Message == "" {
 		return fmt.Sprintf("%d %v", e.Code, strings.ToLower(http.StatusText(int(e.Code))))
 		return fmt.Sprintf("%d %v", e.Code, strings.ToLower(http.StatusText(int(e.Code))))
 	}
 	}
+
 	return e.Message
 	return e.Message
 }
 }
 
 

+ 7 - 2
cmd/cmd.go

@@ -100,14 +100,19 @@ func generate(model, prompt string) error {
 			}
 			}
 		}()
 		}()
 
 
-		client.Generate(context.Background(), &api.GenerateRequest{Model: model, Prompt: prompt}, func(resp api.GenerateResponse) error {
+		request := api.GenerateRequest{Model: model, Prompt: prompt}
+		fn := func(resp api.GenerateResponse) error {
 			if !spinner.IsFinished() {
 			if !spinner.IsFinished() {
 				spinner.Finish()
 				spinner.Finish()
 			}
 			}
 
 
 			fmt.Print(resp.Response)
 			fmt.Print(resp.Response)
 			return nil
 			return nil
-		})
+		}
+
+		if err := client.Generate(context.Background(), &request, fn); err != nil {
+			return err
+		}
 
 
 		fmt.Println()
 		fmt.Println()
 		fmt.Println()
 		fmt.Println()

+ 3 - 4
server/routes.go

@@ -4,7 +4,6 @@ import (
 	"embed"
 	"embed"
 	"encoding/json"
 	"encoding/json"
 	"errors"
 	"errors"
-	"fmt"
 	"io"
 	"io"
 	"log"
 	"log"
 	"math"
 	"math"
@@ -46,7 +45,7 @@ func generate(c *gin.Context) {
 		req.PredictOptions = &api.DefaultPredictOptions
 		req.PredictOptions = &api.DefaultPredictOptions
 	}
 	}
 	if err := c.ShouldBindJSON(&req); err != nil {
 	if err := c.ShouldBindJSON(&req); err != nil {
-		c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
+		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		return
 		return
 	}
 	}
 
 
@@ -66,7 +65,7 @@ func generate(c *gin.Context) {
 
 
 	model, err := llama.New(req.Model, modelOpts)
 	model, err := llama.New(req.Model, modelOpts)
 	if err != nil {
 	if err != nil {
-		fmt.Println("Loading the model failed:", err.Error())
+		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		return
 		return
 	}
 	}
 	defer model.Free()
 	defer model.Free()
@@ -80,7 +79,7 @@ func generate(c *gin.Context) {
 	if template := templates.Lookup(match); template != nil {
 	if template := templates.Lookup(match); template != nil {
 		var sb strings.Builder
 		var sb strings.Builder
 		if err := template.Execute(&sb, req); err != nil {
 		if err := template.Execute(&sb, req); err != nil {
-			fmt.Println("Prompt template failed:", err.Error())
+			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 			return
 			return
 		}
 		}