|
@@ -7,6 +7,7 @@ import (
|
|
|
"encoding/json"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
+ "log/slog"
|
|
|
"math/rand"
|
|
|
"net/http"
|
|
|
"strings"
|
|
@@ -29,8 +30,9 @@ type ErrorResponse struct {
|
|
|
}
|
|
|
|
|
|
type Message struct {
|
|
|
- Role string `json:"role"`
|
|
|
- Content any `json:"content"`
|
|
|
+ Role string `json:"role"`
|
|
|
+ Content any `json:"content"`
|
|
|
+ ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
|
|
}
|
|
|
|
|
|
type Choice struct {
|
|
@@ -78,6 +80,7 @@ type ChatCompletionRequest struct {
|
|
|
PresencePenalty *float64 `json:"presence_penalty_penalty"`
|
|
|
TopP *float64 `json:"top_p"`
|
|
|
ResponseFormat *ResponseFormat `json:"response_format"`
|
|
|
+ Tools []api.Tool `json:"tools"`
|
|
|
}
|
|
|
|
|
|
type ChatCompletion struct {
|
|
@@ -133,6 +136,15 @@ type CompletionChunk struct {
|
|
|
SystemFingerprint string `json:"system_fingerprint"`
|
|
|
}
|
|
|
|
|
|
+type ToolCall struct {
|
|
|
+ ID string `json:"id"`
|
|
|
+ Type string `json:"type"`
|
|
|
+ Function struct {
|
|
|
+ Name string `json:"name"`
|
|
|
+ Arguments string `json:"arguments"`
|
|
|
+ } `json:"function"`
|
|
|
+}
|
|
|
+
|
|
|
type Model struct {
|
|
|
Id string `json:"id"`
|
|
|
Object string `json:"object"`
|
|
@@ -171,7 +183,31 @@ func NewError(code int, message string) ErrorResponse {
|
|
|
return ErrorResponse{Error{Type: etype, Message: message}}
|
|
|
}
|
|
|
|
|
|
+func toolCallId() string {
|
|
|
+ const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
|
|
|
+ b := make([]byte, 8)
|
|
|
+ for i := range b {
|
|
|
+ b[i] = letterBytes[rand.Intn(len(letterBytes))]
|
|
|
+ }
|
|
|
+ return "call_" + strings.ToLower(string(b))
|
|
|
+}
|
|
|
+
|
|
|
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
|
|
+ toolCalls := make([]ToolCall, len(r.Message.ToolCalls))
|
|
|
+ for i, tc := range r.Message.ToolCalls {
|
|
|
+ toolCalls[i].ID = toolCallId()
|
|
|
+ toolCalls[i].Type = "function"
|
|
|
+ toolCalls[i].Function.Name = tc.Function.Name
|
|
|
+
|
|
|
+ args, err := json.Marshal(tc.Function.Arguments)
|
|
|
+ if err != nil {
|
|
|
+ slog.Error("could not marshall function arguments to json", "error", err)
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ toolCalls[i].Function.Arguments = string(args)
|
|
|
+ }
|
|
|
+
|
|
|
return ChatCompletion{
|
|
|
Id: id,
|
|
|
Object: "chat.completion",
|
|
@@ -180,7 +216,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
|
|
SystemFingerprint: "fp_ollama",
|
|
|
Choices: []Choice{{
|
|
|
Index: 0,
|
|
|
- Message: Message{Role: r.Message.Role, Content: r.Message.Content},
|
|
|
+ Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls},
|
|
|
FinishReason: func(reason string) *string {
|
|
|
if len(reason) > 0 {
|
|
|
return &reason
|
|
@@ -366,7 +402,19 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
|
|
}
|
|
|
messages = append(messages, message)
|
|
|
default:
|
|
|
- return nil, fmt.Errorf("invalid message content type: %T", content)
|
|
|
+ if msg.ToolCalls == nil {
|
|
|
+ return nil, fmt.Errorf("invalid message content type: %T", content)
|
|
|
+ }
|
|
|
+
|
|
|
+ toolCalls := make([]api.ToolCall, len(msg.ToolCalls))
|
|
|
+ for i, tc := range msg.ToolCalls {
|
|
|
+ toolCalls[i].Function.Name = tc.Function.Name
|
|
|
+ err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments)
|
|
|
+ if err != nil {
|
|
|
+ return nil, fmt.Errorf("invalid tool call arguments")
|
|
|
+ }
|
|
|
+ }
|
|
|
+ messages = append(messages, api.Message{Role: msg.Role, ToolCalls: toolCalls})
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -424,6 +472,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
|
|
Format: format,
|
|
|
Options: options,
|
|
|
Stream: &r.Stream,
|
|
|
+ Tools: r.Tools,
|
|
|
}, nil
|
|
|
}
|
|
|
|