|
@@ -5,6 +5,7 @@ import (
|
|
|
"bytes"
|
|
|
"encoding/base64"
|
|
|
"encoding/json"
|
|
|
+ "errors"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
"log/slog"
|
|
@@ -14,6 +15,7 @@ import (
|
|
|
"time"
|
|
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
+
|
|
|
"github.com/ollama/ollama/api"
|
|
|
"github.com/ollama/ollama/types/model"
|
|
|
)
|
|
@@ -367,24 +369,24 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
|
|
for _, c := range content {
|
|
|
data, ok := c.(map[string]any)
|
|
|
if !ok {
|
|
|
- return nil, fmt.Errorf("invalid message format")
|
|
|
+ return nil, errors.New("invalid message format")
|
|
|
}
|
|
|
switch data["type"] {
|
|
|
case "text":
|
|
|
text, ok := data["text"].(string)
|
|
|
if !ok {
|
|
|
- return nil, fmt.Errorf("invalid message format")
|
|
|
+ return nil, errors.New("invalid message format")
|
|
|
}
|
|
|
messages = append(messages, api.Message{Role: msg.Role, Content: text})
|
|
|
case "image_url":
|
|
|
var url string
|
|
|
if urlMap, ok := data["image_url"].(map[string]any); ok {
|
|
|
if url, ok = urlMap["url"].(string); !ok {
|
|
|
- return nil, fmt.Errorf("invalid message format")
|
|
|
+ return nil, errors.New("invalid message format")
|
|
|
}
|
|
|
} else {
|
|
|
if url, ok = data["image_url"].(string); !ok {
|
|
|
- return nil, fmt.Errorf("invalid message format")
|
|
|
+ return nil, errors.New("invalid message format")
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -400,17 +402,17 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
|
|
}
|
|
|
|
|
|
if !valid {
|
|
|
- return nil, fmt.Errorf("invalid image input")
|
|
|
+ return nil, errors.New("invalid image input")
|
|
|
}
|
|
|
|
|
|
img, err := base64.StdEncoding.DecodeString(url)
|
|
|
if err != nil {
|
|
|
- return nil, fmt.Errorf("invalid message format")
|
|
|
+ return nil, errors.New("invalid message format")
|
|
|
}
|
|
|
|
|
|
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
|
|
|
default:
|
|
|
- return nil, fmt.Errorf("invalid message format")
|
|
|
+ return nil, errors.New("invalid message format")
|
|
|
}
|
|
|
}
|
|
|
default:
|
|
@@ -423,7 +425,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
|
|
toolCalls[i].Function.Name = tc.Function.Name
|
|
|
err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments)
|
|
|
if err != nil {
|
|
|
- return nil, fmt.Errorf("invalid tool call arguments")
|
|
|
+ return nil, errors.New("invalid tool call arguments")
|
|
|
}
|
|
|
}
|
|
|
messages = append(messages, api.Message{Role: msg.Role, ToolCalls: toolCalls})
|
|
@@ -737,14 +739,12 @@ func (w *RetrieveWriter) Write(data []byte) (int, error) {
|
|
|
func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
|
|
|
var embedResponse api.EmbedResponse
|
|
|
err := json.Unmarshal(data, &embedResponse)
|
|
|
-
|
|
|
if err != nil {
|
|
|
return 0, err
|
|
|
}
|
|
|
|
|
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
|
err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
|
|
|
-
|
|
|
if err != nil {
|
|
|
return 0, err
|
|
|
}
|