|
@@ -3,7 +3,10 @@ package server
|
|
|
import (
|
|
|
"bytes"
|
|
|
"context"
|
|
|
+ "encoding/binary"
|
|
|
+ "fmt"
|
|
|
"log/slog"
|
|
|
+ "strings"
|
|
|
|
|
|
"github.com/ollama/ollama/api"
|
|
|
"github.com/ollama/ollama/llm"
|
|
@@ -18,6 +21,7 @@ type tokenizeFunc func(context.Context, string) ([]int, error)
|
|
|
// latest message and 2) system messages
|
|
|
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
|
|
|
var system []api.Message
|
|
|
+
|
|
|
// always include the last message
|
|
|
n := len(msgs) - 1
|
|
|
// in reverse, find all messages that fit into context window
|
|
@@ -39,16 +43,16 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|
|
return "", nil, err
|
|
|
}
|
|
|
|
|
|
- c := len(s)
|
|
|
+ ctxLen := len(s)
|
|
|
if m.ProjectorPaths != nil {
|
|
|
for _, m := range msgs[i:] {
|
|
|
// images are represented as 768 sized embeddings
|
|
|
// TODO: get embedding length from project metadata
|
|
|
- c += 768 * len(m.Images)
|
|
|
+ ctxLen += 768 * len(m.Images)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- if c > opts.NumCtx {
|
|
|
+ if ctxLen > opts.NumCtx {
|
|
|
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
|
|
|
break
|
|
|
} else {
|
|
@@ -56,35 +60,58 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- // truncate any messages that do not fit into the context window
|
|
|
- var b bytes.Buffer
|
|
|
- if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...), Tools: tools}); err != nil {
|
|
|
- return "", nil, err
|
|
|
+ currMsgIdx := n
|
|
|
+
|
|
|
+ if checkMllamaModelFamily(m) {
|
|
|
+ lastMsgIdx := len(msgs) - 1
|
|
|
+ if len(msgs[lastMsgIdx].Images) == 1 {
|
|
|
+ data, aspectRatioID, err := imageproc.Preprocess(msgs[lastMsgIdx].Images[0])
|
|
|
+ if err != nil {
|
|
|
+ return "", nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ buf := new(bytes.Buffer)
|
|
|
+ err = binary.Write(buf, binary.LittleEndian, data)
|
|
|
+ if err != nil {
|
|
|
+ return "", nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ imgData := llm.ImageData{
|
|
|
+ Data: buf.Bytes(),
|
|
|
+ AspectRatioID: aspectRatioID,
|
|
|
+ }
|
|
|
+
|
|
|
+ msgs[lastMsgIdx].Content = strings.TrimSpace("<|image|>" + msgs[lastMsgIdx].Content)
|
|
|
+ images = append(images, imgData)
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
- preprocess := checkMllamaModelFamily(m)
|
|
|
-
|
|
|
- for _, m := range msgs[n:] {
|
|
|
- for _, i := range m.Images {
|
|
|
- if preprocess {
|
|
|
- data, aspectRatioID, err := imageproc.Preprocess(i)
|
|
|
- if err != nil {
|
|
|
- return "", nil, err
|
|
|
- }
|
|
|
- images = append(images, llm.ImageData{
|
|
|
- ID: len(images),
|
|
|
- ImageData: data,
|
|
|
- AspectRatioID: aspectRatioID,
|
|
|
- })
|
|
|
- } else {
|
|
|
- images = append(images, llm.ImageData{
|
|
|
- ID: len(images),
|
|
|
- Data: i,
|
|
|
- })
|
|
|
+ for cnt, msg := range msgs[currMsgIdx:] {
|
|
|
+ for _, i := range msg.Images {
|
|
|
+ imgData := llm.ImageData{
|
|
|
+ ID: len(images),
|
|
|
+ Data: i,
|
|
|
}
|
|
|
+
|
|
|
+ imageTag := fmt.Sprintf("[img-%d]", imgData.ID)
|
|
|
+ prompt := msg.Content
|
|
|
+
|
|
|
+ if !strings.Contains(prompt, "[img]") {
|
|
|
+ prompt = strings.TrimSpace("[img] " + prompt)
|
|
|
+ }
|
|
|
+ prompt = strings.Replace(prompt, "[img]", imageTag, 1)
|
|
|
+ msgs[currMsgIdx+cnt].Content = prompt
|
|
|
+
|
|
|
+ images = append(images, imgData)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ // truncate any messages that do not fit into the context window
|
|
|
+ var b bytes.Buffer
|
|
|
+ if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools}); err != nil {
|
|
|
+ return "", nil, err
|
|
|
+ }
|
|
|
+
|
|
|
return b.String(), images, nil
|
|
|
}
|
|
|
|