Przeglądaj źródła

Enable index tracking for tools - openai api support (#7888)

Parth Sareen 5 miesięcy temu
rodzic
commit
5f8051180e
4 zmienionych plików z 89 dodań i 4 usunięć
  1. 1 0
      api/types.go
  2. 2 0
      openai/openai.go
  3. 80 1
      openai/openai_test.go
  4. 6 3
      server/routes.go

+ 1 - 0
api/types.go

@@ -146,6 +146,7 @@ type ToolCall struct {
 }
 
 type ToolCallFunction struct {
+	Index     int                       `json:"index,omitempty"`
 	Name      string                    `json:"name"`
 	Arguments ToolCallFunctionArguments `json:"arguments"`
 }

+ 2 - 0
openai/openai.go

@@ -140,6 +140,7 @@ type CompletionChunk struct {
 
 type ToolCall struct {
 	ID       string `json:"id"`
+	Index    int    `json:"index"`
 	Type     string `json:"type"`
 	Function struct {
 		Name      string `json:"name"`
@@ -206,6 +207,7 @@ func toToolCalls(tc []api.ToolCall) []ToolCall {
 		toolCalls[i].ID = toolCallId()
 		toolCalls[i].Type = "function"
 		toolCalls[i].Function.Name = tc.Function.Name
+		toolCalls[i].Index = tc.Function.Index
 
 		args, err := json.Marshal(tc.Function.Arguments)
 		if err != nil {

+ 80 - 1
openai/openai_test.go

@@ -195,7 +195,86 @@ func TestChatMiddleware(t *testing.T) {
 				Stream: &False,
 			},
 		},
-
+		{
+			name: "chat handler with streaming tools",
+			body: `{
+				"model": "test-model",
+				"messages": [
+					{"role": "user", "content": "What's the weather like in Paris?"}
+				],
+				"stream": true,
+				"tools": [{
+					"type": "function",
+					"function": {
+						"name": "get_weather",
+						"description": "Get the current weather",
+						"parameters": {
+							"type": "object",
+							"required": ["location"],
+							"properties": {
+								"location": {
+									"type": "string",
+									"description": "The city and state"
+								},
+								"unit": {
+									"type": "string",
+									"enum": ["celsius", "fahrenheit"]
+								}
+							}
+						}
+					}
+				}]
+			}`,
+			req: api.ChatRequest{
+				Model: "test-model",
+				Messages: []api.Message{
+					{
+						Role:    "user",
+						Content: "What's the weather like in Paris?",
+					},
+				},
+				Tools: []api.Tool{
+					{
+						Type: "function",
+						Function: api.ToolFunction{
+							Name:        "get_weather",
+							Description: "Get the current weather",
+							Parameters: struct {
+								Type       string   `json:"type"`
+								Required   []string `json:"required"`
+								Properties map[string]struct {
+									Type        string   `json:"type"`
+									Description string   `json:"description"`
+									Enum        []string `json:"enum,omitempty"`
+								} `json:"properties"`
+							}{
+								Type:     "object",
+								Required: []string{"location"},
+								Properties: map[string]struct {
+									Type        string   `json:"type"`
+									Description string   `json:"description"`
+									Enum        []string `json:"enum,omitempty"`
+								}{
+									"location": {
+										Type:        "string",
+										Description: "The city and state",
+									},
+									"unit": {
+										Type: "string",
+										Enum: []string{"celsius", "fahrenheit"},
+									},
+								},
+							},
+						},
+					},
+				},
+				Options: map[string]any{
+					"temperature": 1.0,
+					"top_p":       1.0,
+				},
+				Stream: &True,
+			},
+		},
 		{
 			name: "chat handler error forwarding",
 			body: `{

+ 6 - 3
server/routes.go

@@ -1469,7 +1469,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
 	go func() {
 		defer close(ch)
 		var sb strings.Builder
-		var hasToolCalls bool
+		var toolCallIndex int = 0
 		if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
 			Prompt:  prompt,
 			Images:  images,
@@ -1509,16 +1509,19 @@ func (s *Server) ChatHandler(c *gin.Context) {
 			sb.WriteString(r.Content)
 			if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
 				res.Message.ToolCalls = toolCalls
+				for i := range toolCalls {
+					toolCalls[i].Function.Index = toolCallIndex
+					toolCallIndex++
+				}
 				res.Message.Content = ""
 				sb.Reset()
-				hasToolCalls = true
 				ch <- res
 				return
 			}
 
 			if r.Done {
 				// Send any remaining content if no tool calls were detected
-				if !hasToolCalls {
+				if toolCallIndex == 0 {
 					res.Message.Content = sb.String()
 				}
 				ch <- res