Просмотр исходного кода

llm: set done reason at server level

No functional change. Many different done reasons can be set at the runner
level, so rather than obsuring them we should return them to the server
process and let it choose what to do with the done reason. This separates
the API concerns from the runner.
Bruce MacDonald 1 месяц назад
Родитель
Сommit
22f2f6e229
3 измененных файлов с 10 добавлено и 12 удалено
  1. 8 0
      llm/server.go
  2. 1 6
      runner/llamarunner/runner.go
  3. 1 6
      runner/ollamarunner/runner.go

+ 8 - 0
llm/server.go

@@ -796,6 +796,14 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
 			if err := json.Unmarshal(evt, &c); err != nil {
 			if err := json.Unmarshal(evt, &c); err != nil {
 				return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
 				return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
 			}
 			}
+			// convert internal done reason to one of our standard api format done reasons
+			switch c.DoneReason {
+			case "limit":
+				c.DoneReason = "length"
+			default:
+				c.DoneReason = "stop"
+			}
+
 			switch {
 			switch {
 			case strings.TrimSpace(c.Content) == lastToken:
 			case strings.TrimSpace(c.Content) == lastToken:
 				tokenRepeat++
 				tokenRepeat++

+ 1 - 6
runner/llamarunner/runner.go

@@ -647,14 +647,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 
 
 				flusher.Flush()
 				flusher.Flush()
 			} else {
 			} else {
-				// Send the final response
-				doneReason := "stop"
-				if seq.doneReason == "limit" {
-					doneReason = "length"
-				}
 				if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
 				if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
 					Done:               true,
 					Done:               true,
-					DoneReason:         doneReason,
+					DoneReason:         seq.doneReason,
 					PromptEvalCount:    seq.numPromptInputs,
 					PromptEvalCount:    seq.numPromptInputs,
 					PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
 					PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
 					EvalCount:          seq.numDecoded,
 					EvalCount:          seq.numDecoded,

+ 1 - 6
runner/ollamarunner/runner.go

@@ -627,14 +627,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 
 
 				flusher.Flush()
 				flusher.Flush()
 			} else {
 			} else {
-				// Send the final response
-				doneReason := "stop"
-				if seq.doneReason == "limit" {
-					doneReason = "length"
-				}
 				if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
 				if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
 					Done:               true,
 					Done:               true,
-					DoneReason:         doneReason,
+					DoneReason:         seq.doneReason,
 					PromptEvalCount:    seq.numPromptInputs,
 					PromptEvalCount:    seq.numPromptInputs,
 					PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
 					PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
 					EvalCount:          seq.numPredicted,
 					EvalCount:          seq.numPredicted,