|
@@ -244,6 +244,10 @@ func GenerateHandler(c *gin.Context) {
|
|
|
promptVars.System = model.System
|
|
|
}
|
|
|
|
|
|
+ for i := range req.Images {
|
|
|
+ promptVars.Prompt += fmt.Sprintf(" [img-%d]", i)
|
|
|
+ }
|
|
|
+
|
|
|
p, err := model.PreResponsePrompt(promptVars)
|
|
|
if err != nil {
|
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
@@ -308,11 +312,19 @@ func GenerateHandler(c *gin.Context) {
|
|
|
ch <- resp
|
|
|
}
|
|
|
|
|
|
+ var images []llm.ImageData
|
|
|
+ for i := range req.Images {
|
|
|
+ images = append(images, llm.ImageData{
|
|
|
+ ID: i,
|
|
|
+ Data: req.Images[i],
|
|
|
+ })
|
|
|
+ }
|
|
|
+
|
|
|
// Start prediction
|
|
|
predictReq := llm.PredictOpts{
|
|
|
Prompt: prompt,
|
|
|
Format: req.Format,
|
|
|
- Images: req.Images,
|
|
|
+ Images: images,
|
|
|
Options: opts,
|
|
|
}
|
|
|
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
|
@@ -1139,7 +1151,8 @@ func ChatHandler(c *gin.Context) {
|
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
|
return
|
|
|
}
|
|
|
- prompt, err := trimmedPrompt(c.Request.Context(), chat, model)
|
|
|
+
|
|
|
+ prompt, images, err := trimmedPrompt(c.Request.Context(), chat, model)
|
|
|
if err != nil {
|
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
return
|
|
@@ -1182,7 +1195,7 @@ func ChatHandler(c *gin.Context) {
|
|
|
predictReq := llm.PredictOpts{
|
|
|
Prompt: prompt,
|
|
|
Format: req.Format,
|
|
|
- Images: chat.CurrentImages,
|
|
|
+ Images: images,
|
|
|
Options: opts,
|
|
|
}
|
|
|
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
|
@@ -1229,34 +1242,47 @@ type promptInfo struct {
|
|
|
|
|
|
// trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length,
|
|
|
// while preserving the most recent system message.
|
|
|
-func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, error) {
|
|
|
+func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, []llm.ImageData, error) {
|
|
|
if len(chat.Prompts) == 0 {
|
|
|
- return "", nil
|
|
|
+ return "", nil, nil
|
|
|
}
|
|
|
|
|
|
var promptsToAdd []promptInfo
|
|
|
var totalTokenLength int
|
|
|
var systemPromptIncluded bool
|
|
|
|
|
|
+ var images []llm.ImageData
|
|
|
// reverse iterate through the prompts to build the prompt string in a way that fits the max context length
|
|
|
for i := len(chat.Prompts) - 1; i >= 0; i-- {
|
|
|
- promptText, err := promptString(model, chat.Prompts[i], i == len(chat.Prompts)-1)
|
|
|
+ prompt := chat.Prompts[i]
|
|
|
+ promptText, err := promptString(model, prompt, i == len(chat.Prompts)-1)
|
|
|
if err != nil {
|
|
|
- return "", err
|
|
|
+ return "", nil, err
|
|
|
}
|
|
|
|
|
|
encodedTokens, err := loaded.runner.Encode(ctx, promptText)
|
|
|
if err != nil {
|
|
|
- return "", err
|
|
|
+ return "", nil, err
|
|
|
}
|
|
|
|
|
|
if totalTokenLength+len(encodedTokens) > loaded.NumCtx && i != len(chat.Prompts)-1 {
|
|
|
break // reached max context length, stop adding more prompts
|
|
|
}
|
|
|
|
|
|
+ for j := range prompt.Images {
|
|
|
+ if totalTokenLength+768 > loaded.NumCtx {
|
|
|
+ // this decreases the token length but overestimating is fine
|
|
|
+ prompt.Prompt = strings.ReplaceAll(prompt.Prompt, fmt.Sprintf(" [img-%d]", prompt.Images[j].ID), "")
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ totalTokenLength += 768
|
|
|
+ images = append(images, prompt.Images[j])
|
|
|
+ }
|
|
|
+
|
|
|
totalTokenLength += len(encodedTokens)
|
|
|
- systemPromptIncluded = systemPromptIncluded || chat.Prompts[i].System != ""
|
|
|
- promptsToAdd = append(promptsToAdd, promptInfo{vars: chat.Prompts[i], tokenLen: len(encodedTokens)})
|
|
|
+ systemPromptIncluded = systemPromptIncluded || prompt.System != ""
|
|
|
+ promptsToAdd = append(promptsToAdd, promptInfo{vars: prompt, tokenLen: len(encodedTokens)})
|
|
|
}
|
|
|
|
|
|
// ensure the system prompt is included, if not already
|
|
@@ -1264,7 +1290,7 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
|
|
|
var err error
|
|
|
promptsToAdd, err = includeSystemPrompt(ctx, chat.LastSystem, totalTokenLength, promptsToAdd)
|
|
|
if err != nil {
|
|
|
- return "", err
|
|
|
+ return "", nil, err
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -1275,11 +1301,12 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
|
|
|
for i, prompt := range promptsToAdd {
|
|
|
promptText, err := promptString(model, prompt.vars, i == 0)
|
|
|
if err != nil {
|
|
|
- return "", err
|
|
|
+ return "", nil, err
|
|
|
}
|
|
|
result = promptText + result
|
|
|
}
|
|
|
- return result, nil
|
|
|
+
|
|
|
+ return result, images, nil
|
|
|
}
|
|
|
|
|
|
// promptString applies the model template to the prompt
|