浏览代码

Merge pull request #843 from jmorganca/mxyng/request-validation

basic request validation
Michael Yang 1 年之前
父节点
当前提交
0a53da03fd
共有 4 个文件被更改,包括 109 次插入30 次删除
  1. 3 3
      examples/golang-simplegenerate/main.go
  2. 1 1
      format/time_test.go
  3. 1 1
      server/images.go
  4. 104 25
      server/routes.go

+ 3 - 3
examples/golang-simplegenerate/main.go

@@ -3,10 +3,10 @@ package main
 import (
 import (
 	"bytes"
 	"bytes"
 	"fmt"
 	"fmt"
-	"net/http"
-	"os"
 	"io"
 	"io"
 	"log"
 	"log"
+	"net/http"
+	"os"
 )
 )
 
 
 func main() {
 func main() {
@@ -16,7 +16,7 @@ func main() {
 	if err != nil {
 	if err != nil {
 		fmt.Print(err.Error())
 		fmt.Print(err.Error())
 		os.Exit(1)
 		os.Exit(1)
-	} 
+	}
 
 
 	responseData, err := io.ReadAll(resp.Body)
 	responseData, err := io.ReadAll(resp.Body)
 	if err != nil {
 	if err != nil {

+ 1 - 1
format/time_test.go

@@ -29,7 +29,7 @@ func TestHumanTime(t *testing.T) {
 	})
 	})
 
 
 	t.Run("soon", func(t *testing.T) {
 	t.Run("soon", func(t *testing.T) {
-		v := now.Add(800*time.Millisecond)
+		v := now.Add(800 * time.Millisecond)
 		assertEqual(t, HumanTime(v, ""), "Less than a second from now")
 		assertEqual(t, HumanTime(v, ""), "Less than a second from now")
 	})
 	})
 }
 }

+ 1 - 1
server/images.go

@@ -252,7 +252,7 @@ func filenameWithPath(path, f string) (string, error) {
 	return f, nil
 	return f, nil
 }
 }
 
 
-func CreateModel(ctx context.Context, workDir, name string, path string, fn func(resp api.ProgressResponse)) error {
+func CreateModel(ctx context.Context, name string, path string, fn func(resp api.ProgressResponse)) error {
 	mp := ParseModelPath(name)
 	mp := ParseModelPath(name)
 
 
 	var manifest *ManifestV2
 	var manifest *ManifestV2

+ 104 - 25
server/routes.go

@@ -137,8 +137,18 @@ func GenerateHandler(c *gin.Context) {
 	checkpointStart := time.Now()
 	checkpointStart := time.Now()
 
 
 	var req api.GenerateRequest
 	var req api.GenerateRequest
-	if err := c.ShouldBindJSON(&req); err != nil {
-		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+	err := c.ShouldBindJSON(&req)
+	switch {
+	case errors.Is(err, io.EOF):
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
+		return
+	case err != nil:
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		return
+	}
+
+	if req.Model == "" {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
 		return
 		return
 	}
 	}
 
 
@@ -177,6 +187,12 @@ func GenerateHandler(c *gin.Context) {
 	ch := make(chan any)
 	ch := make(chan any)
 	go func() {
 	go func() {
 		defer close(ch)
 		defer close(ch)
+		// an empty request loads the model
+		if req.Prompt == "" && req.Template == "" && req.System == "" {
+			ch <- api.GenerateResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true}
+			return
+		}
+
 		fn := func(r api.GenerateResponse) {
 		fn := func(r api.GenerateResponse) {
 			loaded.expireAt = time.Now().Add(sessionDuration)
 			loaded.expireAt = time.Now().Add(sessionDuration)
 			loaded.expireTimer.Reset(sessionDuration)
 			loaded.expireTimer.Reset(sessionDuration)
@@ -191,13 +207,8 @@ func GenerateHandler(c *gin.Context) {
 			ch <- r
 			ch <- r
 		}
 		}
 
 
-		// an empty request loads the model
-		if req.Prompt == "" && req.Template == "" && req.System == "" {
-			ch <- api.GenerateResponse{Model: req.Model, Done: true}
-		} else {
-			if err := loaded.runner.Predict(c.Request.Context(), req.Context, prompt, fn); err != nil {
-				ch <- gin.H{"error": err.Error()}
-			}
+		if err := loaded.runner.Predict(c.Request.Context(), req.Context, prompt, fn); err != nil {
+			ch <- gin.H{"error": err.Error()}
 		}
 		}
 	}()
 	}()
 
 
@@ -226,8 +237,18 @@ func EmbeddingHandler(c *gin.Context) {
 	defer loaded.mu.Unlock()
 	defer loaded.mu.Unlock()
 
 
 	var req api.EmbeddingRequest
 	var req api.EmbeddingRequest
-	if err := c.ShouldBindJSON(&req); err != nil {
-		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+	err := c.ShouldBindJSON(&req)
+	switch {
+	case errors.Is(err, io.EOF):
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
+		return
+	case err != nil:
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		return
+	}
+
+	if req.Model == "" {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
 		return
 		return
 	}
 	}
 
 
@@ -263,8 +284,18 @@ func EmbeddingHandler(c *gin.Context) {
 
 
 func PullModelHandler(c *gin.Context) {
 func PullModelHandler(c *gin.Context) {
 	var req api.PullRequest
 	var req api.PullRequest
-	if err := c.ShouldBindJSON(&req); err != nil {
-		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+	err := c.ShouldBindJSON(&req)
+	switch {
+	case errors.Is(err, io.EOF):
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
+		return
+	case err != nil:
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		return
+	}
+
+	if req.Name == "" {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
 		return
 		return
 	}
 	}
 
 
@@ -297,8 +328,18 @@ func PullModelHandler(c *gin.Context) {
 
 
 func PushModelHandler(c *gin.Context) {
 func PushModelHandler(c *gin.Context) {
 	var req api.PushRequest
 	var req api.PushRequest
-	if err := c.ShouldBindJSON(&req); err != nil {
-		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+	err := c.ShouldBindJSON(&req)
+	switch {
+	case errors.Is(err, io.EOF):
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
+		return
+	case err != nil:
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		return
+	}
+
+	if req.Name == "" {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
 		return
 		return
 	}
 	}
 
 
@@ -329,12 +370,20 @@ func PushModelHandler(c *gin.Context) {
 
 
 func CreateModelHandler(c *gin.Context) {
 func CreateModelHandler(c *gin.Context) {
 	var req api.CreateRequest
 	var req api.CreateRequest
-	if err := c.ShouldBindJSON(&req); err != nil {
-		c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
+	err := c.ShouldBindJSON(&req)
+	switch {
+	case errors.Is(err, io.EOF):
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
+		return
+	case err != nil:
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		return
 		return
 	}
 	}
 
 
-	workDir := c.GetString("workDir")
+	if req.Name == "" || req.Path == "" {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name and path are required"})
+		return
+	}
 
 
 	ch := make(chan any)
 	ch := make(chan any)
 	go func() {
 	go func() {
@@ -346,7 +395,7 @@ func CreateModelHandler(c *gin.Context) {
 		ctx, cancel := context.WithCancel(c.Request.Context())
 		ctx, cancel := context.WithCancel(c.Request.Context())
 		defer cancel()
 		defer cancel()
 
 
-		if err := CreateModel(ctx, workDir, req.Name, req.Path, fn); err != nil {
+		if err := CreateModel(ctx, req.Name, req.Path, fn); err != nil {
 			ch <- gin.H{"error": err.Error()}
 			ch <- gin.H{"error": err.Error()}
 		}
 		}
 	}()
 	}()
@@ -361,8 +410,18 @@ func CreateModelHandler(c *gin.Context) {
 
 
 func DeleteModelHandler(c *gin.Context) {
 func DeleteModelHandler(c *gin.Context) {
 	var req api.DeleteRequest
 	var req api.DeleteRequest
-	if err := c.ShouldBindJSON(&req); err != nil {
-		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+	err := c.ShouldBindJSON(&req)
+	switch {
+	case errors.Is(err, io.EOF):
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
+		return
+	case err != nil:
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		return
+	}
+
+	if req.Name == "" {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
 		return
 		return
 	}
 	}
 
 
@@ -391,8 +450,18 @@ func DeleteModelHandler(c *gin.Context) {
 
 
 func ShowModelHandler(c *gin.Context) {
 func ShowModelHandler(c *gin.Context) {
 	var req api.ShowRequest
 	var req api.ShowRequest
-	if err := c.ShouldBindJSON(&req); err != nil {
-		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+	err := c.ShouldBindJSON(&req)
+	switch {
+	case errors.Is(err, io.EOF):
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
+		return
+	case err != nil:
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		return
+	}
+
+	if req.Name == "" {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
 		return
 		return
 	}
 	}
 
 
@@ -502,8 +571,18 @@ func ListModelsHandler(c *gin.Context) {
 
 
 func CopyModelHandler(c *gin.Context) {
 func CopyModelHandler(c *gin.Context) {
 	var req api.CopyRequest
 	var req api.CopyRequest
-	if err := c.ShouldBindJSON(&req); err != nil {
-		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+	err := c.ShouldBindJSON(&req)
+	switch {
+	case errors.Is(err, io.EOF):
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
+		return
+	case err != nil:
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		return
+	}
+
+	if req.Source == "" || req.Destination == "" {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "source add destination are required"})
 		return
 		return
 	}
 	}