prompt.go 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. package server
  2. import (
  3. "bytes"
  4. "context"
  5. "log/slog"
  6. "slices"
  7. "github.com/ollama/ollama/api"
  8. "github.com/ollama/ollama/llm"
  9. "github.com/ollama/ollama/template"
  10. )
  11. type tokenizeFunc func(context.Context, string) ([]int, error)
  12. // chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
  13. // chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
  14. // latest message and 2) system messages
  15. func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) {
  16. // pull out any system messages which should always be included in the prompt
  17. var system []api.Message
  18. msgs = slices.DeleteFunc(msgs, func(m api.Message) bool {
  19. if m.Role == "system" {
  20. system = append(system, m)
  21. return true
  22. }
  23. return false
  24. })
  25. if len(system) == 0 && m.System != "" {
  26. // add model system prompt since it wasn't provided
  27. system = append(system, api.Message{Role: "system", Content: m.System})
  28. }
  29. // always include the last message
  30. n := len(msgs) - 1
  31. // in reverse, find all messages that fit into context window
  32. for i := n - 1; i >= 0; i-- {
  33. var b bytes.Buffer
  34. if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil {
  35. return "", nil, err
  36. }
  37. s, err := tokenize(ctx, b.String())
  38. if err != nil {
  39. return "", nil, err
  40. }
  41. c := len(s)
  42. if m.ProjectorPaths != nil {
  43. for _, m := range msgs[i:] {
  44. // images are represented as 768 sized embeddings
  45. // TODO: get embedding length from project metadata
  46. c += 768 * len(m.Images)
  47. }
  48. }
  49. if c > opts.NumCtx {
  50. slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
  51. break
  52. } else {
  53. n = i
  54. }
  55. }
  56. // truncate any messages that do not fit into the context window
  57. var b bytes.Buffer
  58. if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil {
  59. return "", nil, err
  60. }
  61. for _, m := range msgs[n:] {
  62. for _, i := range m.Images {
  63. images = append(images, llm.ImageData{
  64. ID: len(images),
  65. Data: i,
  66. })
  67. }
  68. }
  69. return b.String(), images, nil
  70. }