prompt.go 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. package server
  2. import (
  3. "bytes"
  4. "context"
  5. "log/slog"
  6. "github.com/ollama/ollama/api"
  7. "github.com/ollama/ollama/llm"
  8. "github.com/ollama/ollama/server/imageproc"
  9. "github.com/ollama/ollama/template"
  10. )
  11. type tokenizeFunc func(context.Context, string) ([]int, error)
  12. // chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
  13. // chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
  14. // latest message and 2) system messages
  15. func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
  16. var system []api.Message
  17. // always include the last message
  18. n := len(msgs) - 1
  19. // in reverse, find all messages that fit into context window
  20. for i := n - 1; i >= 0; i-- {
  21. system = make([]api.Message, 0)
  22. for j := range i {
  23. if msgs[j].Role == "system" {
  24. system = append(system, msgs[j])
  25. }
  26. }
  27. var b bytes.Buffer
  28. if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil {
  29. return "", nil, err
  30. }
  31. s, err := tokenize(ctx, b.String())
  32. if err != nil {
  33. return "", nil, err
  34. }
  35. c := len(s)
  36. if m.ProjectorPaths != nil {
  37. for _, m := range msgs[i:] {
  38. // images are represented as 768 sized embeddings
  39. // TODO: get embedding length from project metadata
  40. c += 768 * len(m.Images)
  41. }
  42. }
  43. if c > opts.NumCtx {
  44. slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
  45. break
  46. } else {
  47. n = i
  48. }
  49. }
  50. // truncate any messages that do not fit into the context window
  51. var b bytes.Buffer
  52. if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...), Tools: tools}); err != nil {
  53. return "", nil, err
  54. }
  55. preprocess := checkMllamaModelFamily(m)
  56. for _, m := range msgs[n:] {
  57. for _, i := range m.Images {
  58. if preprocess {
  59. data, aspectRatioID, err := imageproc.Preprocess(i)
  60. if err != nil {
  61. return "", nil, err
  62. }
  63. images = append(images, llm.ImageData{
  64. ID: len(images),
  65. ImageData: data,
  66. AspectRatioID: aspectRatioID,
  67. })
  68. } else {
  69. images = append(images, llm.ImageData{
  70. ID: len(images),
  71. Data: i,
  72. })
  73. }
  74. }
  75. }
  76. return b.String(), images, nil
  77. }
  78. func checkMllamaModelFamily(m *Model) bool {
  79. for _, arch := range m.Config.ModelFamilies {
  80. if arch == "mllama" {
  81. return true
  82. }
  83. }
  84. return false
  85. }