prompt.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. package server
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/binary"
  6. "errors"
  7. "fmt"
  8. "log/slog"
  9. "strings"
  10. "github.com/ollama/ollama/api"
  11. "github.com/ollama/ollama/envconfig"
  12. "github.com/ollama/ollama/llm"
  13. "github.com/ollama/ollama/model/models/mllama"
  14. "github.com/ollama/ollama/template"
  15. )
  16. type tokenizeFunc func(context.Context, string) ([]int, error)
  17. var errTooManyImages = errors.New("vision model only supports a single image per message")
  18. // chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
  19. // chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
  20. // latest message and 2) system messages
  21. func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
  22. var system []api.Message
  23. isMllama := checkMllamaModelFamily(m)
  24. var imageNumTokens int
  25. // TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
  26. if isMllama {
  27. // Our mllama implementation packs all of the embeddings into a single token
  28. imageNumTokens = 1
  29. } else {
  30. // Clip images are represented as 768 tokens, each an embedding
  31. imageNumTokens = 768
  32. }
  33. n := len(msgs) - 1
  34. // in reverse, find all messages that fit into context window
  35. for i := n; i >= 0; i-- {
  36. if isMllama && len(msgs[i].Images) > 1 {
  37. return "", nil, errTooManyImages
  38. }
  39. // always include the last message
  40. if i == n {
  41. continue
  42. }
  43. system = make([]api.Message, 0)
  44. for j := range i {
  45. if msgs[j].Role == "system" {
  46. system = append(system, msgs[j])
  47. }
  48. }
  49. var b bytes.Buffer
  50. if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil {
  51. return "", nil, err
  52. }
  53. s, err := tokenize(ctx, b.String())
  54. if err != nil {
  55. return "", nil, err
  56. }
  57. ctxLen := len(s)
  58. if m.ProjectorPaths != nil {
  59. for _, m := range msgs[i:] {
  60. ctxLen += imageNumTokens * len(m.Images)
  61. }
  62. }
  63. if ctxLen > opts.NumCtx {
  64. slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
  65. break
  66. } else {
  67. n = i
  68. }
  69. }
  70. currMsgIdx := n
  71. for cnt, msg := range msgs[currMsgIdx:] {
  72. prefix := ""
  73. imgPrompt := ""
  74. prompt := msg.Content
  75. for _, i := range msg.Images {
  76. var imgData llm.ImageData
  77. if isMllama {
  78. if envconfig.NewEngine() {
  79. imgData = llm.ImageData{
  80. ID: len(images),
  81. Data: i,
  82. }
  83. } else {
  84. data, opts, err := mllama.Preprocess(bytes.NewReader(i))
  85. if err != nil {
  86. return "", nil, err
  87. }
  88. buf := new(bytes.Buffer)
  89. err = binary.Write(buf, binary.LittleEndian, data)
  90. if err != nil {
  91. return "", nil, err
  92. }
  93. ar, ok := opts["aspectRatioIndex"].(int)
  94. if !ok {
  95. return "", nil, fmt.Errorf("missing aspect ratio for image")
  96. }
  97. imgData = llm.ImageData{
  98. ID: len(images),
  99. Data: buf.Bytes(),
  100. AspectRatioID: ar,
  101. }
  102. }
  103. imgPrompt = "<|image|>"
  104. } else {
  105. imgData = llm.ImageData{
  106. ID: len(images),
  107. Data: i,
  108. }
  109. }
  110. imgTag := fmt.Sprintf("[img-%d]", imgData.ID)
  111. if !strings.Contains(prompt, "[img]") {
  112. prefix += imgTag
  113. } else {
  114. prompt = strings.Replace(prompt, "[img]", imgTag, 1)
  115. }
  116. images = append(images, imgData)
  117. }
  118. msgs[currMsgIdx+cnt].Content = prefix + imgPrompt + prompt
  119. }
  120. // truncate any messages that do not fit into the context window
  121. var b bytes.Buffer
  122. if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools}); err != nil {
  123. return "", nil, err
  124. }
  125. return b.String(), images, nil
  126. }
  127. func checkMllamaModelFamily(m *Model) bool {
  128. for _, arch := range m.Config.ModelFamilies {
  129. if arch == "mllama" {
  130. return true
  131. }
  132. }
  133. return false
  134. }