瀏覽代碼

openai: support include_usage stream option to return final usage chunk

Anuraag Agrawal 7 月之前
父節點
當前提交
220108d3f4
共有 2 個文件被更改,包括 186 次插入33 次删除
  1. 98 33
      openai/openai.go
  2. 88 0
      openai/openai_test.go

+ 98 - 33
openai/openai.go

@@ -61,6 +61,21 @@ type Usage struct {
 	TotalTokens      int `json:"total_tokens"`
 }
 
+// ChunkUsage is an alias for Usage with the ability to marshal a marker
+// value as null. This is to allow omitting the field in chunks when usage
+// isn't requested, and otherwise return null on non-final chunks when it
+// is requested to follow OpenAI's behavior.
+type ChunkUsage = Usage
+
+var nullChunkUsage = ChunkUsage{}
+
+func (u *ChunkUsage) MarshalJSON() ([]byte, error) {
+	if u == &nullChunkUsage {
+		return []byte("null"), nil
+	}
+	return json.Marshal(*u)
+}
+
 type ResponseFormat struct {
 	Type string `json:"type"`
 }
@@ -70,10 +85,15 @@ type EmbedRequest struct {
 	Model string `json:"model"`
 }
 
+type StreamOptions struct {
+	IncludeUsage bool `json:"include_usage"`
+}
+
 type ChatCompletionRequest struct {
 	Model            string          `json:"model"`
 	Messages         []Message       `json:"messages"`
 	Stream           bool            `json:"stream"`
+	StreamOptions    *StreamOptions  `json:"stream_options"`
 	MaxTokens        *int            `json:"max_tokens"`
 	Seed             *int            `json:"seed"`
 	Stop             any             `json:"stop"`
@@ -102,21 +122,23 @@ type ChatCompletionChunk struct {
 	Model             string        `json:"model"`
 	SystemFingerprint string        `json:"system_fingerprint"`
 	Choices           []ChunkChoice `json:"choices"`
+	Usage             *ChunkUsage   `json:"usage,omitempty"`
 }
 
 // TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
 type CompletionRequest struct {
-	Model            string   `json:"model"`
-	Prompt           string   `json:"prompt"`
-	FrequencyPenalty float32  `json:"frequency_penalty"`
-	MaxTokens        *int     `json:"max_tokens"`
-	PresencePenalty  float32  `json:"presence_penalty"`
-	Seed             *int     `json:"seed"`
-	Stop             any      `json:"stop"`
-	Stream           bool     `json:"stream"`
-	Temperature      *float32 `json:"temperature"`
-	TopP             float32  `json:"top_p"`
-	Suffix           string   `json:"suffix"`
+	Model            string         `json:"model"`
+	Prompt           string         `json:"prompt"`
+	FrequencyPenalty float32        `json:"frequency_penalty"`
+	MaxTokens        *int           `json:"max_tokens"`
+	PresencePenalty  float32        `json:"presence_penalty"`
+	Seed             *int           `json:"seed"`
+	Stop             any            `json:"stop"`
+	Stream           bool           `json:"stream"`
+	StreamOptions    *StreamOptions `json:"stream_options"`
+	Temperature      *float32       `json:"temperature"`
+	TopP             float32        `json:"top_p"`
+	Suffix           string         `json:"suffix"`
 }
 
 type Completion struct {
@@ -136,6 +158,7 @@ type CompletionChunk struct {
 	Choices           []CompleteChunkChoice `json:"choices"`
 	Model             string                `json:"model"`
 	SystemFingerprint string                `json:"system_fingerprint"`
+	Usage             *ChunkUsage           `json:"usage,omitempty"`
 }
 
 type ToolCall struct {
@@ -200,6 +223,14 @@ func toolCallId() string {
 	return "call_" + strings.ToLower(string(b))
 }
 
+func toUsage(r api.ChatResponse) Usage {
+	return Usage{
+		PromptTokens:     r.PromptEvalCount,
+		CompletionTokens: r.EvalCount,
+		TotalTokens:      r.PromptEvalCount + r.EvalCount,
+	}
+}
+
 func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
 	toolCalls := make([]ToolCall, len(r.Message.ToolCalls))
 	for i, tc := range r.Message.ToolCalls {
@@ -235,11 +266,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
 				return nil
 			}(r.DoneReason),
 		}},
-		Usage: Usage{
-			PromptTokens:     r.PromptEvalCount,
-			CompletionTokens: r.EvalCount,
-			TotalTokens:      r.PromptEvalCount + r.EvalCount,
-		},
+		Usage: toUsage(r),
 	}
 }
 
@@ -263,6 +290,14 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
 	}
 }
 
+func toUsageGenerate(r api.GenerateResponse) Usage {
+	return Usage{
+		PromptTokens:     r.PromptEvalCount,
+		CompletionTokens: r.EvalCount,
+		TotalTokens:      r.PromptEvalCount + r.EvalCount,
+	}
+}
+
 func toCompletion(id string, r api.GenerateResponse) Completion {
 	return Completion{
 		Id:                id,
@@ -280,11 +315,7 @@ func toCompletion(id string, r api.GenerateResponse) Completion {
 				return nil
 			}(r.DoneReason),
 		}},
-		Usage: Usage{
-			PromptTokens:     r.PromptEvalCount,
-			CompletionTokens: r.EvalCount,
-			TotalTokens:      r.PromptEvalCount + r.EvalCount,
-		},
+		Usage: toUsageGenerate(r),
 	}
 }
 
@@ -546,14 +577,16 @@ type BaseWriter struct {
 }
 
 type ChatWriter struct {
-	stream bool
-	id     string
+	stream      bool
+	streamUsage bool
+	id          string
 	BaseWriter
 }
 
 type CompleteWriter struct {
-	stream bool
-	id     string
+	stream      bool
+	streamUsage bool
+	id          string
 	BaseWriter
 }
 
@@ -596,7 +629,11 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
 
 	// chat chunk
 	if w.stream {
-		d, err := json.Marshal(toChunk(w.id, chatResponse))
+		c := toChunk(w.id, chatResponse)
+		if w.streamUsage {
+			c.Usage = &nullChunkUsage
+		}
+		d, err := json.Marshal(c)
 		if err != nil {
 			return 0, err
 		}
@@ -608,6 +645,17 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
 		}
 
 		if chatResponse.Done {
+			if w.streamUsage {
+				u := toUsage(chatResponse)
+				d, err := json.Marshal(ChatCompletionChunk{Usage: &u})
+				if err != nil {
+					return 0, err
+				}
+				_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
+				if err != nil {
+					return 0, err
+				}
+			}
 			_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
 			if err != nil {
 				return 0, err
@@ -645,7 +693,11 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
 
 	// completion chunk
 	if w.stream {
-		d, err := json.Marshal(toCompleteChunk(w.id, generateResponse))
+		c := toCompleteChunk(w.id, generateResponse)
+		if w.streamUsage {
+			c.Usage = &nullChunkUsage
+		}
+		d, err := json.Marshal(c)
 		if err != nil {
 			return 0, err
 		}
@@ -657,6 +709,17 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
 		}
 
 		if generateResponse.Done {
+			if w.streamUsage {
+				u := toUsageGenerate(generateResponse)
+				d, err := json.Marshal(CompletionChunk{Usage: &u})
+				if err != nil {
+					return 0, err
+				}
+				_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
+				if err != nil {
+					return 0, err
+				}
+			}
 			_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
 			if err != nil {
 				return 0, err
@@ -819,9 +882,10 @@ func CompletionsMiddleware() gin.HandlerFunc {
 		c.Request.Body = io.NopCloser(&b)
 
 		w := &CompleteWriter{
-			BaseWriter: BaseWriter{ResponseWriter: c.Writer},
-			stream:     req.Stream,
-			id:         fmt.Sprintf("cmpl-%d", rand.Intn(999)),
+			BaseWriter:  BaseWriter{ResponseWriter: c.Writer},
+			stream:      req.Stream,
+			id:          fmt.Sprintf("cmpl-%d", rand.Intn(999)),
+			streamUsage: req.StreamOptions != nil && req.StreamOptions.IncludeUsage,
 		}
 
 		c.Writer = w
@@ -901,9 +965,10 @@ func ChatMiddleware() gin.HandlerFunc {
 		c.Request.Body = io.NopCloser(&b)
 
 		w := &ChatWriter{
-			BaseWriter: BaseWriter{ResponseWriter: c.Writer},
-			stream:     req.Stream,
-			id:         fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
+			BaseWriter:  BaseWriter{ResponseWriter: c.Writer},
+			stream:      req.Stream,
+			id:          fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
+			streamUsage: req.StreamOptions != nil && req.StreamOptions.IncludeUsage,
 		}
 
 		c.Writer = w

+ 88 - 0
openai/openai_test.go

@@ -111,6 +111,45 @@ func TestChatMiddleware(t *testing.T) {
 				Stream: &True,
 			},
 		},
+		{
+			name: "chat handler with streaming usage",
+			body: `{
+				"model": "test-model",
+				"messages": [
+					{"role": "user", "content": "Hello"}
+				],
+				"stream":            true,
+				"stream_options":    {"include_usage": true},
+				"max_tokens":        999,
+				"seed":              123,
+				"stop":              ["\n", "stop"],
+				"temperature":       3.0,
+				"frequency_penalty": 4.0,
+				"presence_penalty":  5.0,
+				"top_p":             6.0,
+				"response_format":   {"type": "json_object"}
+			}`,
+			req: api.ChatRequest{
+				Model: "test-model",
+				Messages: []api.Message{
+					{
+						Role:    "user",
+						Content: "Hello",
+					},
+				},
+				Options: map[string]any{
+					"num_predict":       999.0, // float because JSON doesn't distinguish between float and int
+					"seed":              123.0,
+					"stop":              []any{"\n", "stop"},
+					"temperature":       3.0,
+					"frequency_penalty": 4.0,
+					"presence_penalty":  5.0,
+					"top_p":             6.0,
+				},
+				Format: "json",
+				Stream: &True,
+			},
+		},
 		{
 			name: "chat handler with image content",
 			body: `{
@@ -283,6 +322,55 @@ func TestCompletionsMiddleware(t *testing.T) {
 				Stream: &False,
 			},
 		},
+		{
+			name: "completions handler stream",
+			body: `{
+				"model": "test-model",
+				"prompt": "Hello",
+				"stream": true,
+				"temperature": 0.8,
+				"stop": ["\n", "stop"],
+				"suffix": "suffix"
+			}`,
+			req: api.GenerateRequest{
+				Model:  "test-model",
+				Prompt: "Hello",
+				Options: map[string]any{
+					"frequency_penalty": 0.0,
+					"presence_penalty":  0.0,
+					"temperature":       0.8,
+					"top_p":             1.0,
+					"stop":              []any{"\n", "stop"},
+				},
+				Suffix: "suffix",
+				Stream: &True,
+			},
+		},
+		{
+			name: "completions handler stream with usage",
+			body: `{
+				"model": "test-model",
+				"prompt": "Hello",
+				"stream": true,
+				"stream_options": {"include_usage": true},
+				"temperature": 0.8,
+				"stop": ["\n", "stop"],
+				"suffix": "suffix"
+			}`,
+			req: api.GenerateRequest{
+				Model:  "test-model",
+				Prompt: "Hello",
+				Options: map[string]any{
+					"frequency_penalty": 0.0,
+					"presence_penalty":  0.0,
+					"temperature":       0.8,
+					"top_p":             1.0,
+					"stop":              []any{"\n", "stop"},
+				},
+				Suffix: "suffix",
+				Stream: &True,
+			},
+		},
 		{
 			name: "completions handler error forwarding",
 			body: `{