瀏覽代碼

add done_reason to the api (#4235)

Bruce MacDonald 11 月之前
父節點
當前提交
cfa84b8470
共有 4 個文件被更改,包括 44 次插入40 次删除
  1. 7 3
      api/types.go
  2. 12 4
      llm/server.go
  3. 6 18
      openai/openai.go
  4. 19 15
      server/routes.go

+ 7 - 3
api/types.go

@@ -114,9 +114,10 @@ type Message struct {
 // ChatResponse is the response returned by [Client.Chat]. Its fields are
 // similar to [GenerateResponse].
 type ChatResponse struct {
-	Model     string    `json:"model"`
-	CreatedAt time.Time `json:"created_at"`
-	Message   Message   `json:"message"`
+	Model      string    `json:"model"`
+	CreatedAt  time.Time `json:"created_at"`
+	Message    Message   `json:"message"`
+	DoneReason string    `json:"done_reason"`
 
 	Done bool `json:"done"`
 
@@ -309,6 +310,9 @@ type GenerateResponse struct {
 	// Done specifies if the response is complete.
 	Done bool `json:"done"`
 
+	// DoneReason is the reason the model stopped generating text.
+	DoneReason string `json:"done_reason"`
+
 	// Context is an encoding of the conversation used in this response; this
 	// can be sent in the next request to keep a conversational memory.
 	Context []int `json:"context,omitempty"`

+ 12 - 4
llm/server.go

@@ -576,10 +576,11 @@ type ImageData struct {
 }
 
 type completion struct {
-	Content string `json:"content"`
-	Model   string `json:"model"`
-	Prompt  string `json:"prompt"`
-	Stop    bool   `json:"stop"`
+	Content      string `json:"content"`
+	Model        string `json:"model"`
+	Prompt       string `json:"prompt"`
+	Stop         bool   `json:"stop"`
+	StoppedLimit bool   `json:"stopped_limit"`
 
 	Timings struct {
 		PredictedN  int     `json:"predicted_n"`
@@ -598,6 +599,7 @@ type CompletionRequest struct {
 
 type CompletionResponse struct {
 	Content            string
+	DoneReason         string
 	Done               bool
 	PromptEvalCount    int
 	PromptEvalDuration time.Duration
@@ -739,8 +741,14 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
 			}
 
 			if c.Stop {
+				doneReason := "stop"
+				if c.StoppedLimit {
+					doneReason = "length"
+				}
+
 				fn(CompletionResponse{
 					Done:               true,
+					DoneReason:         doneReason,
 					PromptEvalCount:    c.Timings.PromptN,
 					PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
 					EvalCount:          c.Timings.PredictedN,

+ 6 - 18
openai/openai.go

@@ -107,15 +107,9 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
 		Model:             r.Model,
 		SystemFingerprint: "fp_ollama",
 		Choices: []Choice{{
-			Index:   0,
-			Message: Message{Role: r.Message.Role, Content: r.Message.Content},
-			FinishReason: func(done bool) *string {
-				if done {
-					reason := "stop"
-					return &reason
-				}
-				return nil
-			}(r.Done),
+			Index:        0,
+			Message:      Message{Role: r.Message.Role, Content: r.Message.Content},
+			FinishReason: &r.DoneReason,
 		}},
 		Usage: Usage{
 			// TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
@@ -135,15 +129,9 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
 		SystemFingerprint: "fp_ollama",
 		Choices: []ChunkChoice{
 			{
-				Index: 0,
-				Delta: Message{Role: "assistant", Content: r.Message.Content},
-				FinishReason: func(done bool) *string {
-					if done {
-						reason := "stop"
-						return &reason
-					}
-					return nil
-				}(r.Done),
+				Index:        0,
+				Delta:        Message{Role: "assistant", Content: r.Message.Content},
+				FinishReason: &r.DoneReason,
 			},
 		},
 	}

+ 19 - 15
server/routes.go

@@ -152,9 +152,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 	// of `raw` mode so we need to check for it too
 	if req.Prompt == "" && req.Template == "" && req.System == "" {
 		c.JSON(http.StatusOK, api.GenerateResponse{
-			CreatedAt: time.Now().UTC(),
-			Model:     req.Model,
-			Done:      true,
+			CreatedAt:  time.Now().UTC(),
+			Model:      req.Model,
+			Done:       true,
+			DoneReason: "load",
 		})
 		return
 	}
@@ -222,10 +223,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 			}
 
 			resp := api.GenerateResponse{
-				Model:     req.Model,
-				CreatedAt: time.Now().UTC(),
-				Done:      r.Done,
-				Response:  r.Content,
+				Model:      req.Model,
+				CreatedAt:  time.Now().UTC(),
+				Done:       r.Done,
+				Response:   r.Content,
+				DoneReason: r.DoneReason,
 				Metrics: api.Metrics{
 					PromptEvalCount:    r.PromptEvalCount,
 					PromptEvalDuration: r.PromptEvalDuration,
@@ -1215,10 +1217,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
 	// an empty request loads the model
 	if len(req.Messages) == 0 || prompt == "" {
 		resp := api.ChatResponse{
-			CreatedAt: time.Now().UTC(),
-			Model:     req.Model,
-			Done:      true,
-			Message:   api.Message{Role: "assistant"},
+			CreatedAt:  time.Now().UTC(),
+			Model:      req.Model,
+			Done:       true,
+			DoneReason: "load",
+			Message:    api.Message{Role: "assistant"},
 		}
 		c.JSON(http.StatusOK, resp)
 		return
@@ -1251,10 +1254,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
 		fn := func(r llm.CompletionResponse) {
 
 			resp := api.ChatResponse{
-				Model:     req.Model,
-				CreatedAt: time.Now().UTC(),
-				Message:   api.Message{Role: "assistant", Content: r.Content},
-				Done:      r.Done,
+				Model:      req.Model,
+				CreatedAt:  time.Now().UTC(),
+				Message:    api.Message{Role: "assistant", Content: r.Content},
+				Done:       r.Done,
+				DoneReason: r.DoneReason,
 				Metrics: api.Metrics{
 					PromptEvalCount:    r.PromptEvalCount,
 					PromptEvalDuration: r.PromptEvalDuration,