Selaa lähdekoodia

Add `done_reason`

jmorganca 1 vuosi sitten
vanhempi
commit
e117483ef6
5 muutettua tiedostoa jossa 70 lisäystä ja 36 poistoa
  1. 5 3
      api/types.go
  2. 20 4
      llm/server.go
  3. 10 8
      server/prompt.go
  4. 1 1
      server/prompt_test.go
  5. 34 20
      server/routes.go

+ 5 - 3
api/types.go

@@ -98,7 +98,8 @@ type ChatResponse struct {
 	CreatedAt time.Time `json:"created_at"`
 	Message   Message   `json:"message"`
 
-	Done bool `json:"done"`
+	Done       bool   `json:"done"`
+	DoneReason string `json:"done_reason,omitempty"`
 
 	Metrics
 }
@@ -265,8 +266,9 @@ type GenerateResponse struct {
 	CreatedAt time.Time `json:"created_at"`
 	Response  string    `json:"response"`
 
-	Done    bool  `json:"done"`
-	Context []int `json:"context,omitempty"`
+	Done       bool   `json:"done"`
+	DoneReason string `json:"done_reason,omitempty"`
+	Context    []int  `json:"context,omitempty"`
 
 	Metrics
 }

+ 20 - 4
llm/server.go

@@ -509,10 +509,13 @@ type ImageData struct {
 }
 
 type completion struct {
-	Content string `json:"content"`
-	Model   string `json:"model"`
-	Prompt  string `json:"prompt"`
-	Stop    bool   `json:"stop"`
+	Content      string `json:"content"`
+	Model        string `json:"model"`
+	Prompt       string `json:"prompt"`
+	Stop         bool   `json:"stop"`
+	StoppedEos   bool   `json:"stopped_eos"`
+	StoppedWord  bool   `json:"stopped_word"`
+	StoppedLimit bool   `json:"stopped_limit"`
 
 	Timings struct {
 		PredictedN  int     `json:"predicted_n"`
@@ -532,6 +535,7 @@ type CompletionRequest struct {
 type CompletionResponse struct {
 	Content            string
 	Done               bool
+	DoneReason         string
 	PromptEvalCount    int
 	PromptEvalDuration time.Duration
 	EvalCount          int
@@ -648,6 +652,8 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn
 					return fmt.Errorf("error parsing llm response stream: %s", line)
 				}
 
+				fmt.Println("c", string(evt))
+
 				var c completion
 				if err := json.Unmarshal(evt, &c); err != nil {
 					return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
@@ -674,8 +680,18 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn
 				}
 
 				if c.Stop {
+					var doneReason string
+					switch {
+					case c.StoppedEos:
+						doneReason = "stop"
+					case c.StoppedWord:
+						doneReason = "stop"
+					case c.StoppedLimit:
+						doneReason = "limit"
+					}
 					fn(CompletionResponse{
 						Done:               true,
+						DoneReason:         doneReason,
 						PromptEvalCount:    c.Timings.PromptN,
 						PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
 						EvalCount:          c.Timings.PredictedN,

+ 10 - 8
server/prompt.go

@@ -91,7 +91,7 @@ func countTokens(tmpl string, system string, prompt string, response string, enc
 }
 
 // ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
-func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
+func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, int, error) {
 	type prompt struct {
 		System   string
 		Prompt   string
@@ -138,7 +138,7 @@ func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(str
 
 			p.Response = msg.Content
 		default:
-			return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
+			return "", 0, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
 		}
 	}
 
@@ -151,7 +151,7 @@ func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(str
 	for i, p := range prompts {
 		tokens, err := countTokens(tmpl, p.System, p.Prompt, p.Response, encode)
 		if err != nil {
-			return "", err
+			return "", 0, err
 		}
 
 		prompts[i].tokens = tokens + len(prompts[i].images)*768
@@ -160,15 +160,17 @@ func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(str
 	// truncate images and prompts starting from the beginning of the list
 	// until either one prompt remains or the total tokens fits the context window
 	// TODO (jmorganca): this doesn't account for the context window room required for the response
+	var required int
 	for {
-		var required int
+		required = 0
 		for _, p := range prompts {
 			required += p.tokens
 		}
 
 		required += 1 // for bos token
 
-		if required <= window {
+		// leave ~1024 tokens for generation
+		if required <= max(1024, window/2) {
 			slog.Debug("prompt now fits in context window", "required", required, "window", window)
 			break
 		}
@@ -194,7 +196,7 @@ func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(str
 
 				tokens, err := countTokens(tmpl, prompts[0].System, prompts[0].Prompt, prompts[0].Response, encode)
 				if err != nil {
-					return "", err
+					return "", 0, err
 				}
 
 				prompts[0].tokens = tokens + len(prompts[0].images)*768
@@ -212,10 +214,10 @@ func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(str
 		// last prompt should leave the response unrendered (for completion)
 		rendered, err := Prompt(tmpl, p.System, p.Prompt, p.Response, i == len(prompts)-1)
 		if err != nil {
-			return "", err
+			return "", 0, err
 		}
 		sb.WriteString(rendered)
 	}
 
-	return sb.String(), nil
+	return sb.String(), required, nil
 }

+ 1 - 1
server/prompt_test.go

@@ -192,7 +192,7 @@ func TestChatPrompt(t *testing.T) {
 
 	for _, tc := range tests {
 		t.Run(tc.name, func(t *testing.T) {
-			got, err := ChatPrompt(tc.template, tc.messages, tc.window, encode)
+			got, _, err := ChatPrompt(tc.template, tc.messages, tc.window, encode)
 			if err != nil {
 				t.Errorf("error = %v", err)
 			}

+ 34 - 20
server/routes.go

@@ -234,9 +234,10 @@ func GenerateHandler(c *gin.Context) {
 	// of `raw` mode so we need to check for it too
 	if req.Prompt == "" && req.Template == "" && req.System == "" {
 		c.JSON(http.StatusOK, api.GenerateResponse{
-			CreatedAt: time.Now().UTC(),
-			Model:     req.Model,
-			Done:      true,
+			CreatedAt:  time.Now().UTC(),
+			Model:      req.Model,
+			Done:       true,
+			DoneReason: "load",
 		})
 		return
 	}
@@ -289,6 +290,14 @@ func GenerateHandler(c *gin.Context) {
 		prompt = sb.String()
 	}
 
+	tokens, err := loaded.llama.Tokenize(c.Request.Context(), prompt)
+	if err != nil {
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		return
+	}
+
+	opts.NumPredict = max(opts.NumCtx-len(tokens), 0)
+
 	slog.Debug("generate handler", "prompt", prompt)
 
 	ch := make(chan any)
@@ -307,10 +316,11 @@ func GenerateHandler(c *gin.Context) {
 			}
 
 			resp := api.GenerateResponse{
-				Model:     req.Model,
-				CreatedAt: time.Now().UTC(),
-				Done:      r.Done,
-				Response:  r.Content,
+				Model:      req.Model,
+				CreatedAt:  time.Now().UTC(),
+				Done:       r.Done,
+				DoneReason: r.DoneReason,
+				Response:   r.Content,
 				Metrics: api.Metrics{
 					PromptEvalCount:    r.PromptEvalCount,
 					PromptEvalDuration: r.PromptEvalDuration,
@@ -1219,17 +1229,17 @@ func streamResponse(c *gin.Context, ch chan any) {
 }
 
 // ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
-func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, error) {
+func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, int, error) {
 	encode := func(s string) ([]int, error) {
 		return loaded.llama.Tokenize(ctx, s)
 	}
 
-	prompt, err := ChatPrompt(template, messages, numCtx, encode)
+	prompt, tokens, err := ChatPrompt(template, messages, numCtx, encode)
 	if err != nil {
-		return "", err
+		return "", 0, err
 	}
 
-	return prompt, nil
+	return prompt, tokens, nil
 }
 
 func ChatHandler(c *gin.Context) {
@@ -1309,19 +1319,22 @@ func ChatHandler(c *gin.Context) {
 		}, req.Messages...)
 	}
 
-	prompt, err := chatPrompt(c.Request.Context(), model.Template, req.Messages, opts.NumCtx)
+	prompt, tokens, err := chatPrompt(c.Request.Context(), model.Template, req.Messages, opts.NumCtx)
 	if err != nil {
 		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		return
 	}
 
+	opts.NumPredict = max(opts.NumCtx-tokens, 0)
+
 	// an empty request loads the model
 	if len(req.Messages) == 0 || prompt == "" {
 		resp := api.ChatResponse{
-			CreatedAt: time.Now().UTC(),
-			Model:     req.Model,
-			Done:      true,
-			Message:   api.Message{Role: "assistant"},
+			CreatedAt:  time.Now().UTC(),
+			Model:      req.Model,
+			Done:       true,
+			DoneReason: "load",
+			Message:    api.Message{Role: "assistant"},
 		}
 		c.JSON(http.StatusOK, resp)
 		return
@@ -1356,10 +1369,11 @@ func ChatHandler(c *gin.Context) {
 			loaded.expireTimer.Reset(sessionDuration)
 
 			resp := api.ChatResponse{
-				Model:     req.Model,
-				CreatedAt: time.Now().UTC(),
-				Message:   api.Message{Role: "assistant", Content: r.Content},
-				Done:      r.Done,
+				Model:      req.Model,
+				CreatedAt:  time.Now().UTC(),
+				Message:    api.Message{Role: "assistant", Content: r.Content},
+				Done:       r.Done,
+				DoneReason: r.DoneReason,
 				Metrics: api.Metrics{
 					PromptEvalCount:    r.PromptEvalCount,
 					PromptEvalDuration: r.PromptEvalDuration,