Pārlūkot izejas kodu

OpenAI: Support Tools (#5614)

* reopen pr

* tools

* remove tc from stream for now

* ID and Function

* openai expects arguments to be a string (#5739)

* mutually exclusive content and tool calls

* clean up

---------

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>
royjhan 9 mēneši atpakaļ
vecāks
revīzija
154f6f45d4
1 mainītis faili ar 53 papildinājumiem un 4 dzēšanām
  1. 53 4
      openai/openai.go

+ 53 - 4
openai/openai.go

@@ -7,6 +7,7 @@ import (
 	"encoding/json"
 	"fmt"
 	"io"
+	"log/slog"
 	"math/rand"
 	"net/http"
 	"strings"
@@ -29,8 +30,9 @@ type ErrorResponse struct {
 }
 
 type Message struct {
-	Role    string `json:"role"`
-	Content any    `json:"content"`
+	Role      string     `json:"role"`
+	Content   any        `json:"content"`
+	ToolCalls []ToolCall `json:"tool_calls,omitempty"`
 }
 
 type Choice struct {
@@ -78,6 +80,7 @@ type ChatCompletionRequest struct {
 	PresencePenalty  *float64        `json:"presence_penalty_penalty"`
 	TopP             *float64        `json:"top_p"`
 	ResponseFormat   *ResponseFormat `json:"response_format"`
+	Tools            []api.Tool      `json:"tools"`
 }
 
 type ChatCompletion struct {
@@ -133,6 +136,15 @@ type CompletionChunk struct {
 	SystemFingerprint string                `json:"system_fingerprint"`
 }
 
+type ToolCall struct {
+	ID       string `json:"id"`
+	Type     string `json:"type"`
+	Function struct {
+		Name      string `json:"name"`
+		Arguments string `json:"arguments"`
+	} `json:"function"`
+}
+
 type Model struct {
 	Id      string `json:"id"`
 	Object  string `json:"object"`
@@ -171,7 +183,31 @@ func NewError(code int, message string) ErrorResponse {
 	return ErrorResponse{Error{Type: etype, Message: message}}
 }
 
+func toolCallId() string {
+	const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
+	b := make([]byte, 8)
+	for i := range b {
+		b[i] = letterBytes[rand.Intn(len(letterBytes))]
+	}
+	return "call_" + strings.ToLower(string(b))
+}
+
 func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
+	toolCalls := make([]ToolCall, len(r.Message.ToolCalls))
+	for i, tc := range r.Message.ToolCalls {
+		toolCalls[i].ID = toolCallId()
+		toolCalls[i].Type = "function"
+		toolCalls[i].Function.Name = tc.Function.Name
+
+		args, err := json.Marshal(tc.Function.Arguments)
+		if err != nil {
+			slog.Error("could not marshall function arguments to json", "error", err)
+			continue
+		}
+
+		toolCalls[i].Function.Arguments = string(args)
+	}
+
 	return ChatCompletion{
 		Id:                id,
 		Object:            "chat.completion",
@@ -180,7 +216,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
 		SystemFingerprint: "fp_ollama",
 		Choices: []Choice{{
 			Index:   0,
-			Message: Message{Role: r.Message.Role, Content: r.Message.Content},
+			Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls},
 			FinishReason: func(reason string) *string {
 				if len(reason) > 0 {
 					return &reason
@@ -366,7 +402,19 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
 			}
 			messages = append(messages, message)
 		default:
-			return nil, fmt.Errorf("invalid message content type: %T", content)
+			if msg.ToolCalls == nil {
+				return nil, fmt.Errorf("invalid message content type: %T", content)
+			}
+
+			toolCalls := make([]api.ToolCall, len(msg.ToolCalls))
+			for i, tc := range msg.ToolCalls {
+				toolCalls[i].Function.Name = tc.Function.Name
+				err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments)
+				if err != nil {
+					return nil, fmt.Errorf("invalid tool call arguments")
+				}
+			}
+			messages = append(messages, api.Message{Role: msg.Role, ToolCalls: toolCalls})
 		}
 	}
 
@@ -424,6 +472,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
 		Format:   format,
 		Options:  options,
 		Stream:   &r.Stream,
+		Tools:    r.Tools,
 	}, nil
 }