Просмотр исходного кода

openai: finish streaming tool calls as tool_calls

Anuraag Agrawal 4 месяцев назад
Родитель
Сommit
120785dbb6
1 измененных файлов с 18 добавлено и 10 удалено
  1. 18 10
      openai/openai.go

+ 18 - 10
openai/openai.go

@@ -266,6 +266,9 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
 			Index: 0,
 			Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls},
 			FinishReason: func(reason string) *string {
+				if len(toolCalls) > 0 {
+					reason = "tool_calls"
+				}
 				if len(reason) > 0 {
 					return &reason
 				}
@@ -570,8 +573,9 @@ type BaseWriter struct {
 }
 
 type ChatWriter struct {
-	stream bool
-	id     string
+	stream   bool
+	finished bool
+	id       string
 	BaseWriter
 }
 
@@ -620,15 +624,19 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
 
 	// chat chunk
 	if w.stream {
-		d, err := json.Marshal(toChunk(w.id, chatResponse))
-		if err != nil {
-			return 0, err
-		}
+		// If we've already finished, don't send any more chunks with choices.
+		if !w.finished {
+			chunk := toChunk(w.id, chatResponse)
+			d, err := json.Marshal(chunk)
+			if err != nil {
+				return 0, err
+			}
 
-		w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
-		_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
-		if err != nil {
-			return 0, err
+			w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
+			_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
+			if err != nil {
+				return 0, err
+			}
 		}
 
 		if chatResponse.Done {