prompt.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. package server
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/binary"
  6. "errors"
  7. "fmt"
  8. "log/slog"
  9. "strings"
  10. "github.com/ollama/ollama/api"
  11. "github.com/ollama/ollama/llm"
  12. "github.com/ollama/ollama/model/models/mllama"
  13. "github.com/ollama/ollama/template"
  14. )
  15. type tokenizeFunc func(context.Context, string) ([]int, error)
  16. var errTooManyImages = errors.New("vision model only supports a single image per message")
  17. // chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
  18. // chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
  19. // latest message and 2) system messages
  20. func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
  21. var system []api.Message
  22. isMllama := checkMllamaModelFamily(m)
  23. var imageNumTokens int
  24. // TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
  25. if isMllama {
  26. // Our mllama implementation packs all of the embeddings into a single token
  27. imageNumTokens = 1
  28. } else {
  29. // Clip images are represented as 768 tokens, each an embedding
  30. imageNumTokens = 768
  31. }
  32. n := len(msgs) - 1
  33. // in reverse, find all messages that fit into context window
  34. for i := n; i >= 0; i-- {
  35. if isMllama && len(msgs[i].Images) > 1 {
  36. return "", nil, errTooManyImages
  37. }
  38. // always include the last message
  39. if i == n {
  40. continue
  41. }
  42. system = make([]api.Message, 0)
  43. for j := range i {
  44. if msgs[j].Role == "system" {
  45. system = append(system, msgs[j])
  46. }
  47. }
  48. var b bytes.Buffer
  49. if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil {
  50. return "", nil, err
  51. }
  52. s, err := tokenize(ctx, b.String())
  53. if err != nil {
  54. return "", nil, err
  55. }
  56. ctxLen := len(s)
  57. if m.ProjectorPaths != nil {
  58. for _, m := range msgs[i:] {
  59. ctxLen += imageNumTokens * len(m.Images)
  60. }
  61. }
  62. if ctxLen > opts.NumCtx {
  63. slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
  64. break
  65. } else {
  66. n = i
  67. }
  68. }
  69. currMsgIdx := n
  70. for cnt, msg := range msgs[currMsgIdx:] {
  71. prefix := ""
  72. imgPrompt := ""
  73. prompt := msg.Content
  74. for _, i := range msg.Images {
  75. var imgData llm.ImageData
  76. if isMllama {
  77. if len(m.ProjectorPaths) == 0 {
  78. imgData = llm.ImageData{
  79. ID: len(images),
  80. Data: i,
  81. }
  82. } else {
  83. data, opts, err := mllama.Preprocess(bytes.NewReader(i))
  84. if err != nil {
  85. return "", nil, err
  86. }
  87. buf := new(bytes.Buffer)
  88. err = binary.Write(buf, binary.LittleEndian, data)
  89. if err != nil {
  90. return "", nil, err
  91. }
  92. ar, ok := opts["aspectRatioIndex"].(int)
  93. if !ok {
  94. return "", nil, fmt.Errorf("missing aspect ratio for image")
  95. }
  96. imgData = llm.ImageData{
  97. ID: len(images),
  98. Data: buf.Bytes(),
  99. AspectRatioID: ar,
  100. }
  101. }
  102. imgPrompt = "<|image|>"
  103. } else {
  104. imgData = llm.ImageData{
  105. ID: len(images),
  106. Data: i,
  107. }
  108. }
  109. imgTag := fmt.Sprintf("[img-%d]", imgData.ID)
  110. if !strings.Contains(prompt, "[img]") {
  111. prefix += imgTag
  112. } else {
  113. prompt = strings.Replace(prompt, "[img]", imgTag, 1)
  114. }
  115. images = append(images, imgData)
  116. }
  117. msgs[currMsgIdx+cnt].Content = prefix + imgPrompt + prompt
  118. }
  119. // truncate any messages that do not fit into the context window
  120. var b bytes.Buffer
  121. if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools}); err != nil {
  122. return "", nil, err
  123. }
  124. return b.String(), images, nil
  125. }
  126. func checkMllamaModelFamily(m *Model) bool {
  127. for _, arch := range m.Config.ModelFamilies {
  128. if arch == "mllama" {
  129. return true
  130. }
  131. }
  132. return false
  133. }