|
@@ -61,6 +61,21 @@ type Usage struct {
|
|
|
TotalTokens int `json:"total_tokens"`
|
|
|
}
|
|
|
|
|
|
+// ChunkUsage is an alias for Usage with the ability to marshal a marker
|
|
|
+// value as null. This is to allow omitting the field in chunks when usage
|
|
|
+// isn't requested, and otherwise return null on non-final chunks when it
|
|
|
+// is requested to follow OpenAI's behavior.
|
|
|
+type ChunkUsage = Usage
|
|
|
+
|
|
|
+var nullChunkUsage = ChunkUsage{}
|
|
|
+
|
|
|
+func (u *ChunkUsage) MarshalJSON() ([]byte, error) {
|
|
|
+ if u == &nullChunkUsage {
|
|
|
+ return []byte("null"), nil
|
|
|
+ }
|
|
|
+ return json.Marshal(*u)
|
|
|
+}
|
|
|
+
|
|
|
type ResponseFormat struct {
|
|
|
Type string `json:"type"`
|
|
|
}
|
|
@@ -70,10 +85,15 @@ type EmbedRequest struct {
|
|
|
Model string `json:"model"`
|
|
|
}
|
|
|
|
|
|
+type StreamOptions struct {
|
|
|
+ IncludeUsage bool `json:"include_usage"`
|
|
|
+}
|
|
|
+
|
|
|
type ChatCompletionRequest struct {
|
|
|
Model string `json:"model"`
|
|
|
Messages []Message `json:"messages"`
|
|
|
Stream bool `json:"stream"`
|
|
|
+ StreamOptions *StreamOptions `json:"stream_options"`
|
|
|
MaxTokens *int `json:"max_tokens"`
|
|
|
Seed *int `json:"seed"`
|
|
|
Stop any `json:"stop"`
|
|
@@ -102,21 +122,23 @@ type ChatCompletionChunk struct {
|
|
|
Model string `json:"model"`
|
|
|
SystemFingerprint string `json:"system_fingerprint"`
|
|
|
Choices []ChunkChoice `json:"choices"`
|
|
|
+ Usage *ChunkUsage `json:"usage,omitempty"`
|
|
|
}
|
|
|
|
|
|
// TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
|
|
|
type CompletionRequest struct {
|
|
|
- Model string `json:"model"`
|
|
|
- Prompt string `json:"prompt"`
|
|
|
- FrequencyPenalty float32 `json:"frequency_penalty"`
|
|
|
- MaxTokens *int `json:"max_tokens"`
|
|
|
- PresencePenalty float32 `json:"presence_penalty"`
|
|
|
- Seed *int `json:"seed"`
|
|
|
- Stop any `json:"stop"`
|
|
|
- Stream bool `json:"stream"`
|
|
|
- Temperature *float32 `json:"temperature"`
|
|
|
- TopP float32 `json:"top_p"`
|
|
|
- Suffix string `json:"suffix"`
|
|
|
+ Model string `json:"model"`
|
|
|
+ Prompt string `json:"prompt"`
|
|
|
+ FrequencyPenalty float32 `json:"frequency_penalty"`
|
|
|
+ MaxTokens *int `json:"max_tokens"`
|
|
|
+ PresencePenalty float32 `json:"presence_penalty"`
|
|
|
+ Seed *int `json:"seed"`
|
|
|
+ Stop any `json:"stop"`
|
|
|
+ Stream bool `json:"stream"`
|
|
|
+ StreamOptions *StreamOptions `json:"stream_options"`
|
|
|
+ Temperature *float32 `json:"temperature"`
|
|
|
+ TopP float32 `json:"top_p"`
|
|
|
+ Suffix string `json:"suffix"`
|
|
|
}
|
|
|
|
|
|
type Completion struct {
|
|
@@ -136,6 +158,7 @@ type CompletionChunk struct {
|
|
|
Choices []CompleteChunkChoice `json:"choices"`
|
|
|
Model string `json:"model"`
|
|
|
SystemFingerprint string `json:"system_fingerprint"`
|
|
|
+ Usage *ChunkUsage `json:"usage,omitempty"`
|
|
|
}
|
|
|
|
|
|
type ToolCall struct {
|
|
@@ -200,6 +223,14 @@ func toolCallId() string {
|
|
|
return "call_" + strings.ToLower(string(b))
|
|
|
}
|
|
|
|
|
|
+func toUsage(r api.ChatResponse) Usage {
|
|
|
+ return Usage{
|
|
|
+ PromptTokens: r.PromptEvalCount,
|
|
|
+ CompletionTokens: r.EvalCount,
|
|
|
+ TotalTokens: r.PromptEvalCount + r.EvalCount,
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
|
|
toolCalls := make([]ToolCall, len(r.Message.ToolCalls))
|
|
|
for i, tc := range r.Message.ToolCalls {
|
|
@@ -235,11 +266,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
|
|
return nil
|
|
|
}(r.DoneReason),
|
|
|
}},
|
|
|
- Usage: Usage{
|
|
|
- PromptTokens: r.PromptEvalCount,
|
|
|
- CompletionTokens: r.EvalCount,
|
|
|
- TotalTokens: r.PromptEvalCount + r.EvalCount,
|
|
|
- },
|
|
|
+ Usage: toUsage(r),
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -263,6 +290,14 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+func toUsageGenerate(r api.GenerateResponse) Usage {
|
|
|
+ return Usage{
|
|
|
+ PromptTokens: r.PromptEvalCount,
|
|
|
+ CompletionTokens: r.EvalCount,
|
|
|
+ TotalTokens: r.PromptEvalCount + r.EvalCount,
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
func toCompletion(id string, r api.GenerateResponse) Completion {
|
|
|
return Completion{
|
|
|
Id: id,
|
|
@@ -280,11 +315,7 @@ func toCompletion(id string, r api.GenerateResponse) Completion {
|
|
|
return nil
|
|
|
}(r.DoneReason),
|
|
|
}},
|
|
|
- Usage: Usage{
|
|
|
- PromptTokens: r.PromptEvalCount,
|
|
|
- CompletionTokens: r.EvalCount,
|
|
|
- TotalTokens: r.PromptEvalCount + r.EvalCount,
|
|
|
- },
|
|
|
+ Usage: toUsageGenerate(r),
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -546,14 +577,16 @@ type BaseWriter struct {
|
|
|
}
|
|
|
|
|
|
type ChatWriter struct {
|
|
|
- stream bool
|
|
|
- id string
|
|
|
+ stream bool
|
|
|
+ streamUsage bool
|
|
|
+ id string
|
|
|
BaseWriter
|
|
|
}
|
|
|
|
|
|
type CompleteWriter struct {
|
|
|
- stream bool
|
|
|
- id string
|
|
|
+ stream bool
|
|
|
+ streamUsage bool
|
|
|
+ id string
|
|
|
BaseWriter
|
|
|
}
|
|
|
|
|
@@ -596,7 +629,11 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
|
|
|
|
|
// chat chunk
|
|
|
if w.stream {
|
|
|
- d, err := json.Marshal(toChunk(w.id, chatResponse))
|
|
|
+ c := toChunk(w.id, chatResponse)
|
|
|
+ if w.streamUsage {
|
|
|
+ c.Usage = &nullChunkUsage
|
|
|
+ }
|
|
|
+ d, err := json.Marshal(c)
|
|
|
if err != nil {
|
|
|
return 0, err
|
|
|
}
|
|
@@ -608,6 +645,17 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
|
|
}
|
|
|
|
|
|
if chatResponse.Done {
|
|
|
+ if w.streamUsage {
|
|
|
+ u := toUsage(chatResponse)
|
|
|
+ d, err := json.Marshal(ChatCompletionChunk{Usage: &u})
|
|
|
+ if err != nil {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+ _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
|
|
+ if err != nil {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+ }
|
|
|
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
|
|
if err != nil {
|
|
|
return 0, err
|
|
@@ -645,7 +693,11 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
|
|
|
|
|
// completion chunk
|
|
|
if w.stream {
|
|
|
- d, err := json.Marshal(toCompleteChunk(w.id, generateResponse))
|
|
|
+ c := toCompleteChunk(w.id, generateResponse)
|
|
|
+ if w.streamUsage {
|
|
|
+ c.Usage = &nullChunkUsage
|
|
|
+ }
|
|
|
+ d, err := json.Marshal(c)
|
|
|
if err != nil {
|
|
|
return 0, err
|
|
|
}
|
|
@@ -657,6 +709,17 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
|
|
}
|
|
|
|
|
|
if generateResponse.Done {
|
|
|
+ if w.streamUsage {
|
|
|
+ u := toUsageGenerate(generateResponse)
|
|
|
+ d, err := json.Marshal(CompletionChunk{Usage: &u})
|
|
|
+ if err != nil {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+ _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
|
|
+ if err != nil {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+ }
|
|
|
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
|
|
if err != nil {
|
|
|
return 0, err
|
|
@@ -819,9 +882,10 @@ func CompletionsMiddleware() gin.HandlerFunc {
|
|
|
c.Request.Body = io.NopCloser(&b)
|
|
|
|
|
|
w := &CompleteWriter{
|
|
|
- BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
|
|
- stream: req.Stream,
|
|
|
- id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
|
|
|
+ BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
|
|
+ stream: req.Stream,
|
|
|
+ id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
|
|
|
+ streamUsage: req.StreamOptions != nil && req.StreamOptions.IncludeUsage,
|
|
|
}
|
|
|
|
|
|
c.Writer = w
|
|
@@ -901,9 +965,10 @@ func ChatMiddleware() gin.HandlerFunc {
|
|
|
c.Request.Body = io.NopCloser(&b)
|
|
|
|
|
|
w := &ChatWriter{
|
|
|
- BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
|
|
- stream: req.Stream,
|
|
|
- id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
|
|
+ BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
|
|
+ stream: req.Stream,
|
|
|
+ id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
|
|
+ streamUsage: req.StreamOptions != nil && req.StreamOptions.IncludeUsage,
|
|
|
}
|
|
|
|
|
|
c.Writer = w
|