prompt.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. package server
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/binary"
  6. "errors"
  7. "fmt"
  8. "log/slog"
  9. "strings"
  10. "github.com/ollama/ollama/api"
  11. "github.com/ollama/ollama/llm"
  12. "github.com/ollama/ollama/server/imageproc"
  13. "github.com/ollama/ollama/template"
  14. )
  15. type tokenizeFunc func(context.Context, string) ([]int, error)
  16. var errTooManyImages = errors.New("vision model only supports a single image per message")
  17. // chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
  18. // chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
  19. // latest message and 2) system messages
  20. func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
  21. var system []api.Message
  22. isMllama := checkMllamaModelFamily(m)
  23. n := len(msgs) - 1
  24. // in reverse, find all messages that fit into context window
  25. for i := n; i >= 0; i-- {
  26. if isMllama && len(msgs[i].Images) > 1 {
  27. return "", nil, errTooManyImages
  28. }
  29. // always include the last message
  30. if i == n {
  31. continue
  32. }
  33. system = make([]api.Message, 0)
  34. for j := range i {
  35. if msgs[j].Role == "system" {
  36. system = append(system, msgs[j])
  37. }
  38. }
  39. var b bytes.Buffer
  40. if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil {
  41. return "", nil, err
  42. }
  43. s, err := tokenize(ctx, b.String())
  44. if err != nil {
  45. return "", nil, err
  46. }
  47. ctxLen := len(s)
  48. if m.ProjectorPaths != nil {
  49. for _, m := range msgs[i:] {
  50. // images are represented as 768 sized embeddings
  51. // TODO: get embedding length from project metadata
  52. ctxLen += 768 * len(m.Images)
  53. }
  54. }
  55. if ctxLen > opts.NumCtx {
  56. slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
  57. break
  58. } else {
  59. n = i
  60. }
  61. }
  62. currMsgIdx := n
  63. if isMllama {
  64. lastMsgIdx := len(msgs) - 1
  65. for i := lastMsgIdx; i >= currMsgIdx; i-- {
  66. if len(msgs[i].Images) > 0 {
  67. data, aspectRatioID, err := imageproc.Preprocess(msgs[i].Images[0])
  68. if err != nil {
  69. return "", nil, err
  70. }
  71. buf := new(bytes.Buffer)
  72. err = binary.Write(buf, binary.LittleEndian, data)
  73. if err != nil {
  74. return "", nil, err
  75. }
  76. imgData := llm.ImageData{
  77. Data: buf.Bytes(),
  78. AspectRatioID: aspectRatioID,
  79. }
  80. msgs[i].Content = strings.TrimSpace("<|image|>" + msgs[i].Content)
  81. images = append(images, imgData)
  82. break
  83. }
  84. }
  85. } else {
  86. for cnt, msg := range msgs[currMsgIdx:] {
  87. prefix := ""
  88. prompt := msg.Content
  89. for _, i := range msg.Images {
  90. imgData := llm.ImageData{
  91. ID: len(images),
  92. Data: i,
  93. }
  94. imgTag := fmt.Sprintf("[img-%d]", imgData.ID)
  95. if !strings.Contains(prompt, "[img]") {
  96. prefix += imgTag
  97. } else {
  98. prompt = strings.Replace(prompt, "[img]", imgTag, 1)
  99. }
  100. images = append(images, imgData)
  101. }
  102. msgs[currMsgIdx+cnt].Content = strings.TrimSpace(prefix + " " + prompt)
  103. }
  104. }
  105. // truncate any messages that do not fit into the context window
  106. var b bytes.Buffer
  107. if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools}); err != nil {
  108. return "", nil, err
  109. }
  110. return b.String(), images, nil
  111. }
  112. func checkMllamaModelFamily(m *Model) bool {
  113. for _, arch := range m.Config.ModelFamilies {
  114. if arch == "mllama" {
  115. return true
  116. }
  117. }
  118. return false
  119. }