Browse Source

openai: return usage as final chunk for streams (#6784)

* openai: return usage as final chunk for streams

---------

Co-authored-by: ParthSareen <parth.sareen@ollama.com>
Anuraag (Rag) Agrawal 4 months ago
parent
commit
e28f2d4900
3 changed files with 176 additions and 33 deletions
  1. 4 0
      docs/openai.md
  2. 84 33
      openai/openai.go
  3. 88 0
      openai/openai_test.go

+ 4 - 0
docs/openai.md

@@ -233,6 +233,8 @@ curl http://localhost:11434/v1/embeddings \
 - [x] `seed`
 - [x] `seed`
 - [x] `stop`
 - [x] `stop`
 - [x] `stream`
 - [x] `stream`
+- [x] `stream_options`
+  - [x] `include_usage`
 - [x] `temperature`
 - [x] `temperature`
 - [x] `top_p`
 - [x] `top_p`
 - [x] `max_tokens`
 - [x] `max_tokens`
@@ -261,6 +263,8 @@ curl http://localhost:11434/v1/embeddings \
 - [x] `seed`
 - [x] `seed`
 - [x] `stop`
 - [x] `stop`
 - [x] `stream`
 - [x] `stream`
+- [x] `stream_options`
+  - [x] `include_usage`
 - [x] `temperature`
 - [x] `temperature`
 - [x] `top_p`
 - [x] `top_p`
 - [x] `max_tokens`
 - [x] `max_tokens`

+ 84 - 33
openai/openai.go

@@ -75,10 +75,15 @@ type EmbedRequest struct {
 	Model string `json:"model"`
 	Model string `json:"model"`
 }
 }
 
 
+type StreamOptions struct {
+	IncludeUsage bool `json:"include_usage"`
+}
+
 type ChatCompletionRequest struct {
 type ChatCompletionRequest struct {
 	Model            string          `json:"model"`
 	Model            string          `json:"model"`
 	Messages         []Message       `json:"messages"`
 	Messages         []Message       `json:"messages"`
 	Stream           bool            `json:"stream"`
 	Stream           bool            `json:"stream"`
+	StreamOptions    *StreamOptions  `json:"stream_options"`
 	MaxTokens        *int            `json:"max_tokens"`
 	MaxTokens        *int            `json:"max_tokens"`
 	Seed             *int            `json:"seed"`
 	Seed             *int            `json:"seed"`
 	Stop             any             `json:"stop"`
 	Stop             any             `json:"stop"`
@@ -107,21 +112,23 @@ type ChatCompletionChunk struct {
 	Model             string        `json:"model"`
 	Model             string        `json:"model"`
 	SystemFingerprint string        `json:"system_fingerprint"`
 	SystemFingerprint string        `json:"system_fingerprint"`
 	Choices           []ChunkChoice `json:"choices"`
 	Choices           []ChunkChoice `json:"choices"`
+	Usage             *Usage        `json:"usage,omitempty"`
 }
 }
 
 
 // TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
 // TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
 type CompletionRequest struct {
 type CompletionRequest struct {
-	Model            string   `json:"model"`
-	Prompt           string   `json:"prompt"`
-	FrequencyPenalty float32  `json:"frequency_penalty"`
-	MaxTokens        *int     `json:"max_tokens"`
-	PresencePenalty  float32  `json:"presence_penalty"`
-	Seed             *int     `json:"seed"`
-	Stop             any      `json:"stop"`
-	Stream           bool     `json:"stream"`
-	Temperature      *float32 `json:"temperature"`
-	TopP             float32  `json:"top_p"`
-	Suffix           string   `json:"suffix"`
+	Model            string         `json:"model"`
+	Prompt           string         `json:"prompt"`
+	FrequencyPenalty float32        `json:"frequency_penalty"`
+	MaxTokens        *int           `json:"max_tokens"`
+	PresencePenalty  float32        `json:"presence_penalty"`
+	Seed             *int           `json:"seed"`
+	Stop             any            `json:"stop"`
+	Stream           bool           `json:"stream"`
+	StreamOptions    *StreamOptions `json:"stream_options"`
+	Temperature      *float32       `json:"temperature"`
+	TopP             float32        `json:"top_p"`
+	Suffix           string         `json:"suffix"`
 }
 }
 
 
 type Completion struct {
 type Completion struct {
@@ -141,6 +148,7 @@ type CompletionChunk struct {
 	Choices           []CompleteChunkChoice `json:"choices"`
 	Choices           []CompleteChunkChoice `json:"choices"`
 	Model             string                `json:"model"`
 	Model             string                `json:"model"`
 	SystemFingerprint string                `json:"system_fingerprint"`
 	SystemFingerprint string                `json:"system_fingerprint"`
+	Usage             *Usage                `json:"usage,omitempty"`
 }
 }
 
 
 type ToolCall struct {
 type ToolCall struct {
@@ -197,6 +205,14 @@ func NewError(code int, message string) ErrorResponse {
 	return ErrorResponse{Error{Type: etype, Message: message}}
 	return ErrorResponse{Error{Type: etype, Message: message}}
 }
 }
 
 
+func toUsage(r api.ChatResponse) Usage {
+	return Usage{
+		PromptTokens:     r.PromptEvalCount,
+		CompletionTokens: r.EvalCount,
+		TotalTokens:      r.PromptEvalCount + r.EvalCount,
+	}
+}
+
 func toolCallId() string {
 func toolCallId() string {
 	const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
 	const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
 	b := make([]byte, 8)
 	b := make([]byte, 8)
@@ -246,11 +262,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
 				return nil
 				return nil
 			}(r.DoneReason),
 			}(r.DoneReason),
 		}},
 		}},
-		Usage: Usage{
-			PromptTokens:     r.PromptEvalCount,
-			CompletionTokens: r.EvalCount,
-			TotalTokens:      r.PromptEvalCount + r.EvalCount,
-		},
+		Usage: toUsage(r),
 	}
 	}
 }
 }
 
 
@@ -275,6 +287,14 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
 	}
 	}
 }
 }
 
 
+func toUsageGenerate(r api.GenerateResponse) Usage {
+	return Usage{
+		PromptTokens:     r.PromptEvalCount,
+		CompletionTokens: r.EvalCount,
+		TotalTokens:      r.PromptEvalCount + r.EvalCount,
+	}
+}
+
 func toCompletion(id string, r api.GenerateResponse) Completion {
 func toCompletion(id string, r api.GenerateResponse) Completion {
 	return Completion{
 	return Completion{
 		Id:                id,
 		Id:                id,
@@ -292,11 +312,7 @@ func toCompletion(id string, r api.GenerateResponse) Completion {
 				return nil
 				return nil
 			}(r.DoneReason),
 			}(r.DoneReason),
 		}},
 		}},
-		Usage: Usage{
-			PromptTokens:     r.PromptEvalCount,
-			CompletionTokens: r.EvalCount,
-			TotalTokens:      r.PromptEvalCount + r.EvalCount,
-		},
+		Usage: toUsageGenerate(r),
 	}
 	}
 }
 }
 
 
@@ -566,14 +582,16 @@ type BaseWriter struct {
 }
 }
 
 
 type ChatWriter struct {
 type ChatWriter struct {
-	stream bool
-	id     string
+	stream        bool
+	streamOptions *StreamOptions
+	id            string
 	BaseWriter
 	BaseWriter
 }
 }
 
 
 type CompleteWriter struct {
 type CompleteWriter struct {
-	stream bool
-	id     string
+	stream        bool
+	streamOptions *StreamOptions
+	id            string
 	BaseWriter
 	BaseWriter
 }
 }
 
 
@@ -616,7 +634,8 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
 
 
 	// chat chunk
 	// chat chunk
 	if w.stream {
 	if w.stream {
-		d, err := json.Marshal(toChunk(w.id, chatResponse))
+		c := toChunk(w.id, chatResponse)
+		d, err := json.Marshal(c)
 		if err != nil {
 		if err != nil {
 			return 0, err
 			return 0, err
 		}
 		}
@@ -628,6 +647,19 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
 		}
 		}
 
 
 		if chatResponse.Done {
 		if chatResponse.Done {
+			if w.streamOptions != nil && w.streamOptions.IncludeUsage {
+				u := toUsage(chatResponse)
+				c.Usage = &u
+				c.Choices = []ChunkChoice{}
+				d, err := json.Marshal(c)
+				if err != nil {
+					return 0, err
+				}
+				_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
+				if err != nil {
+					return 0, err
+				}
+			}
 			_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
 			_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
 			if err != nil {
 			if err != nil {
 				return 0, err
 				return 0, err
@@ -665,7 +697,11 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
 
 
 	// completion chunk
 	// completion chunk
 	if w.stream {
 	if w.stream {
-		d, err := json.Marshal(toCompleteChunk(w.id, generateResponse))
+		c := toCompleteChunk(w.id, generateResponse)
+		if w.streamOptions != nil && w.streamOptions.IncludeUsage {
+			c.Usage = &Usage{}
+		}
+		d, err := json.Marshal(c)
 		if err != nil {
 		if err != nil {
 			return 0, err
 			return 0, err
 		}
 		}
@@ -677,6 +713,19 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
 		}
 		}
 
 
 		if generateResponse.Done {
 		if generateResponse.Done {
+			if w.streamOptions != nil && w.streamOptions.IncludeUsage {
+				u := toUsageGenerate(generateResponse)
+				c.Usage = &u
+				c.Choices = []CompleteChunkChoice{}
+				d, err := json.Marshal(c)
+				if err != nil {
+					return 0, err
+				}
+				_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
+				if err != nil {
+					return 0, err
+				}
+			}
 			_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
 			_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
 			if err != nil {
 			if err != nil {
 				return 0, err
 				return 0, err
@@ -839,9 +888,10 @@ func CompletionsMiddleware() gin.HandlerFunc {
 		c.Request.Body = io.NopCloser(&b)
 		c.Request.Body = io.NopCloser(&b)
 
 
 		w := &CompleteWriter{
 		w := &CompleteWriter{
-			BaseWriter: BaseWriter{ResponseWriter: c.Writer},
-			stream:     req.Stream,
-			id:         fmt.Sprintf("cmpl-%d", rand.Intn(999)),
+			BaseWriter:    BaseWriter{ResponseWriter: c.Writer},
+			stream:        req.Stream,
+			id:            fmt.Sprintf("cmpl-%d", rand.Intn(999)),
+			streamOptions: req.StreamOptions,
 		}
 		}
 
 
 		c.Writer = w
 		c.Writer = w
@@ -921,9 +971,10 @@ func ChatMiddleware() gin.HandlerFunc {
 		c.Request.Body = io.NopCloser(&b)
 		c.Request.Body = io.NopCloser(&b)
 
 
 		w := &ChatWriter{
 		w := &ChatWriter{
-			BaseWriter: BaseWriter{ResponseWriter: c.Writer},
-			stream:     req.Stream,
-			id:         fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
+			BaseWriter:    BaseWriter{ResponseWriter: c.Writer},
+			stream:        req.Stream,
+			id:            fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
+			streamOptions: req.StreamOptions,
 		}
 		}
 
 
 		c.Writer = w
 		c.Writer = w

+ 88 - 0
openai/openai_test.go

@@ -112,6 +112,45 @@ func TestChatMiddleware(t *testing.T) {
 				Stream: &True,
 				Stream: &True,
 			},
 			},
 		},
 		},
+		{
+			name: "chat handler with streaming usage",
+			body: `{
+				"model": "test-model",
+				"messages": [
+					{"role": "user", "content": "Hello"}
+				],
+				"stream":            true,
+				"stream_options":    {"include_usage": true},
+				"max_tokens":        999,
+				"seed":              123,
+				"stop":              ["\n", "stop"],
+				"temperature":       3.0,
+				"frequency_penalty": 4.0,
+				"presence_penalty":  5.0,
+				"top_p":             6.0,
+				"response_format":   {"type": "json_object"}
+			}`,
+			req: api.ChatRequest{
+				Model: "test-model",
+				Messages: []api.Message{
+					{
+						Role:    "user",
+						Content: "Hello",
+					},
+				},
+				Options: map[string]any{
+					"num_predict":       999.0, // float because JSON doesn't distinguish between float and int
+					"seed":              123.0,
+					"stop":              []any{"\n", "stop"},
+					"temperature":       3.0,
+					"frequency_penalty": 4.0,
+					"presence_penalty":  5.0,
+					"top_p":             6.0,
+				},
+				Format: json.RawMessage(`"json"`),
+				Stream: &True,
+			},
+		},
 		{
 		{
 			name: "chat handler with image content",
 			name: "chat handler with image content",
 			body: `{
 			body: `{
@@ -363,6 +402,55 @@ func TestCompletionsMiddleware(t *testing.T) {
 				Stream: &False,
 				Stream: &False,
 			},
 			},
 		},
 		},
+		{
+			name: "completions handler stream",
+			body: `{
+				"model": "test-model",
+				"prompt": "Hello",
+				"stream": true,
+				"temperature": 0.8,
+				"stop": ["\n", "stop"],
+				"suffix": "suffix"
+			}`,
+			req: api.GenerateRequest{
+				Model:  "test-model",
+				Prompt: "Hello",
+				Options: map[string]any{
+					"frequency_penalty": 0.0,
+					"presence_penalty":  0.0,
+					"temperature":       0.8,
+					"top_p":             1.0,
+					"stop":              []any{"\n", "stop"},
+				},
+				Suffix: "suffix",
+				Stream: &True,
+			},
+		},
+		{
+			name: "completions handler stream with usage",
+			body: `{
+				"model": "test-model",
+				"prompt": "Hello",
+				"stream": true,
+				"stream_options": {"include_usage": true},
+				"temperature": 0.8,
+				"stop": ["\n", "stop"],
+				"suffix": "suffix"
+			}`,
+			req: api.GenerateRequest{
+				Model:  "test-model",
+				Prompt: "Hello",
+				Options: map[string]any{
+					"frequency_penalty": 0.0,
+					"presence_penalty":  0.0,
+					"temperature":       0.8,
+					"top_p":             1.0,
+					"stop":              []any{"\n", "stop"},
+				},
+				Suffix: "suffix",
+				Stream: &True,
+			},
+		},
 		{
 		{
 			name: "completions handler error forwarding",
 			name: "completions handler error forwarding",
 			body: `{
 			body: `{