Browse Source

use done reason enum

Bruce MacDonald 1 tháng trước cách đây
mục cha
commit
771c88b3ad
4 tập tin đã thay đổi với 30 bổ sung13 xóa
  1. 26 9
      llm/server.go
  2. 1 1
      runner/llamarunner/runner.go
  3. 1 1
      runner/ollamarunner/runner.go
  4. 2 2
      server/routes.go

+ 26 - 9
llm/server.go

@@ -675,9 +675,34 @@ type CompletionRequest struct {
 	Grammar string // set before sending the request to the subprocess
 }
 
+// DoneReason represents the reason why a completion response is done
+type DoneReason string
+
+const (
+	// DoneReasonStop indicates the completion stopped naturally
+	DoneReasonStop DoneReason = "stop"
+	// DoneReasonLength indicates the completion stopped due to length limits
+	DoneReasonLength DoneReason = "length"
+)
+
+func (d DoneReason) String() string {
+	return string(d)
+}
+
+// ParseDoneReason converts a string to a DoneReason type
+// If the string doesn't match any known reason, it defaults to DoneReasonStop
+func ParseDoneReason(reason string) DoneReason {
+	switch reason {
+	case "limit", "length":
+		return DoneReasonLength
+	default:
+		return DoneReasonStop
+	}
+}
+
 type CompletionResponse struct {
 	Content            string        `json:"content"`
-	DoneReason         string        `json:"done_reason"`
+	DoneReason         DoneReason    `json:"done_reason"`
 	Done               bool          `json:"done"`
 	PromptEvalCount    int           `json:"prompt_eval_count"`
 	PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
@@ -786,7 +811,6 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
 				continue
 			}
 
-			// slog.Debug("got line", "line", string(line))
 			evt, ok := bytes.CutPrefix(line, []byte("data: "))
 			if !ok {
 				evt = line
@@ -796,13 +820,6 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
 			if err := json.Unmarshal(evt, &c); err != nil {
 				return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
 			}
-			// convert internal done reason to one of our standard api format done reasons
-			switch c.DoneReason {
-			case "limit":
-				c.DoneReason = "length"
-			default:
-				c.DoneReason = "stop"
-			}
 
 			switch {
 			case strings.TrimSpace(c.Content) == lastToken:

+ 1 - 1
runner/llamarunner/runner.go

@@ -649,7 +649,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 			} else {
 				if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
 					Done:               true,
-					DoneReason:         seq.doneReason,
+					DoneReason:         llm.ParseDoneReason(seq.doneReason),
 					PromptEvalCount:    seq.numPromptInputs,
 					PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
 					EvalCount:          seq.numDecoded,

+ 1 - 1
runner/ollamarunner/runner.go

@@ -629,7 +629,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 			} else {
 				if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
 					Done:               true,
-					DoneReason:         seq.doneReason,
+					DoneReason:         llm.ParseDoneReason(seq.doneReason),
 					PromptEvalCount:    seq.numPromptInputs,
 					PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
 					EvalCount:          seq.numPredicted,

+ 2 - 2
server/routes.go

@@ -312,7 +312,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 				CreatedAt:  time.Now().UTC(),
 				Response:   cr.Content,
 				Done:       cr.Done,
-				DoneReason: cr.DoneReason,
+				DoneReason: cr.DoneReason.String(),
 				Metrics: api.Metrics{
 					PromptEvalCount:    cr.PromptEvalCount,
 					PromptEvalDuration: cr.PromptEvalDuration,
@@ -1536,7 +1536,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
 				CreatedAt:  time.Now().UTC(),
 				Message:    api.Message{Role: "assistant", Content: r.Content},
 				Done:       r.Done,
-				DoneReason: r.DoneReason,
+				DoneReason: r.DoneReason.String(),
 				Metrics: api.Metrics{
 					PromptEvalCount:    r.PromptEvalCount,
 					PromptEvalDuration: r.PromptEvalDuration,