瀏覽代碼

Check image filetype in api handlers (#2467)

Jeffrey Morgan 1 年之前
父節點
當前提交
1f9078d6ae
共有 2 個文件被更改,包括 24 次插入1 次删除
  1. 1 1
      cmd/interactive.go
  2. 23 0
      server/routes.go

+ 1 - 1
cmd/interactive.go

@@ -625,7 +625,7 @@ func getImageData(filePath string) ([]byte, error) {
 	}
 
 	contentType := http.DetectContentType(buf)
-	allowedTypes := []string{"image/jpeg", "image/jpg", "image/svg+xml", "image/png"}
+	allowedTypes := []string{"image/jpeg", "image/jpg", "image/png"}
 	if !slices.Contains(allowedTypes, contentType) {
 		return nil, fmt.Errorf("invalid image type: %s", contentType)
 	}

+ 23 - 0
server/routes.go

@@ -22,6 +22,7 @@ import (
 
 	"github.com/gin-contrib/cors"
 	"github.com/gin-gonic/gin"
+	"golang.org/x/exp/slices"
 
 	"github.com/jmorganca/ollama/api"
 	"github.com/jmorganca/ollama/gpu"
@@ -136,6 +137,12 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
 	return opts, nil
 }
 
+func isSupportedImageType(image []byte) bool {
+	contentType := http.DetectContentType(image)
+	allowedTypes := []string{"image/jpeg", "image/jpg", "image/png"}
+	return slices.Contains(allowedTypes, contentType)
+}
+
 func GenerateHandler(c *gin.Context) {
 	loaded.mu.Lock()
 	defer loaded.mu.Unlock()
@@ -166,6 +173,13 @@ func GenerateHandler(c *gin.Context) {
 		return
 	}
 
+	for _, img := range req.Images {
+		if !isSupportedImageType(img) {
+			c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
+			return
+		}
+	}
+
 	model, err := GetModel(req.Model)
 	if err != nil {
 		var pErr *fs.PathError
@@ -1103,6 +1117,15 @@ func ChatHandler(c *gin.Context) {
 		return
 	}
 
+	for _, msg := range req.Messages {
+		for _, img := range msg.Images {
+			if !isSupportedImageType(img) {
+				c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
+				return
+			}
+		}
+	}
+
 	model, err := GetModel(req.Model)
 	if err != nil {
 		var pErr *fs.PathError