prompt.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. package server
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/binary"
  6. "fmt"
  7. "log/slog"
  8. "strings"
  9. "github.com/ollama/ollama/api"
  10. "github.com/ollama/ollama/llm"
  11. "github.com/ollama/ollama/server/imageproc"
  12. "github.com/ollama/ollama/template"
  13. )
  14. type tokenizeFunc func(context.Context, string) ([]int, error)
  15. // chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
  16. // chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
  17. // latest message and 2) system messages
  18. func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
  19. var system []api.Message
  20. // always include the last message
  21. n := len(msgs) - 1
  22. // in reverse, find all messages that fit into context window
  23. for i := n - 1; i >= 0; i-- {
  24. system = make([]api.Message, 0)
  25. for j := range i {
  26. if msgs[j].Role == "system" {
  27. system = append(system, msgs[j])
  28. }
  29. }
  30. var b bytes.Buffer
  31. if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil {
  32. return "", nil, err
  33. }
  34. s, err := tokenize(ctx, b.String())
  35. if err != nil {
  36. return "", nil, err
  37. }
  38. ctxLen := len(s)
  39. if m.ProjectorPaths != nil {
  40. for _, m := range msgs[i:] {
  41. // images are represented as 768 sized embeddings
  42. // TODO: get embedding length from project metadata
  43. ctxLen += 768 * len(m.Images)
  44. }
  45. }
  46. if ctxLen > opts.NumCtx {
  47. slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
  48. break
  49. } else {
  50. n = i
  51. }
  52. }
  53. currMsgIdx := n
  54. if checkMllamaModelFamily(m) {
  55. lastMsgIdx := len(msgs) - 1
  56. if len(msgs[lastMsgIdx].Images) == 1 {
  57. data, aspectRatioID, err := imageproc.Preprocess(msgs[lastMsgIdx].Images[0])
  58. if err != nil {
  59. return "", nil, err
  60. }
  61. buf := new(bytes.Buffer)
  62. err = binary.Write(buf, binary.LittleEndian, data)
  63. if err != nil {
  64. return "", nil, err
  65. }
  66. imgData := llm.ImageData{
  67. Data: buf.Bytes(),
  68. AspectRatioID: aspectRatioID,
  69. }
  70. msgs[lastMsgIdx].Content = strings.TrimSpace("<|image|>" + msgs[lastMsgIdx].Content)
  71. images = append(images, imgData)
  72. }
  73. } else {
  74. for cnt, msg := range msgs[currMsgIdx:] {
  75. for _, i := range msg.Images {
  76. imgData := llm.ImageData{
  77. ID: len(images),
  78. Data: i,
  79. }
  80. imageTag := fmt.Sprintf("[img-%d]", imgData.ID)
  81. prompt := msg.Content
  82. if !strings.Contains(prompt, "[img]") {
  83. prompt = strings.TrimSpace("[img] " + prompt)
  84. }
  85. prompt = strings.Replace(prompt, "[img]", imageTag, 1)
  86. msgs[currMsgIdx+cnt].Content = prompt
  87. images = append(images, imgData)
  88. }
  89. }
  90. }
  91. // truncate any messages that do not fit into the context window
  92. var b bytes.Buffer
  93. if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools}); err != nil {
  94. return "", nil, err
  95. }
  96. return b.String(), images, nil
  97. }
  98. func checkMllamaModelFamily(m *Model) bool {
  99. for _, arch := range m.Config.ModelFamilies {
  100. if arch == "mllama" {
  101. return true
  102. }
  103. }
  104. return false
  105. }