|
@@ -3,11 +3,13 @@ package openai
|
|
|
|
|
|
import (
|
|
import (
|
|
"bytes"
|
|
"bytes"
|
|
|
|
+ "encoding/base64"
|
|
"encoding/json"
|
|
"encoding/json"
|
|
"fmt"
|
|
"fmt"
|
|
"io"
|
|
"io"
|
|
"math/rand"
|
|
"math/rand"
|
|
"net/http"
|
|
"net/http"
|
|
|
|
+ "strings"
|
|
"time"
|
|
"time"
|
|
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/gin-gonic/gin"
|
|
@@ -28,7 +30,7 @@ type ErrorResponse struct {
|
|
|
|
|
|
type Message struct {
|
|
type Message struct {
|
|
Role string `json:"role"`
|
|
Role string `json:"role"`
|
|
- Content string `json:"content"`
|
|
|
|
|
|
+ Content any `json:"content"`
|
|
}
|
|
}
|
|
|
|
|
|
type Choice struct {
|
|
type Choice struct {
|
|
@@ -269,10 +271,66 @@ func toModel(r api.ShowResponse, m string) Model {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
-func fromChatRequest(r ChatCompletionRequest) api.ChatRequest {
|
|
|
|
|
|
+func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
|
var messages []api.Message
|
|
var messages []api.Message
|
|
for _, msg := range r.Messages {
|
|
for _, msg := range r.Messages {
|
|
- messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content})
|
|
|
|
|
|
+ switch content := msg.Content.(type) {
|
|
|
|
+ case string:
|
|
|
|
+ messages = append(messages, api.Message{Role: msg.Role, Content: content})
|
|
|
|
+ case []any:
|
|
|
|
+ message := api.Message{Role: msg.Role}
|
|
|
|
+ for _, c := range content {
|
|
|
|
+ data, ok := c.(map[string]any)
|
|
|
|
+ if !ok {
|
|
|
|
+ return nil, fmt.Errorf("invalid message format")
|
|
|
|
+ }
|
|
|
|
+ switch data["type"] {
|
|
|
|
+ case "text":
|
|
|
|
+ text, ok := data["text"].(string)
|
|
|
|
+ if !ok {
|
|
|
|
+ return nil, fmt.Errorf("invalid message format")
|
|
|
|
+ }
|
|
|
|
+ message.Content = text
|
|
|
|
+ case "image_url":
|
|
|
|
+ var url string
|
|
|
|
+ if urlMap, ok := data["image_url"].(map[string]any); ok {
|
|
|
|
+ if url, ok = urlMap["url"].(string); !ok {
|
|
|
|
+ return nil, fmt.Errorf("invalid message format")
|
|
|
|
+ }
|
|
|
|
+ } else {
|
|
|
|
+ if url, ok = data["image_url"].(string); !ok {
|
|
|
|
+ return nil, fmt.Errorf("invalid message format")
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ types := []string{"jpeg", "jpg", "png"}
|
|
|
|
+ valid := false
|
|
|
|
+ for _, t := range types {
|
|
|
|
+ prefix := "data:image/" + t + ";base64,"
|
|
|
|
+ if strings.HasPrefix(url, prefix) {
|
|
|
|
+ url = strings.TrimPrefix(url, prefix)
|
|
|
|
+ valid = true
|
|
|
|
+ break
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if !valid {
|
|
|
|
+ return nil, fmt.Errorf("invalid image input")
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ img, err := base64.StdEncoding.DecodeString(url)
|
|
|
|
+ if err != nil {
|
|
|
|
+ return nil, fmt.Errorf("invalid message format")
|
|
|
|
+ }
|
|
|
|
+ message.Images = append(message.Images, img)
|
|
|
|
+ default:
|
|
|
|
+ return nil, fmt.Errorf("invalid message format")
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ messages = append(messages, message)
|
|
|
|
+ default:
|
|
|
|
+ return nil, fmt.Errorf("invalid message content type: %T", content)
|
|
|
|
+ }
|
|
}
|
|
}
|
|
|
|
|
|
options := make(map[string]interface{})
|
|
options := make(map[string]interface{})
|
|
@@ -323,13 +381,13 @@ func fromChatRequest(r ChatCompletionRequest) api.ChatRequest {
|
|
format = "json"
|
|
format = "json"
|
|
}
|
|
}
|
|
|
|
|
|
- return api.ChatRequest{
|
|
|
|
|
|
+ return &api.ChatRequest{
|
|
Model: r.Model,
|
|
Model: r.Model,
|
|
Messages: messages,
|
|
Messages: messages,
|
|
Format: format,
|
|
Format: format,
|
|
Options: options,
|
|
Options: options,
|
|
Stream: &r.Stream,
|
|
Stream: &r.Stream,
|
|
- }
|
|
|
|
|
|
+ }, nil
|
|
}
|
|
}
|
|
|
|
|
|
func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
|
func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
|
@@ -656,7 +714,13 @@ func ChatMiddleware() gin.HandlerFunc {
|
|
}
|
|
}
|
|
|
|
|
|
var b bytes.Buffer
|
|
var b bytes.Buffer
|
|
- if err := json.NewEncoder(&b).Encode(fromChatRequest(req)); err != nil {
|
|
|
|
|
|
+
|
|
|
|
+ chatReq, err := fromChatRequest(req)
|
|
|
|
+ if err != nil {
|
|
|
|
+ c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
|
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
|
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
|
return
|
|
return
|
|
}
|
|
}
|