template.go 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. package template
  2. import (
  3. "bytes"
  4. "embed"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "math"
  10. "slices"
  11. "strings"
  12. "sync"
  13. "text/template"
  14. "text/template/parse"
  15. "time"
  16. "github.com/agnivade/levenshtein"
  17. "github.com/ollama/ollama/api"
  18. "golang.org/x/exp/maps"
  19. )
  20. //go:embed index.json
  21. var indexBytes []byte
  22. //go:embed *.gotmpl
  23. var templatesFS embed.FS
  24. var templatesOnce = sync.OnceValues(func() ([]*named, error) {
  25. var templates []*named
  26. if err := json.Unmarshal(indexBytes, &templates); err != nil {
  27. return nil, err
  28. }
  29. for _, t := range templates {
  30. bts, err := templatesFS.ReadFile(t.Name + ".gotmpl")
  31. if err != nil {
  32. return nil, err
  33. }
  34. // normalize line endings
  35. t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n"))
  36. }
  37. return templates, nil
  38. })
  39. type named struct {
  40. Name string `json:"name"`
  41. Template string `json:"template"`
  42. Bytes []byte
  43. }
  44. func (t named) Reader() io.Reader {
  45. return bytes.NewReader(t.Bytes)
  46. }
  47. func Named(s string) (*named, error) {
  48. templates, err := templatesOnce()
  49. if err != nil {
  50. return nil, err
  51. }
  52. var template *named
  53. score := math.MaxInt
  54. for _, t := range templates {
  55. if s := levenshtein.ComputeDistance(s, t.Template); s < score {
  56. score = s
  57. template = t
  58. }
  59. }
  60. if score < 100 {
  61. return template, nil
  62. }
  63. return nil, errors.New("no matching template found")
  64. }
  65. var DefaultTemplate, _ = Parse("{{ .Prompt }}")
  66. type Template struct {
  67. *template.Template
  68. raw string
  69. }
  70. // response is a template node that can be added to templates that don't already have one
  71. var response = parse.ActionNode{
  72. NodeType: parse.NodeAction,
  73. Pipe: &parse.PipeNode{
  74. NodeType: parse.NodePipe,
  75. Cmds: []*parse.CommandNode{
  76. {
  77. NodeType: parse.NodeCommand,
  78. Args: []parse.Node{
  79. &parse.FieldNode{
  80. NodeType: parse.NodeField,
  81. Ident: []string{"Response"},
  82. },
  83. },
  84. },
  85. },
  86. },
  87. }
  88. var funcs = template.FuncMap{
  89. "json": func(v any) string {
  90. b, _ := json.Marshal(v)
  91. return string(b)
  92. },
  93. "now": func() string {
  94. return time.Now().Format("2006-01-02 15:04:05")
  95. },
  96. }
  97. func Parse(s string) (*Template, error) {
  98. tmpl := template.New("").Option("missingkey=zero").Funcs(funcs)
  99. tmpl, err := tmpl.Parse(s)
  100. if err != nil {
  101. return nil, err
  102. }
  103. t := Template{Template: tmpl, raw: s}
  104. if vars := t.Vars(); !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") {
  105. // touch up the template and append {{ .Response }}
  106. tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, &response)
  107. }
  108. return &t, nil
  109. }
  110. func (t *Template) String() string {
  111. return t.raw
  112. }
  113. func (t *Template) Vars() []string {
  114. var vars []string
  115. for _, tt := range t.Templates() {
  116. for _, n := range tt.Root.Nodes {
  117. vars = append(vars, Identifiers(n)...)
  118. }
  119. }
  120. set := make(map[string]struct{})
  121. for _, n := range vars {
  122. set[strings.ToLower(n)] = struct{}{}
  123. }
  124. vars = maps.Keys(set)
  125. slices.Sort(vars)
  126. return vars
  127. }
  128. type Values struct {
  129. Messages []api.Message
  130. Tools []api.Tool
  131. // forceLegacy is a flag used to test compatibility with legacy templates
  132. forceLegacy bool
  133. }
  134. func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
  135. var walk func(parse.Node) parse.Node
  136. walk = func(n parse.Node) parse.Node {
  137. if fn(n) {
  138. return n
  139. }
  140. switch t := n.(type) {
  141. case *parse.ListNode:
  142. for _, c := range t.Nodes {
  143. if n := walk(c); n != nil {
  144. return n
  145. }
  146. }
  147. case *parse.BranchNode:
  148. for _, n := range []*parse.ListNode{t.List, t.ElseList} {
  149. if n != nil {
  150. if n := walk(n); n != nil {
  151. return n
  152. }
  153. }
  154. }
  155. case *parse.IfNode:
  156. return walk(&t.BranchNode)
  157. case *parse.WithNode:
  158. return walk(&t.BranchNode)
  159. case *parse.RangeNode:
  160. return walk(&t.BranchNode)
  161. }
  162. return nil
  163. }
  164. if n := walk(t.Tree.Root); n != nil {
  165. return (&template.Template{
  166. Tree: &parse.Tree{
  167. Root: &parse.ListNode{
  168. Nodes: []parse.Node{n},
  169. },
  170. },
  171. }).Funcs(funcs)
  172. }
  173. return nil
  174. }
  175. func (t *Template) Execute(w io.Writer, v Values) error {
  176. system, messages := collate(v.Messages)
  177. if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
  178. return t.Template.Execute(w, map[string]any{
  179. "System": system,
  180. "Messages": messages,
  181. "Tools": v.Tools,
  182. })
  183. }
  184. system = ""
  185. var b bytes.Buffer
  186. var prompt, response string
  187. for _, m := range messages {
  188. execute := func() error {
  189. if err := t.Template.Execute(&b, map[string]any{
  190. "System": system,
  191. "Prompt": prompt,
  192. "Response": response,
  193. }); err != nil {
  194. return err
  195. }
  196. system = ""
  197. prompt = ""
  198. response = ""
  199. return nil
  200. }
  201. switch m.Role {
  202. case "system":
  203. if prompt != "" || response != "" {
  204. if err := execute(); err != nil {
  205. return err
  206. }
  207. }
  208. system = m.Content
  209. case "user":
  210. if response != "" {
  211. if err := execute(); err != nil {
  212. return err
  213. }
  214. }
  215. prompt = m.Content
  216. case "assistant":
  217. response = m.Content
  218. }
  219. }
  220. var cut bool
  221. nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
  222. if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") {
  223. cut = true
  224. }
  225. return cut
  226. })
  227. tree := parse.Tree{Root: nodes.(*parse.ListNode)}
  228. if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
  229. "System": system,
  230. "Prompt": prompt,
  231. }); err != nil {
  232. return err
  233. }
  234. _, err := io.Copy(w, &b)
  235. return err
  236. }
  237. // collate messages based on role. consecutive messages of the same role are merged
  238. // into a single message. collate also collects and returns all system messages.
  239. // collate mutates message content adding image tags ([img-%d]) as needed
  240. func collate(msgs []api.Message) (string, []*api.Message) {
  241. var n int
  242. var system []string
  243. var collated []*api.Message
  244. for i := range msgs {
  245. msg := msgs[i]
  246. for range msg.Images {
  247. imageTag := fmt.Sprintf("[img-%d]", n)
  248. if !strings.Contains(msg.Content, "[img]") {
  249. msg.Content = strings.TrimSpace("[img] " + msg.Content)
  250. }
  251. msg.Content = strings.Replace(msg.Content, "[img]", imageTag, 1)
  252. n++
  253. }
  254. if msg.Role == "system" {
  255. system = append(system, msg.Content)
  256. }
  257. if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role {
  258. collated[len(collated)-1].Content += "\n\n" + msg.Content
  259. } else {
  260. collated = append(collated, &msg)
  261. }
  262. }
  263. return strings.Join(system, "\n\n"), collated
  264. }
  265. // Identifiers walks the node tree returning any identifiers it finds along the way
  266. func Identifiers(n parse.Node) []string {
  267. switch n := n.(type) {
  268. case *parse.ListNode:
  269. var names []string
  270. for _, n := range n.Nodes {
  271. names = append(names, Identifiers(n)...)
  272. }
  273. return names
  274. case *parse.TemplateNode:
  275. return Identifiers(n.Pipe)
  276. case *parse.ActionNode:
  277. return Identifiers(n.Pipe)
  278. case *parse.BranchNode:
  279. names := Identifiers(n.Pipe)
  280. for _, n := range []*parse.ListNode{n.List, n.ElseList} {
  281. if n != nil {
  282. names = append(names, Identifiers(n)...)
  283. }
  284. }
  285. return names
  286. case *parse.IfNode:
  287. return Identifiers(&n.BranchNode)
  288. case *parse.RangeNode:
  289. return Identifiers(&n.BranchNode)
  290. case *parse.WithNode:
  291. return Identifiers(&n.BranchNode)
  292. case *parse.PipeNode:
  293. var names []string
  294. for _, c := range n.Cmds {
  295. for _, a := range c.Args {
  296. names = append(names, Identifiers(a)...)
  297. }
  298. }
  299. return names
  300. case *parse.FieldNode:
  301. return n.Ident
  302. case *parse.VariableNode:
  303. return n.Ident
  304. }
  305. return nil
  306. }
  307. // deleteNode walks the node list and deletes nodes that match the predicate
  308. // this is currently to remove the {{ .Response }} node from templates
  309. func deleteNode(n parse.Node, fn func(parse.Node) bool) parse.Node {
  310. var walk func(n parse.Node) parse.Node
  311. walk = func(n parse.Node) parse.Node {
  312. if fn(n) {
  313. return nil
  314. }
  315. switch t := n.(type) {
  316. case *parse.ListNode:
  317. var nodes []parse.Node
  318. for _, c := range t.Nodes {
  319. if n := walk(c); n != nil {
  320. nodes = append(nodes, n)
  321. }
  322. }
  323. t.Nodes = nodes
  324. return t
  325. case *parse.IfNode:
  326. t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
  327. case *parse.WithNode:
  328. t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
  329. case *parse.RangeNode:
  330. t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
  331. case *parse.BranchNode:
  332. t.List = walk(t.List).(*parse.ListNode)
  333. if t.ElseList != nil {
  334. t.ElseList = walk(t.ElseList).(*parse.ListNode)
  335. }
  336. case *parse.ActionNode:
  337. n := walk(t.Pipe)
  338. if n == nil {
  339. return nil
  340. }
  341. t.Pipe = n.(*parse.PipeNode)
  342. case *parse.PipeNode:
  343. var commands []*parse.CommandNode
  344. for _, c := range t.Cmds {
  345. var args []parse.Node
  346. for _, a := range c.Args {
  347. if n := walk(a); n != nil {
  348. args = append(args, n)
  349. }
  350. }
  351. if len(args) == 0 {
  352. return nil
  353. }
  354. c.Args = args
  355. commands = append(commands, c)
  356. }
  357. if len(commands) == 0 {
  358. return nil
  359. }
  360. t.Cmds = commands
  361. }
  362. return n
  363. }
  364. return walk(n)
  365. }