浏览代码

Support image input for OpenAI chat compatibility (#5208)

* OpenAI v1 models

* Refactor Writers

* Add Test

Co-Authored-By: Attila Kerekes

* Credit Co-Author

Co-Authored-By: Attila Kerekes <439392+keriati@users.noreply.github.com>

* Empty List Testing

* Use Namespace for Ownedby

* Update Test

* Add back envconfig

* v1/models docs

* Use ModelName Parser

* Test Names

* Remove Docs

* Clean Up

* Test name

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>

* Add Middleware for Chat and List

* Testing Cleanup

* Test with Fatal

* Add functionality to chat test

* Support image input for OpenAI chat

* Decoding

* Fix message processing logic

* openai vision test

* type errors

* clean up

* redundant check

* merge conflicts

* merge conflicts

* merge conflicts

* flattening and smaller image

* add test

* support python and js SDKs and mandate prefixing

* clean up

---------

Co-authored-by: Attila Kerekes <439392+keriati@users.noreply.github.com>
Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>
royjhan 9 月之前
父节点
当前提交
e9f7f36029
共有 2 个文件被更改,包括 119 次插入6 次删除
  1. 70 6
      openai/openai.go
  2. 49 0
      openai/openai_test.go

+ 70 - 6
openai/openai.go

@@ -3,11 +3,13 @@ package openai
 
 import (
 	"bytes"
+	"encoding/base64"
 	"encoding/json"
 	"fmt"
 	"io"
 	"math/rand"
 	"net/http"
+	"strings"
 	"time"
 
 	"github.com/gin-gonic/gin"
@@ -28,7 +30,7 @@ type ErrorResponse struct {
 
 type Message struct {
 	Role    string `json:"role"`
-	Content string `json:"content"`
+	Content any    `json:"content"`
 }
 
 type Choice struct {
@@ -269,10 +271,66 @@ func toModel(r api.ShowResponse, m string) Model {
 	}
 }
 
-func fromChatRequest(r ChatCompletionRequest) api.ChatRequest {
+func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
 	var messages []api.Message
 	for _, msg := range r.Messages {
-		messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content})
+		switch content := msg.Content.(type) {
+		case string:
+			messages = append(messages, api.Message{Role: msg.Role, Content: content})
+		case []any:
+			message := api.Message{Role: msg.Role}
+			for _, c := range content {
+				data, ok := c.(map[string]any)
+				if !ok {
+					return nil, fmt.Errorf("invalid message format")
+				}
+				switch data["type"] {
+				case "text":
+					text, ok := data["text"].(string)
+					if !ok {
+						return nil, fmt.Errorf("invalid message format")
+					}
+					message.Content = text
+				case "image_url":
+					var url string
+					if urlMap, ok := data["image_url"].(map[string]any); ok {
+						if url, ok = urlMap["url"].(string); !ok {
+							return nil, fmt.Errorf("invalid message format")
+						}
+					} else {
+						if url, ok = data["image_url"].(string); !ok {
+							return nil, fmt.Errorf("invalid message format")
+						}
+					}
+
+					types := []string{"jpeg", "jpg", "png"}
+					valid := false
+					for _, t := range types {
+						prefix := "data:image/" + t + ";base64,"
+						if strings.HasPrefix(url, prefix) {
+							url = strings.TrimPrefix(url, prefix)
+							valid = true
+							break
+						}
+					}
+
+					if !valid {
+						return nil, fmt.Errorf("invalid image input")
+					}
+
+					img, err := base64.StdEncoding.DecodeString(url)
+					if err != nil {
+						return nil, fmt.Errorf("invalid message format")
+					}
+					message.Images = append(message.Images, img)
+				default:
+					return nil, fmt.Errorf("invalid message format")
+				}
+			}
+			messages = append(messages, message)
+		default:
+			return nil, fmt.Errorf("invalid message content type: %T", content)
+		}
 	}
 
 	options := make(map[string]interface{})
@@ -323,13 +381,13 @@ func fromChatRequest(r ChatCompletionRequest) api.ChatRequest {
 		format = "json"
 	}
 
-	return api.ChatRequest{
+	return &api.ChatRequest{
 		Model:    r.Model,
 		Messages: messages,
 		Format:   format,
 		Options:  options,
 		Stream:   &r.Stream,
-	}
+	}, nil
 }
 
 func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
@@ -656,7 +714,13 @@ func ChatMiddleware() gin.HandlerFunc {
 		}
 
 		var b bytes.Buffer
-		if err := json.NewEncoder(&b).Encode(fromChatRequest(req)); err != nil {
+
+		chatReq, err := fromChatRequest(req)
+		if err != nil {
+			c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
+		}
+
+		if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
 			c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
 			return
 		}

+ 49 - 0
openai/openai_test.go

@@ -2,6 +2,7 @@ package openai
 
 import (
 	"bytes"
+	"encoding/base64"
 	"encoding/json"
 	"io"
 	"net/http"
@@ -15,6 +16,10 @@ import (
 	"github.com/stretchr/testify/assert"
 )
 
+const prefix = `data:image/jpeg;base64,`
+const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
+const imageURL = prefix + image
+
 func TestMiddlewareRequests(t *testing.T) {
 	type testCase struct {
 		Name     string
@@ -112,6 +117,50 @@ func TestMiddlewareRequests(t *testing.T) {
 				}
 			},
 		},
+		{
+			Name:    "chat handler with image content",
+			Method:  http.MethodPost,
+			Path:    "/api/chat",
+			Handler: ChatMiddleware,
+			Setup: func(t *testing.T, req *http.Request) {
+				body := ChatCompletionRequest{
+					Model: "test-model",
+					Messages: []Message{
+						{
+							Role: "user", Content: []map[string]any{
+								{"type": "text", "text": "Hello"},
+								{"type": "image_url", "image_url": map[string]string{"url": imageURL}},
+							},
+						},
+					},
+				}
+
+				bodyBytes, _ := json.Marshal(body)
+
+				req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
+				req.Header.Set("Content-Type", "application/json")
+			},
+			Expected: func(t *testing.T, req *http.Request) {
+				var chatReq api.ChatRequest
+				if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
+					t.Fatal(err)
+				}
+
+				if chatReq.Messages[0].Role != "user" {
+					t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
+				}
+
+				if chatReq.Messages[0].Content != "Hello" {
+					t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
+				}
+
+				img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
+
+				if !bytes.Equal(chatReq.Messages[0].Images[0], img) {
+					t.Fatalf("expected image encoding, got %s", chatReq.Messages[0].Images[0])
+				}
+			},
+		},
 	}
 
 	gin.SetMode(gin.TestMode)