template.go 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. package template
  2. import (
  3. "bytes"
  4. "embed"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "math"
  10. "slices"
  11. "strings"
  12. "sync"
  13. "text/template"
  14. "text/template/parse"
  15. "github.com/agnivade/levenshtein"
  16. "github.com/ollama/ollama/api"
  17. "golang.org/x/exp/maps"
  18. )
  19. //go:embed index.json
  20. var indexBytes []byte
  21. //go:embed *.gotmpl
  22. var templatesFS embed.FS
  23. var templatesOnce = sync.OnceValues(func() ([]*named, error) {
  24. var templates []*named
  25. if err := json.Unmarshal(indexBytes, &templates); err != nil {
  26. return nil, err
  27. }
  28. for _, t := range templates {
  29. bts, err := templatesFS.ReadFile(t.Name + ".gotmpl")
  30. if err != nil {
  31. return nil, err
  32. }
  33. // normalize line endings
  34. t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n"))
  35. }
  36. return templates, nil
  37. })
  38. type named struct {
  39. Name string `json:"name"`
  40. Template string `json:"template"`
  41. Bytes []byte
  42. }
  43. func (t named) Reader() io.Reader {
  44. return bytes.NewReader(t.Bytes)
  45. }
  46. func Named(s string) (*named, error) {
  47. templates, err := templatesOnce()
  48. if err != nil {
  49. return nil, err
  50. }
  51. var template *named
  52. score := math.MaxInt
  53. for _, t := range templates {
  54. if s := levenshtein.ComputeDistance(s, t.Template); s < score {
  55. score = s
  56. template = t
  57. }
  58. }
  59. if score < 100 {
  60. return template, nil
  61. }
  62. return nil, errors.New("no matching template found")
  63. }
  64. var DefaultTemplate, _ = Parse("{{ .Prompt }}")
  65. type Template struct {
  66. *template.Template
  67. raw string
  68. }
  69. // response is a template node that can be added to templates that don't already have one
  70. var response = parse.ActionNode{
  71. NodeType: parse.NodeAction,
  72. Pipe: &parse.PipeNode{
  73. NodeType: parse.NodePipe,
  74. Cmds: []*parse.CommandNode{
  75. {
  76. NodeType: parse.NodeCommand,
  77. Args: []parse.Node{
  78. &parse.FieldNode{
  79. NodeType: parse.NodeField,
  80. Ident: []string{"Response"},
  81. },
  82. },
  83. },
  84. },
  85. },
  86. }
  87. func Parse(s string) (*Template, error) {
  88. tmpl := template.New("").Option("missingkey=zero")
  89. tmpl, err := tmpl.Parse(s)
  90. if err != nil {
  91. return nil, err
  92. }
  93. t := Template{Template: tmpl, raw: s}
  94. if vars := t.Vars(); !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") {
  95. // touch up the template and append {{ .Response }}
  96. tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, &response)
  97. }
  98. return &t, nil
  99. }
  100. func (t *Template) String() string {
  101. return t.raw
  102. }
  103. func (t *Template) Vars() []string {
  104. var vars []string
  105. for _, tt := range t.Templates() {
  106. for _, n := range tt.Root.Nodes {
  107. vars = append(vars, parseNode(n)...)
  108. }
  109. }
  110. set := make(map[string]struct{})
  111. for _, n := range vars {
  112. set[strings.ToLower(n)] = struct{}{}
  113. }
  114. vars = maps.Keys(set)
  115. slices.Sort(vars)
  116. return vars
  117. }
  118. type Values struct {
  119. Messages []api.Message
  120. }
  121. func (t *Template) Execute(w io.Writer, v Values) error {
  122. system, collated := collate(v.Messages)
  123. if slices.Contains(t.Vars(), "messages") {
  124. return t.Template.Execute(w, map[string]any{
  125. "System": system,
  126. "Messages": collated,
  127. })
  128. }
  129. var b bytes.Buffer
  130. var prompt, response string
  131. for i, m := range collated {
  132. if m.Role == "user" {
  133. prompt = m.Content
  134. } else {
  135. response = m.Content
  136. }
  137. if i != len(collated)-1 && prompt != "" && response != "" {
  138. if err := t.Template.Execute(&b, map[string]any{
  139. "System": "",
  140. "Prompt": prompt,
  141. "Response": response,
  142. }); err != nil {
  143. return err
  144. }
  145. prompt = ""
  146. response = ""
  147. }
  148. }
  149. var cut bool
  150. tree := t.Template.Copy()
  151. // for the last message, cut everything after "{{ .Response }}"
  152. tree.Root.Nodes = slices.DeleteFunc(tree.Root.Nodes, func(n parse.Node) bool {
  153. if slices.Contains(parseNode(n), "Response") {
  154. cut = true
  155. }
  156. return cut
  157. })
  158. if err := template.Must(template.New("").AddParseTree("", tree)).Execute(&b, map[string]any{
  159. "System": system,
  160. "Prompt": prompt,
  161. }); err != nil {
  162. return err
  163. }
  164. _, err := io.Copy(w, &b)
  165. return err
  166. }
  167. type messages []*api.Message
  168. // collate messages based on role. consecutive messages of the same role are merged
  169. // into a single message. collate also pulls out and merges messages with Role == "system"
  170. // which are templated separately. As a side effect, it mangles message content adding image
  171. // tags ([img-%d]) as needed
  172. func collate(msgs []api.Message) (system string, collated messages) {
  173. var n int
  174. for i := range msgs {
  175. msg := msgs[i]
  176. if msg.Role == "system" {
  177. if system != "" {
  178. system += "\n\n"
  179. }
  180. system += msg.Content
  181. continue
  182. }
  183. for range msg.Images {
  184. imageTag := fmt.Sprintf("[img-%d]", n)
  185. if !strings.Contains(msg.Content, "[img]") {
  186. msg.Content = strings.TrimSpace("[img] " + msg.Content)
  187. }
  188. msg.Content = strings.Replace(msg.Content, "[img]", imageTag, 1)
  189. n++
  190. }
  191. if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role {
  192. collated[len(collated)-1].Content += "\n\n" + msg.Content
  193. } else {
  194. collated = append(collated, &msg)
  195. }
  196. }
  197. return
  198. }
  199. func parseNode(n parse.Node) []string {
  200. switch n := n.(type) {
  201. case *parse.ActionNode:
  202. return parseNode(n.Pipe)
  203. case *parse.IfNode:
  204. names := parseNode(n.Pipe)
  205. names = append(names, parseNode(n.List)...)
  206. if n.ElseList != nil {
  207. names = append(names, parseNode(n.ElseList)...)
  208. }
  209. return names
  210. case *parse.RangeNode:
  211. names := parseNode(n.Pipe)
  212. names = append(names, parseNode(n.List)...)
  213. if n.ElseList != nil {
  214. names = append(names, parseNode(n.ElseList)...)
  215. }
  216. return names
  217. case *parse.WithNode:
  218. names := parseNode(n.Pipe)
  219. names = append(names, parseNode(n.List)...)
  220. if n.ElseList != nil {
  221. names = append(names, parseNode(n.ElseList)...)
  222. }
  223. return names
  224. case *parse.PipeNode:
  225. var names []string
  226. for _, c := range n.Cmds {
  227. for _, a := range c.Args {
  228. names = append(names, parseNode(a)...)
  229. }
  230. }
  231. return names
  232. case *parse.ListNode:
  233. var names []string
  234. for _, n := range n.Nodes {
  235. names = append(names, parseNode(n)...)
  236. }
  237. return names
  238. case *parse.FieldNode:
  239. return n.Ident
  240. case *parse.TemplateNode:
  241. return parseNode(n.Pipe)
  242. }
  243. return nil
  244. }