template.go 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429
  1. package template
  2. import (
  3. "bytes"
  4. "embed"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "math"
  10. "slices"
  11. "strings"
  12. "sync"
  13. "text/template"
  14. "text/template/parse"
  15. "github.com/agnivade/levenshtein"
  16. "github.com/ollama/ollama/api"
  17. "golang.org/x/exp/maps"
  18. )
  19. //go:embed index.json
  20. var indexBytes []byte
  21. //go:embed *.gotmpl
  22. var templatesFS embed.FS
  23. var templatesOnce = sync.OnceValues(func() ([]*named, error) {
  24. var templates []*named
  25. if err := json.Unmarshal(indexBytes, &templates); err != nil {
  26. return nil, err
  27. }
  28. for _, t := range templates {
  29. bts, err := templatesFS.ReadFile(t.Name + ".gotmpl")
  30. if err != nil {
  31. return nil, err
  32. }
  33. // normalize line endings
  34. t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n"))
  35. }
  36. return templates, nil
  37. })
  38. type named struct {
  39. Name string `json:"name"`
  40. Template string `json:"template"`
  41. Bytes []byte
  42. }
  43. func (t named) Reader() io.Reader {
  44. return bytes.NewReader(t.Bytes)
  45. }
  46. func Named(s string) (*named, error) {
  47. templates, err := templatesOnce()
  48. if err != nil {
  49. return nil, err
  50. }
  51. var template *named
  52. score := math.MaxInt
  53. for _, t := range templates {
  54. if s := levenshtein.ComputeDistance(s, t.Template); s < score {
  55. score = s
  56. template = t
  57. }
  58. }
  59. if score < 100 {
  60. return template, nil
  61. }
  62. return nil, errors.New("no matching template found")
  63. }
  64. var DefaultTemplate, _ = Parse("{{ .Prompt }}")
  65. type Template struct {
  66. *template.Template
  67. raw string
  68. }
  69. // response is a template node that can be added to templates that don't already have one
  70. var response = parse.ActionNode{
  71. NodeType: parse.NodeAction,
  72. Pipe: &parse.PipeNode{
  73. NodeType: parse.NodePipe,
  74. Cmds: []*parse.CommandNode{
  75. {
  76. NodeType: parse.NodeCommand,
  77. Args: []parse.Node{
  78. &parse.FieldNode{
  79. NodeType: parse.NodeField,
  80. Ident: []string{"Response"},
  81. },
  82. },
  83. },
  84. },
  85. },
  86. }
  87. var funcs = template.FuncMap{
  88. "json": func(v any) string {
  89. b, _ := json.Marshal(v)
  90. return string(b)
  91. },
  92. }
  93. func Parse(s string) (*Template, error) {
  94. tmpl := template.New("").Option("missingkey=zero").Funcs(funcs)
  95. tmpl, err := tmpl.Parse(s)
  96. if err != nil {
  97. return nil, err
  98. }
  99. t := Template{Template: tmpl, raw: s}
  100. if vars := t.Vars(); !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") {
  101. // touch up the template and append {{ .Response }}
  102. tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, &response)
  103. }
  104. return &t, nil
  105. }
  106. func (t *Template) String() string {
  107. return t.raw
  108. }
  109. func (t *Template) Vars() []string {
  110. var vars []string
  111. for _, tt := range t.Templates() {
  112. for _, n := range tt.Root.Nodes {
  113. vars = append(vars, Identifiers(n)...)
  114. }
  115. }
  116. set := make(map[string]struct{})
  117. for _, n := range vars {
  118. set[strings.ToLower(n)] = struct{}{}
  119. }
  120. vars = maps.Keys(set)
  121. slices.Sort(vars)
  122. return vars
  123. }
  124. type Values struct {
  125. Messages []api.Message
  126. Tools []api.Tool
  127. Prompt string
  128. Suffix string
  129. // forceLegacy is a flag used to test compatibility with legacy templates
  130. forceLegacy bool
  131. }
  132. func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
  133. var walk func(parse.Node) parse.Node
  134. walk = func(n parse.Node) parse.Node {
  135. if fn(n) {
  136. return n
  137. }
  138. switch t := n.(type) {
  139. case *parse.ListNode:
  140. for _, c := range t.Nodes {
  141. if n := walk(c); n != nil {
  142. return n
  143. }
  144. }
  145. case *parse.BranchNode:
  146. for _, n := range []*parse.ListNode{t.List, t.ElseList} {
  147. if n != nil {
  148. if n := walk(n); n != nil {
  149. return n
  150. }
  151. }
  152. }
  153. case *parse.IfNode:
  154. return walk(&t.BranchNode)
  155. case *parse.WithNode:
  156. return walk(&t.BranchNode)
  157. case *parse.RangeNode:
  158. return walk(&t.BranchNode)
  159. }
  160. return nil
  161. }
  162. if n := walk(t.Tree.Root); n != nil {
  163. return (&template.Template{
  164. Tree: &parse.Tree{
  165. Root: &parse.ListNode{
  166. Nodes: []parse.Node{n},
  167. },
  168. },
  169. }).Funcs(funcs)
  170. }
  171. return nil
  172. }
  173. func (t *Template) Execute(w io.Writer, v Values) error {
  174. system, messages := collate(v.Messages)
  175. if v.Prompt != "" && v.Suffix != "" {
  176. return t.Template.Execute(w, map[string]any{
  177. "Prompt": v.Prompt,
  178. "Suffix": v.Suffix,
  179. "Response": "",
  180. })
  181. } else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
  182. return t.Template.Execute(w, map[string]any{
  183. "System": system,
  184. "Messages": messages,
  185. "Tools": v.Tools,
  186. })
  187. }
  188. system = ""
  189. var b bytes.Buffer
  190. var prompt, response string
  191. for _, m := range messages {
  192. execute := func() error {
  193. if err := t.Template.Execute(&b, map[string]any{
  194. "System": system,
  195. "Prompt": prompt,
  196. "Response": response,
  197. }); err != nil {
  198. return err
  199. }
  200. system = ""
  201. prompt = ""
  202. response = ""
  203. return nil
  204. }
  205. switch m.Role {
  206. case "system":
  207. if prompt != "" || response != "" {
  208. if err := execute(); err != nil {
  209. return err
  210. }
  211. }
  212. system = m.Content
  213. case "user":
  214. if response != "" {
  215. if err := execute(); err != nil {
  216. return err
  217. }
  218. }
  219. prompt = m.Content
  220. case "assistant":
  221. response = m.Content
  222. }
  223. }
  224. var cut bool
  225. nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
  226. if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") {
  227. cut = true
  228. }
  229. return cut
  230. })
  231. tree := parse.Tree{Root: nodes.(*parse.ListNode)}
  232. if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
  233. "System": system,
  234. "Prompt": prompt,
  235. }); err != nil {
  236. return err
  237. }
  238. _, err := io.Copy(w, &b)
  239. return err
  240. }
  241. // collate messages based on role. consecutive messages of the same role are merged
  242. // into a single message. collate also collects and returns all system messages.
  243. // collate mutates message content adding image tags ([img-%d]) as needed
  244. func collate(msgs []api.Message) (string, []*api.Message) {
  245. var n int
  246. var system []string
  247. var collated []*api.Message
  248. for i := range msgs {
  249. msg := msgs[i]
  250. for range msg.Images {
  251. imageTag := fmt.Sprintf("[img-%d]", n)
  252. if !strings.Contains(msg.Content, "[img]") {
  253. msg.Content = strings.TrimSpace("[img] " + msg.Content)
  254. }
  255. msg.Content = strings.Replace(msg.Content, "[img]", imageTag, 1)
  256. n++
  257. }
  258. if msg.Role == "system" {
  259. system = append(system, msg.Content)
  260. }
  261. if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role {
  262. collated[len(collated)-1].Content += "\n\n" + msg.Content
  263. } else {
  264. collated = append(collated, &msg)
  265. }
  266. }
  267. return strings.Join(system, "\n\n"), collated
  268. }
  269. // Identifiers walks the node tree returning any identifiers it finds along the way
  270. func Identifiers(n parse.Node) []string {
  271. switch n := n.(type) {
  272. case *parse.ListNode:
  273. var names []string
  274. for _, n := range n.Nodes {
  275. names = append(names, Identifiers(n)...)
  276. }
  277. return names
  278. case *parse.TemplateNode:
  279. return Identifiers(n.Pipe)
  280. case *parse.ActionNode:
  281. return Identifiers(n.Pipe)
  282. case *parse.BranchNode:
  283. names := Identifiers(n.Pipe)
  284. for _, n := range []*parse.ListNode{n.List, n.ElseList} {
  285. if n != nil {
  286. names = append(names, Identifiers(n)...)
  287. }
  288. }
  289. return names
  290. case *parse.IfNode:
  291. return Identifiers(&n.BranchNode)
  292. case *parse.RangeNode:
  293. return Identifiers(&n.BranchNode)
  294. case *parse.WithNode:
  295. return Identifiers(&n.BranchNode)
  296. case *parse.PipeNode:
  297. var names []string
  298. for _, c := range n.Cmds {
  299. for _, a := range c.Args {
  300. names = append(names, Identifiers(a)...)
  301. }
  302. }
  303. return names
  304. case *parse.FieldNode:
  305. return n.Ident
  306. case *parse.VariableNode:
  307. return n.Ident
  308. }
  309. return nil
  310. }
  311. // deleteNode walks the node list and deletes nodes that match the predicate
  312. // this is currently to remove the {{ .Response }} node from templates
  313. func deleteNode(n parse.Node, fn func(parse.Node) bool) parse.Node {
  314. var walk func(n parse.Node) parse.Node
  315. walk = func(n parse.Node) parse.Node {
  316. if fn(n) {
  317. return nil
  318. }
  319. switch t := n.(type) {
  320. case *parse.ListNode:
  321. var nodes []parse.Node
  322. for _, c := range t.Nodes {
  323. if n := walk(c); n != nil {
  324. nodes = append(nodes, n)
  325. }
  326. }
  327. t.Nodes = nodes
  328. return t
  329. case *parse.IfNode:
  330. t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
  331. case *parse.WithNode:
  332. t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
  333. case *parse.RangeNode:
  334. t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
  335. case *parse.BranchNode:
  336. t.List = walk(t.List).(*parse.ListNode)
  337. if t.ElseList != nil {
  338. t.ElseList = walk(t.ElseList).(*parse.ListNode)
  339. }
  340. case *parse.ActionNode:
  341. n := walk(t.Pipe)
  342. if n == nil {
  343. return nil
  344. }
  345. t.Pipe = n.(*parse.PipeNode)
  346. case *parse.PipeNode:
  347. var commands []*parse.CommandNode
  348. for _, c := range t.Cmds {
  349. var args []parse.Node
  350. for _, a := range c.Args {
  351. if n := walk(a); n != nil {
  352. args = append(args, n)
  353. }
  354. }
  355. if len(args) == 0 {
  356. return nil
  357. }
  358. c.Args = args
  359. commands = append(commands, c)
  360. }
  361. if len(commands) == 0 {
  362. return nil
  363. }
  364. t.Cmds = commands
  365. }
  366. return n
  367. }
  368. return walk(n)
  369. }