|
@@ -214,6 +214,8 @@ func GenerateHandler(c *gin.Context) {
|
|
}
|
|
}
|
|
|
|
|
|
// an empty request loads the model
|
|
// an empty request loads the model
|
|
|
|
+ // note: for a short while template was used in lieu
|
|
|
|
+ // of `raw` mode so we need to check for it too
|
|
if req.Prompt == "" && req.Template == "" && req.System == "" {
|
|
if req.Prompt == "" && req.Template == "" && req.System == "" {
|
|
c.JSON(http.StatusOK, api.GenerateResponse{
|
|
c.JSON(http.StatusOK, api.GenerateResponse{
|
|
CreatedAt: time.Now().UTC(),
|
|
CreatedAt: time.Now().UTC(),
|
|
@@ -226,50 +228,48 @@ func GenerateHandler(c *gin.Context) {
|
|
checkpointLoaded := time.Now()
|
|
checkpointLoaded := time.Now()
|
|
|
|
|
|
var prompt string
|
|
var prompt string
|
|
- var promptVars PromptVars
|
|
|
|
switch {
|
|
switch {
|
|
case req.Raw:
|
|
case req.Raw:
|
|
prompt = req.Prompt
|
|
prompt = req.Prompt
|
|
case req.Prompt != "":
|
|
case req.Prompt != "":
|
|
- if req.Template != "" {
|
|
|
|
- // override the default model template
|
|
|
|
- model.Template = req.Template
|
|
|
|
|
|
+ if req.Template == "" {
|
|
|
|
+ req.Template = model.Template
|
|
}
|
|
}
|
|
|
|
|
|
- var rebuild strings.Builder
|
|
|
|
|
|
+ if req.System == "" {
|
|
|
|
+ req.System = model.System
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ slog.Debug("generate handler", "prompt", req.Prompt)
|
|
|
|
+ slog.Debug("generate handler", "template", req.Template)
|
|
|
|
+ slog.Debug("generate handler", "system", req.System)
|
|
|
|
+
|
|
|
|
+ var sb strings.Builder
|
|
if req.Context != nil {
|
|
if req.Context != nil {
|
|
- // TODO: context is deprecated, at some point the context logic within this conditional should be removed
|
|
|
|
- prevCtx, err := loaded.runner.Decode(c.Request.Context(), req.Context)
|
|
|
|
|
|
+ prev, err := loaded.runner.Decode(c.Request.Context(), req.Context)
|
|
if err != nil {
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
- // Remove leading spaces from prevCtx if present
|
|
|
|
- prevCtx = strings.TrimPrefix(prevCtx, " ")
|
|
|
|
- rebuild.WriteString(prevCtx)
|
|
|
|
- }
|
|
|
|
- promptVars = PromptVars{
|
|
|
|
- System: req.System,
|
|
|
|
- Prompt: req.Prompt,
|
|
|
|
- First: len(req.Context) == 0,
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- if promptVars.System == "" {
|
|
|
|
- promptVars.System = model.System
|
|
|
|
|
|
+ sb.WriteString(prev)
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ // write image tags
|
|
|
|
+ // TODO: limit the number of images to fit in the context similar to the chat endpoint
|
|
for i := range req.Images {
|
|
for i := range req.Images {
|
|
- promptVars.Prompt += fmt.Sprintf(" [img-%d]", i)
|
|
|
|
|
|
+ req.Prompt += fmt.Sprintf(" [img-%d]", i)
|
|
}
|
|
}
|
|
|
|
|
|
- p, err := model.PreResponsePrompt(promptVars)
|
|
|
|
|
|
+ p, err := Prompt(req.Template, req.System, req.Prompt, "", true)
|
|
if err != nil {
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
return
|
|
}
|
|
}
|
|
- rebuild.WriteString(p)
|
|
|
|
- prompt = rebuild.String()
|
|
|
|
|
|
+
|
|
|
|
+ sb.WriteString(p)
|
|
|
|
+
|
|
|
|
+ prompt = sb.String()
|
|
}
|
|
}
|
|
|
|
|
|
slog.Debug("generate handler", "prompt", prompt)
|
|
slog.Debug("generate handler", "prompt", prompt)
|
|
@@ -308,19 +308,20 @@ func GenerateHandler(c *gin.Context) {
|
|
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
|
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
|
|
|
|
|
if !req.Raw {
|
|
if !req.Raw {
|
|
- // append the generated text to the history and template it if needed
|
|
|
|
- promptVars.Response = generated.String()
|
|
|
|
- result, err := model.PostResponseTemplate(promptVars)
|
|
|
|
|
|
+ p, err := Prompt(req.Template, req.System, req.Prompt, generated.String(), false)
|
|
if err != nil {
|
|
if err != nil {
|
|
- ch <- gin.H{"error": err.Error()}
|
|
|
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
return
|
|
}
|
|
}
|
|
- embd, err := loaded.runner.Encode(c.Request.Context(), prompt+result)
|
|
|
|
|
|
+
|
|
|
|
+ // TODO (jmorganca): encode() should not strip special tokens
|
|
|
|
+ tokens, err := loaded.runner.Encode(c.Request.Context(), p)
|
|
if err != nil {
|
|
if err != nil {
|
|
ch <- gin.H{"error": err.Error()}
|
|
ch <- gin.H{"error": err.Error()}
|
|
return
|
|
return
|
|
}
|
|
}
|
|
- resp.Context = embd
|
|
|
|
|
|
+
|
|
|
|
+ resp.Context = append(req.Context, tokens...)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
@@ -1090,6 +1091,20 @@ func streamResponse(c *gin.Context, ch chan any) {
|
|
})
|
|
})
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
|
|
|
|
+func chatPrompt(ctx context.Context, messages []api.Message) (string, error) {
|
|
|
|
+ encode := func(s string) ([]int, error) {
|
|
|
|
+ return loaded.runner.Encode(ctx, s)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ prompt, err := ChatPrompt(loaded.Model.Template, loaded.Model.System, messages, loaded.Options.NumCtx, encode)
|
|
|
|
+ if err != nil {
|
|
|
|
+ return "", err
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ return prompt, nil
|
|
|
|
+}
|
|
|
|
+
|
|
func ChatHandler(c *gin.Context) {
|
|
func ChatHandler(c *gin.Context) {
|
|
loaded.mu.Lock()
|
|
loaded.mu.Lock()
|
|
defer loaded.mu.Unlock()
|
|
defer loaded.mu.Unlock()
|
|
@@ -1117,15 +1132,6 @@ func ChatHandler(c *gin.Context) {
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
- for _, msg := range req.Messages {
|
|
|
|
- for _, img := range msg.Images {
|
|
|
|
- if !isSupportedImageType(img) {
|
|
|
|
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
|
|
|
|
- return
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
model, err := GetModel(req.Model)
|
|
model, err := GetModel(req.Model)
|
|
if err != nil {
|
|
if err != nil {
|
|
var pErr *fs.PathError
|
|
var pErr *fs.PathError
|
|
@@ -1161,20 +1167,14 @@ func ChatHandler(c *gin.Context) {
|
|
|
|
|
|
checkpointLoaded := time.Now()
|
|
checkpointLoaded := time.Now()
|
|
|
|
|
|
- chat, err := model.ChatPrompts(req.Messages)
|
|
|
|
|
|
+ prompt, err := chatPrompt(c.Request.Context(), req.Messages)
|
|
if err != nil {
|
|
if err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
- prompt, images, err := trimmedPrompt(c.Request.Context(), chat, model)
|
|
|
|
- if err != nil {
|
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
|
- return
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
// an empty request loads the model
|
|
// an empty request loads the model
|
|
- if len(prompt) == 0 {
|
|
|
|
|
|
+ if len(req.Messages) == 0 || prompt == "" {
|
|
resp := api.ChatResponse{
|
|
resp := api.ChatResponse{
|
|
CreatedAt: time.Now().UTC(),
|
|
CreatedAt: time.Now().UTC(),
|
|
Model: req.Model,
|
|
Model: req.Model,
|
|
@@ -1185,7 +1185,24 @@ func ChatHandler(c *gin.Context) {
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
- slog.Debug("chat handler", "prompt", prompt)
|
|
|
|
|
|
+ // only send images that are in the prompt
|
|
|
|
+ var i int
|
|
|
|
+ var images []llm.ImageData
|
|
|
|
+ for _, m := range req.Messages {
|
|
|
|
+ for _, img := range m.Images {
|
|
|
|
+ if !isSupportedImageType(img) {
|
|
|
|
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if strings.Contains(prompt, fmt.Sprintf("[img-%d]", i)) {
|
|
|
|
+ images = append(images, llm.ImageData{Data: img, ID: i})
|
|
|
|
+ }
|
|
|
|
+ i += 1
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ slog.Debug("chat handler", "prompt", prompt, "images", len(images))
|
|
|
|
|
|
ch := make(chan any)
|
|
ch := make(chan any)
|
|
|
|
|
|
@@ -1260,115 +1277,3 @@ func ChatHandler(c *gin.Context) {
|
|
|
|
|
|
streamResponse(c, ch)
|
|
streamResponse(c, ch)
|
|
}
|
|
}
|
|
-
|
|
|
|
-// promptInfo stores the variables used to template a prompt, and the token length of the resulting template for some model
|
|
|
|
-type promptInfo struct {
|
|
|
|
- vars PromptVars
|
|
|
|
- tokenLen int
|
|
|
|
-}
|
|
|
|
-
|
|
|
|
-// trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length,
|
|
|
|
-// while preserving the most recent system message.
|
|
|
|
-func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, []llm.ImageData, error) {
|
|
|
|
- if len(chat.Prompts) == 0 {
|
|
|
|
- return "", nil, nil
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- var promptsToAdd []promptInfo
|
|
|
|
- var totalTokenLength int
|
|
|
|
- var systemPromptIncluded bool
|
|
|
|
-
|
|
|
|
- var images []llm.ImageData
|
|
|
|
- // reverse iterate through the prompts to build the prompt string in a way that fits the max context length
|
|
|
|
- for i := len(chat.Prompts) - 1; i >= 0; i-- {
|
|
|
|
- prompt := chat.Prompts[i]
|
|
|
|
- promptText, err := promptString(model, prompt, i == len(chat.Prompts)-1)
|
|
|
|
- if err != nil {
|
|
|
|
- return "", nil, err
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- encodedTokens, err := loaded.runner.Encode(ctx, promptText)
|
|
|
|
- if err != nil {
|
|
|
|
- return "", nil, err
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- if totalTokenLength+len(encodedTokens) > loaded.NumCtx && i != len(chat.Prompts)-1 {
|
|
|
|
- break // reached max context length, stop adding more prompts
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- for j := range prompt.Images {
|
|
|
|
- if totalTokenLength+768 > loaded.NumCtx {
|
|
|
|
- // this decreases the token length but overestimating is fine
|
|
|
|
- prompt.Prompt = strings.ReplaceAll(prompt.Prompt, fmt.Sprintf(" [img-%d]", prompt.Images[j].ID), "")
|
|
|
|
- continue
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- totalTokenLength += 768
|
|
|
|
- images = append(images, prompt.Images[j])
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- totalTokenLength += len(encodedTokens)
|
|
|
|
- systemPromptIncluded = systemPromptIncluded || prompt.System != ""
|
|
|
|
- promptsToAdd = append(promptsToAdd, promptInfo{vars: prompt, tokenLen: len(encodedTokens)})
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- // ensure the system prompt is included, if not already
|
|
|
|
- if chat.LastSystem != "" && !systemPromptIncluded {
|
|
|
|
- var err error
|
|
|
|
- promptsToAdd, err = includeSystemPrompt(ctx, chat.LastSystem, totalTokenLength, promptsToAdd)
|
|
|
|
- if err != nil {
|
|
|
|
- return "", nil, err
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- promptsToAdd[len(promptsToAdd)-1].vars.First = true
|
|
|
|
-
|
|
|
|
- // construct the final prompt string from the prompts which fit within the context window
|
|
|
|
- var result string
|
|
|
|
- for i, prompt := range promptsToAdd {
|
|
|
|
- promptText, err := promptString(model, prompt.vars, i == 0)
|
|
|
|
- if err != nil {
|
|
|
|
- return "", nil, err
|
|
|
|
- }
|
|
|
|
- result = promptText + result
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- return result, images, nil
|
|
|
|
-}
|
|
|
|
-
|
|
|
|
-// promptString applies the model template to the prompt
|
|
|
|
-func promptString(model *Model, vars PromptVars, isMostRecent bool) (string, error) {
|
|
|
|
- if isMostRecent {
|
|
|
|
- p, err := model.PreResponsePrompt(vars)
|
|
|
|
- if err != nil {
|
|
|
|
- return "", fmt.Errorf("pre-response template: %w", err)
|
|
|
|
- }
|
|
|
|
- return p, nil
|
|
|
|
- }
|
|
|
|
- p, err := Prompt(model.Template, vars)
|
|
|
|
- if err != nil {
|
|
|
|
- return "", err
|
|
|
|
- }
|
|
|
|
- return p, nil
|
|
|
|
-}
|
|
|
|
-
|
|
|
|
-// includeSystemPrompt adjusts the prompts to include the system prompt.
|
|
|
|
-func includeSystemPrompt(ctx context.Context, systemPrompt string, totalTokenLength int, promptsToAdd []promptInfo) ([]promptInfo, error) {
|
|
|
|
- systemTokens, err := loaded.runner.Encode(ctx, systemPrompt)
|
|
|
|
- if err != nil {
|
|
|
|
- return nil, err
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- for i := len(promptsToAdd) - 1; i >= 0; i-- {
|
|
|
|
- if totalTokenLength+len(systemTokens) <= loaded.NumCtx {
|
|
|
|
- promptsToAdd[i].vars.System = systemPrompt
|
|
|
|
- return promptsToAdd[:i+1], nil
|
|
|
|
- }
|
|
|
|
- totalTokenLength -= promptsToAdd[i].tokenLen
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- // if got here, system did not fit anywhere, so return the most recent prompt with the system message set
|
|
|
|
- recent := promptsToAdd[len(promptsToAdd)-1]
|
|
|
|
- recent.vars.System = systemPrompt
|
|
|
|
- return []promptInfo{recent}, nil
|
|
|
|
-}
|
|
|