|
@@ -234,9 +234,10 @@ func GenerateHandler(c *gin.Context) {
|
|
|
// of `raw` mode so we need to check for it too
|
|
|
if req.Prompt == "" && req.Template == "" && req.System == "" {
|
|
|
c.JSON(http.StatusOK, api.GenerateResponse{
|
|
|
- CreatedAt: time.Now().UTC(),
|
|
|
- Model: req.Model,
|
|
|
- Done: true,
|
|
|
+ CreatedAt: time.Now().UTC(),
|
|
|
+ Model: req.Model,
|
|
|
+ Done: true,
|
|
|
+ DoneReason: "load",
|
|
|
})
|
|
|
return
|
|
|
}
|
|
@@ -289,6 +290,14 @@ func GenerateHandler(c *gin.Context) {
|
|
|
prompt = sb.String()
|
|
|
}
|
|
|
|
|
|
+ tokens, err := loaded.llama.Tokenize(c.Request.Context(), prompt)
|
|
|
+ if err != nil {
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ opts.NumPredict = max(opts.NumCtx-len(tokens), 0)
|
|
|
+
|
|
|
slog.Debug("generate handler", "prompt", prompt)
|
|
|
|
|
|
ch := make(chan any)
|
|
@@ -307,10 +316,11 @@ func GenerateHandler(c *gin.Context) {
|
|
|
}
|
|
|
|
|
|
resp := api.GenerateResponse{
|
|
|
- Model: req.Model,
|
|
|
- CreatedAt: time.Now().UTC(),
|
|
|
- Done: r.Done,
|
|
|
- Response: r.Content,
|
|
|
+ Model: req.Model,
|
|
|
+ CreatedAt: time.Now().UTC(),
|
|
|
+ Done: r.Done,
|
|
|
+ DoneReason: r.DoneReason,
|
|
|
+ Response: r.Content,
|
|
|
Metrics: api.Metrics{
|
|
|
PromptEvalCount: r.PromptEvalCount,
|
|
|
PromptEvalDuration: r.PromptEvalDuration,
|
|
@@ -1219,17 +1229,17 @@ func streamResponse(c *gin.Context, ch chan any) {
|
|
|
}
|
|
|
|
|
|
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
|
|
|
-func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, error) {
|
|
|
+func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, int, error) {
|
|
|
encode := func(s string) ([]int, error) {
|
|
|
return loaded.llama.Tokenize(ctx, s)
|
|
|
}
|
|
|
|
|
|
- prompt, err := ChatPrompt(template, messages, numCtx, encode)
|
|
|
+ prompt, tokens, err := ChatPrompt(template, messages, numCtx, encode)
|
|
|
if err != nil {
|
|
|
- return "", err
|
|
|
+ return "", 0, err
|
|
|
}
|
|
|
|
|
|
- return prompt, nil
|
|
|
+ return prompt, tokens, nil
|
|
|
}
|
|
|
|
|
|
func ChatHandler(c *gin.Context) {
|
|
@@ -1309,19 +1319,22 @@ func ChatHandler(c *gin.Context) {
|
|
|
}, req.Messages...)
|
|
|
}
|
|
|
|
|
|
- prompt, err := chatPrompt(c.Request.Context(), model.Template, req.Messages, opts.NumCtx)
|
|
|
+ prompt, tokens, err := chatPrompt(c.Request.Context(), model.Template, req.Messages, opts.NumCtx)
|
|
|
if err != nil {
|
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
|
return
|
|
|
}
|
|
|
|
|
|
+ opts.NumPredict = max(opts.NumCtx-tokens, 0)
|
|
|
+
|
|
|
// an empty request loads the model
|
|
|
if len(req.Messages) == 0 || prompt == "" {
|
|
|
resp := api.ChatResponse{
|
|
|
- CreatedAt: time.Now().UTC(),
|
|
|
- Model: req.Model,
|
|
|
- Done: true,
|
|
|
- Message: api.Message{Role: "assistant"},
|
|
|
+ CreatedAt: time.Now().UTC(),
|
|
|
+ Model: req.Model,
|
|
|
+ Done: true,
|
|
|
+ DoneReason: "load",
|
|
|
+ Message: api.Message{Role: "assistant"},
|
|
|
}
|
|
|
c.JSON(http.StatusOK, resp)
|
|
|
return
|
|
@@ -1356,10 +1369,11 @@ func ChatHandler(c *gin.Context) {
|
|
|
loaded.expireTimer.Reset(sessionDuration)
|
|
|
|
|
|
resp := api.ChatResponse{
|
|
|
- Model: req.Model,
|
|
|
- CreatedAt: time.Now().UTC(),
|
|
|
- Message: api.Message{Role: "assistant", Content: r.Content},
|
|
|
- Done: r.Done,
|
|
|
+ Model: req.Model,
|
|
|
+ CreatedAt: time.Now().UTC(),
|
|
|
+ Message: api.Message{Role: "assistant", Content: r.Content},
|
|
|
+ Done: r.Done,
|
|
|
+ DoneReason: r.DoneReason,
|
|
|
Metrics: api.Metrics{
|
|
|
PromptEvalCount: r.PromptEvalCount,
|
|
|
PromptEvalDuration: r.PromptEvalDuration,
|