|
@@ -22,6 +22,7 @@ import (
|
|
|
|
|
|
"github.com/gin-contrib/cors"
|
|
|
"github.com/gin-gonic/gin"
|
|
|
+ "golang.org/x/exp/slices"
|
|
|
|
|
|
"github.com/jmorganca/ollama/api"
|
|
|
"github.com/jmorganca/ollama/gpu"
|
|
@@ -136,6 +137,12 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
|
|
|
return opts, nil
|
|
|
}
|
|
|
|
|
|
+func isSupportedImageType(image []byte) bool {
|
|
|
+ contentType := http.DetectContentType(image)
|
|
|
+ allowedTypes := []string{"image/jpeg", "image/jpg", "image/png"}
|
|
|
+ return slices.Contains(allowedTypes, contentType)
|
|
|
+}
|
|
|
+
|
|
|
func GenerateHandler(c *gin.Context) {
|
|
|
loaded.mu.Lock()
|
|
|
defer loaded.mu.Unlock()
|
|
@@ -166,6 +173,13 @@ func GenerateHandler(c *gin.Context) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
+ for _, img := range req.Images {
|
|
|
+ if !isSupportedImageType(img) {
|
|
|
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
|
|
|
+ return
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
model, err := GetModel(req.Model)
|
|
|
if err != nil {
|
|
|
var pErr *fs.PathError
|
|
@@ -1103,6 +1117,15 @@ func ChatHandler(c *gin.Context) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
+ for _, msg := range req.Messages {
|
|
|
+ for _, img := range msg.Images {
|
|
|
+ if !isSupportedImageType(img) {
|
|
|
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
|
|
|
+ return
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
model, err := GetModel(req.Model)
|
|
|
if err != nil {
|
|
|
var pErr *fs.PathError
|