Sfoglia il codice sorgente

openai: fix follow-on messages having "role": "assistant"

jmorganca 5 mesi fa
parent
commit
32c48ddad6
1 ha cambiato i file con 26 aggiunte e 6 eliminazioni
  1. 26 6
      openai/openai.go

+ 26 - 6
openai/openai.go

@@ -32,7 +32,7 @@ type ErrorResponse struct {
 }
 
 type Message struct {
-	Role      string     `json:"role"`
+	Role      string     `json:"role,omitempty"`
 	Content   any        `json:"content"`
 	ToolCalls []ToolCall `json:"tool_calls,omitempty"`
 }
@@ -252,7 +252,7 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
 		SystemFingerprint: "fp_ollama",
 		Choices: []ChunkChoice{{
 			Index: 0,
-			Delta: Message{Role: "assistant", Content: r.Message.Content},
+			Delta: Message{Content: r.Message.Content},
 			FinishReason: func(reason string) *string {
 				if len(reason) > 0 {
 					return &reason
@@ -546,8 +546,9 @@ type BaseWriter struct {
 }
 
 type ChatWriter struct {
-	stream bool
-	id     string
+	stream  bool
+	started bool
+	id      string
 	BaseWriter
 }
 
@@ -594,8 +595,28 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
 		return 0, err
 	}
 
-	// chat chunk
 	if w.stream {
+		// The first chunk always has empty content so we
+		// copy the first chunk and set the content to
+		// empty, and send it first.
+		if !w.started {
+			first := chatResponse
+			first.Message = api.Message{Role: "assistant"}
+
+			d, err := json.Marshal(toChunk(w.id, first))
+			if err != nil {
+				return 0, err
+			}
+
+			w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
+			_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
+			if err != nil {
+				return 0, err
+			}
+
+			w.started = true
+		}
+
 		d, err := json.Marshal(toChunk(w.id, chatResponse))
 		if err != nil {
 			return 0, err
@@ -617,7 +638,6 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
 		return len(data), nil
 	}
 
-	// chat completion
 	w.ResponseWriter.Header().Set("Content-Type", "application/json")
 	err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse))
 	if err != nil {