Browse Source

openai: finish_reason as tool_calls for streaming with tools (#7963)

Anuraag (Rag) Agrawal 2 months ago
parent
commit
10d59d5f90
1 changed files with 11 additions and 2 deletions
  1. 11 2
      openai/openai.go

+ 11 - 2
openai/openai.go

@@ -20,6 +20,8 @@ import (
 	"github.com/ollama/ollama/types/model"
 )
 
+var finishReasonToolCalls = "tool_calls"
+
 type Error struct {
 	Message string      `json:"message"`
 	Type    string      `json:"type"`
@@ -266,7 +268,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
 	}
 }
 
-func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
+func toChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk {
 	toolCalls := toToolCalls(r.Message.ToolCalls)
 	return ChatCompletionChunk{
 		Id:                id,
@@ -279,6 +281,9 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
 			Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls},
 			FinishReason: func(reason string) *string {
 				if len(reason) > 0 {
+					if toolCallSent {
+						return &finishReasonToolCalls
+					}
 					return &reason
 				}
 				return nil
@@ -585,6 +590,7 @@ type ChatWriter struct {
 	stream        bool
 	streamOptions *StreamOptions
 	id            string
+	toolCallSent  bool
 	BaseWriter
 }
 
@@ -634,11 +640,14 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
 
 	// chat chunk
 	if w.stream {
-		c := toChunk(w.id, chatResponse)
+		c := toChunk(w.id, chatResponse, w.toolCallSent)
 		d, err := json.Marshal(c)
 		if err != nil {
 			return 0, err
 		}
+		if !w.toolCallSent && len(c.Choices) > 0 && len(c.Choices[0].Delta.ToolCalls) > 0 {
+			w.toolCallSent = true
+		}
 
 		w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
 		_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))