prompt.go 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. package server
  2. import (
  3. "fmt"
  4. "log/slog"
  5. "strings"
  6. "text/template"
  7. "text/template/parse"
  8. "github.com/jmorganca/ollama/api"
  9. )
  10. // isResponseNode checks if the node contains .Response
  11. func isResponseNode(node *parse.ActionNode) bool {
  12. for _, cmd := range node.Pipe.Cmds {
  13. for _, arg := range cmd.Args {
  14. if fieldNode, ok := arg.(*parse.FieldNode); ok && len(fieldNode.Ident) > 0 {
  15. if fieldNode.Ident[0] == "Response" {
  16. return true
  17. }
  18. }
  19. }
  20. }
  21. return false
  22. }
  23. // formatTemplateForResponse formats the template AST to:
  24. // 1. remove all nodes after the first .Response (if generate=true)
  25. // 2. add a .Response node to the end if it doesn't exist
  26. // TODO(jmorganca): this should recursively cut the template before the first .Response
  27. func formatTemplateForResponse(tmpl *template.Template, generate bool) {
  28. var found bool
  29. for i, node := range tmpl.Tree.Root.Nodes {
  30. if actionNode, ok := node.(*parse.ActionNode); ok {
  31. if isResponseNode(actionNode) {
  32. found = true
  33. if generate {
  34. tmpl.Tree.Root.Nodes = tmpl.Tree.Root.Nodes[:i+1]
  35. break
  36. }
  37. }
  38. }
  39. }
  40. if !found {
  41. // add the response node if it doesn't exist
  42. responseFieldNode := &parse.FieldNode{NodeType: parse.NodeField, Ident: []string{"Response"}}
  43. responsePipeNode := &parse.PipeNode{NodeType: parse.NodePipe, Cmds: []*parse.CommandNode{{NodeType: parse.NodeCommand, Args: []parse.Node{responseFieldNode}}}}
  44. responseActionNode := &parse.ActionNode{NodeType: parse.NodeAction, Pipe: responsePipeNode}
  45. tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, responseActionNode)
  46. }
  47. }
  48. // Prompt renders a prompt from a template. If generate is set to true,
  49. // the response and parts of the template following it are not rendered
  50. func Prompt(tmpl, system, prompt, response string, generate bool) (string, error) {
  51. parsed, err := template.New("").Option("missingkey=zero").Parse(tmpl)
  52. if err != nil {
  53. return "", err
  54. }
  55. formatTemplateForResponse(parsed, generate)
  56. vars := map[string]any{
  57. "System": system,
  58. "Prompt": prompt,
  59. "Response": response,
  60. }
  61. var sb strings.Builder
  62. if err := parsed.Execute(&sb, vars); err != nil {
  63. return "", err
  64. }
  65. return sb.String(), nil
  66. }
  67. func countTokens(tmpl string, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) {
  68. rendered, err := Prompt(tmpl, system, prompt, response, false)
  69. if err != nil {
  70. return 0, err
  71. }
  72. tokens, err := encode(rendered)
  73. if err != nil {
  74. slog.Error("failed to encode prompt", "err", err)
  75. return 0, err
  76. }
  77. return len(tokens), err
  78. }
  79. // ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
  80. func ChatPrompt(tmpl string, system string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
  81. type prompt struct {
  82. System string
  83. Prompt string
  84. Response string
  85. images []int
  86. tokens int
  87. }
  88. var p prompt
  89. // Set the first system prompt to the model's system prompt
  90. if system != "" {
  91. p.System = system
  92. }
  93. // iterate through messages to build up {system,user,response} prompts
  94. var imgId int
  95. var prompts []prompt
  96. for _, msg := range messages {
  97. switch strings.ToLower(msg.Role) {
  98. case "system":
  99. if p.System != "" || p.Prompt != "" || p.Response != "" {
  100. prompts = append(prompts, p)
  101. p = prompt{}
  102. }
  103. p.System = msg.Content
  104. case "user":
  105. if p.Prompt != "" || p.Response != "" {
  106. prompts = append(prompts, p)
  107. p = prompt{}
  108. }
  109. p.Prompt = msg.Content
  110. for range msg.Images {
  111. p.Prompt += fmt.Sprintf(" [img-%d]", imgId)
  112. p.images = append(p.images, imgId)
  113. imgId += 1
  114. }
  115. case "assistant":
  116. if p.Response != "" {
  117. prompts = append(prompts, p)
  118. p = prompt{}
  119. }
  120. p.Response = msg.Content
  121. default:
  122. return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
  123. }
  124. }
  125. // add final prompt
  126. if p.System != "" || p.Prompt != "" || p.Response != "" {
  127. prompts = append(prompts, p)
  128. }
  129. // calculate token lengths for each prompt, estimating 768 tokens per images
  130. for i, p := range prompts {
  131. tokens, err := countTokens(tmpl, p.System, p.Prompt, p.Response, encode)
  132. if err != nil {
  133. return "", err
  134. }
  135. prompts[i].tokens = tokens + len(prompts[i].images)*768
  136. }
  137. // truncate images and prompts starting from the beginning of the list
  138. // until either one prompt remains or the total tokens fits the context window
  139. // TODO (jmorganca): this doesn't account for the context window room required for the response
  140. for {
  141. var required int
  142. for _, p := range prompts {
  143. required += p.tokens
  144. }
  145. required += 1 // for bos token
  146. if required <= window {
  147. slog.Debug("prompt now fits in context window", "required", required, "window", window)
  148. break
  149. }
  150. prompt := &prompts[0]
  151. if len(prompt.images) > 1 {
  152. img := prompt.images[0]
  153. slog.Debug("prompt longer than context window, removing image", "id", img, "required", required, "window", window)
  154. prompt.images = prompt.images[1:]
  155. prompt.Prompt = strings.Replace(prompt.Prompt, fmt.Sprintf(" [img-%d]", img), "", 1)
  156. prompt.tokens -= 768
  157. continue
  158. }
  159. if len(prompts) > 1 {
  160. slog.Debug("required tokens longer than context window, removing first prompt", "prompt", prompts[0].tokens, "required", required, "window", window)
  161. system := prompt.System
  162. prompts = prompts[1:]
  163. if system != "" && prompts[0].System == "" {
  164. prompts[0].System = system
  165. tokens, err := countTokens(tmpl, prompts[0].System, prompts[0].Prompt, prompts[0].Response, encode)
  166. if err != nil {
  167. return "", err
  168. }
  169. prompts[0].tokens = tokens + len(prompts[0].images)*768
  170. }
  171. continue
  172. }
  173. // stop truncating if there's only one prompt left
  174. break
  175. }
  176. var sb strings.Builder
  177. for i, p := range prompts {
  178. // last prompt should leave the response unrendered (for completion)
  179. rendered, err := Prompt(tmpl, p.System, p.Prompt, p.Response, i == len(prompts)-1)
  180. if err != nil {
  181. return "", err
  182. }
  183. sb.WriteString(rendered)
  184. }
  185. return sb.String(), nil
  186. }