template.go 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432
  1. package template
  2. import (
  3. "bytes"
  4. "embed"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "math"
  10. "slices"
  11. "strings"
  12. "sync"
  13. "text/template"
  14. "text/template/parse"
  15. "github.com/agnivade/levenshtein"
  16. "github.com/ollama/ollama/api"
  17. "golang.org/x/exp/maps"
  18. )
  19. //go:embed index.json
  20. var indexBytes []byte
  21. //go:embed *.gotmpl
  22. var templatesFS embed.FS
  23. var templatesOnce = sync.OnceValues(func() ([]*named, error) {
  24. var templates []*named
  25. if err := json.Unmarshal(indexBytes, &templates); err != nil {
  26. return nil, err
  27. }
  28. for _, t := range templates {
  29. bts, err := templatesFS.ReadFile(t.Name + ".gotmpl")
  30. if err != nil {
  31. return nil, err
  32. }
  33. // normalize line endings
  34. t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n"))
  35. }
  36. return templates, nil
  37. })
  38. type named struct {
  39. Name string `json:"name"`
  40. Template string `json:"template"`
  41. Bytes []byte
  42. }
  43. func (t named) Reader() io.Reader {
  44. return bytes.NewReader(t.Bytes)
  45. }
  46. func Named(s string) (*named, error) {
  47. templates, err := templatesOnce()
  48. if err != nil {
  49. return nil, err
  50. }
  51. var template *named
  52. score := math.MaxInt
  53. for _, t := range templates {
  54. if s := levenshtein.ComputeDistance(s, t.Template); s < score {
  55. score = s
  56. template = t
  57. }
  58. }
  59. if score < 100 {
  60. return template, nil
  61. }
  62. return nil, errors.New("no matching template found")
  63. }
  64. var DefaultTemplate, _ = Parse("{{ .Prompt }}")
  65. type Template struct {
  66. *template.Template
  67. raw string
  68. }
  69. // response is a template node that can be added to templates that don't already have one
  70. var response = parse.ActionNode{
  71. NodeType: parse.NodeAction,
  72. Pipe: &parse.PipeNode{
  73. NodeType: parse.NodePipe,
  74. Cmds: []*parse.CommandNode{
  75. {
  76. NodeType: parse.NodeCommand,
  77. Args: []parse.Node{
  78. &parse.FieldNode{
  79. NodeType: parse.NodeField,
  80. Ident: []string{"Response"},
  81. },
  82. },
  83. },
  84. },
  85. },
  86. }
  87. var funcs = template.FuncMap{
  88. "json": func(v any) string {
  89. b, _ := json.Marshal(v)
  90. return string(b)
  91. },
  92. }
  93. func Parse(s string) (*Template, error) {
  94. tmpl := template.New("").Option("missingkey=zero").Funcs(funcs)
  95. tmpl, err := tmpl.Parse(s)
  96. if err != nil {
  97. return nil, err
  98. }
  99. t := Template{Template: tmpl, raw: s}
  100. if vars := t.Vars(); !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") {
  101. // touch up the template and append {{ .Response }}
  102. tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, &response)
  103. }
  104. return &t, nil
  105. }
  106. func (t *Template) String() string {
  107. return t.raw
  108. }
  109. func (t *Template) Vars() []string {
  110. var vars []string
  111. for _, tt := range t.Templates() {
  112. for _, n := range tt.Root.Nodes {
  113. vars = append(vars, Identifiers(n)...)
  114. }
  115. }
  116. set := make(map[string]struct{})
  117. for _, n := range vars {
  118. set[strings.ToLower(n)] = struct{}{}
  119. }
  120. vars = maps.Keys(set)
  121. slices.Sort(vars)
  122. return vars
  123. }
  124. type Values struct {
  125. Messages []api.Message
  126. api.Tools
  127. Prompt string
  128. Suffix string
  129. // forceLegacy is a flag used to test compatibility with legacy templates
  130. forceLegacy bool
  131. }
  132. func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
  133. var walk func(parse.Node) parse.Node
  134. walk = func(n parse.Node) parse.Node {
  135. if fn(n) {
  136. return n
  137. }
  138. switch t := n.(type) {
  139. case *parse.ListNode:
  140. for _, c := range t.Nodes {
  141. if n := walk(c); n != nil {
  142. return n
  143. }
  144. }
  145. case *parse.BranchNode:
  146. for _, n := range []*parse.ListNode{t.List, t.ElseList} {
  147. if n != nil {
  148. if n := walk(n); n != nil {
  149. return n
  150. }
  151. }
  152. }
  153. case *parse.IfNode:
  154. return walk(&t.BranchNode)
  155. case *parse.WithNode:
  156. return walk(&t.BranchNode)
  157. case *parse.RangeNode:
  158. return walk(&t.BranchNode)
  159. }
  160. return nil
  161. }
  162. if n := walk(t.Tree.Root); n != nil {
  163. return (&template.Template{
  164. Tree: &parse.Tree{
  165. Root: &parse.ListNode{
  166. Nodes: []parse.Node{n},
  167. },
  168. },
  169. }).Funcs(funcs)
  170. }
  171. return nil
  172. }
  173. func (t *Template) Execute(w io.Writer, v Values) error {
  174. system, messages := collate(v.Messages)
  175. if v.Prompt != "" && v.Suffix != "" {
  176. return t.Template.Execute(w, map[string]any{
  177. "Prompt": v.Prompt,
  178. "Suffix": v.Suffix,
  179. "Response": "",
  180. })
  181. } else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
  182. return t.Template.Execute(w, map[string]any{
  183. "System": system,
  184. "Messages": messages,
  185. "Tools": v.Tools,
  186. "Response": "",
  187. })
  188. }
  189. system = ""
  190. var b bytes.Buffer
  191. var prompt, response string
  192. for _, m := range messages {
  193. execute := func() error {
  194. if err := t.Template.Execute(&b, map[string]any{
  195. "System": system,
  196. "Prompt": prompt,
  197. "Response": response,
  198. }); err != nil {
  199. return err
  200. }
  201. system = ""
  202. prompt = ""
  203. response = ""
  204. return nil
  205. }
  206. switch m.Role {
  207. case "system":
  208. if prompt != "" || response != "" {
  209. if err := execute(); err != nil {
  210. return err
  211. }
  212. }
  213. system = m.Content
  214. case "user":
  215. if response != "" {
  216. if err := execute(); err != nil {
  217. return err
  218. }
  219. }
  220. prompt = m.Content
  221. case "assistant":
  222. response = m.Content
  223. }
  224. }
  225. var cut bool
  226. nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
  227. if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") {
  228. cut = true
  229. return false
  230. }
  231. return cut
  232. })
  233. tree := parse.Tree{Root: nodes.(*parse.ListNode)}
  234. if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
  235. "System": system,
  236. "Prompt": prompt,
  237. "Response": response,
  238. }); err != nil {
  239. return err
  240. }
  241. _, err := io.Copy(w, &b)
  242. return err
  243. }
  244. // collate messages based on role. consecutive messages of the same role are merged
  245. // into a single message. collate also collects and returns all system messages.
  246. // collate mutates message content adding image tags ([img-%d]) as needed
  247. func collate(msgs []api.Message) (string, []*api.Message) {
  248. var n int
  249. var system []string
  250. var collated []*api.Message
  251. for i := range msgs {
  252. msg := msgs[i]
  253. for range msg.Images {
  254. imageTag := fmt.Sprintf("[img-%d]", n)
  255. if !strings.Contains(msg.Content, "[img]") {
  256. msg.Content = strings.TrimSpace("[img] " + msg.Content)
  257. }
  258. msg.Content = strings.Replace(msg.Content, "[img]", imageTag, 1)
  259. n++
  260. }
  261. if msg.Role == "system" {
  262. system = append(system, msg.Content)
  263. }
  264. if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role {
  265. collated[len(collated)-1].Content += "\n\n" + msg.Content
  266. } else {
  267. collated = append(collated, &msg)
  268. }
  269. }
  270. return strings.Join(system, "\n\n"), collated
  271. }
  272. // Identifiers walks the node tree returning any identifiers it finds along the way
  273. func Identifiers(n parse.Node) []string {
  274. switch n := n.(type) {
  275. case *parse.ListNode:
  276. var names []string
  277. for _, n := range n.Nodes {
  278. names = append(names, Identifiers(n)...)
  279. }
  280. return names
  281. case *parse.TemplateNode:
  282. return Identifiers(n.Pipe)
  283. case *parse.ActionNode:
  284. return Identifiers(n.Pipe)
  285. case *parse.BranchNode:
  286. names := Identifiers(n.Pipe)
  287. for _, n := range []*parse.ListNode{n.List, n.ElseList} {
  288. if n != nil {
  289. names = append(names, Identifiers(n)...)
  290. }
  291. }
  292. return names
  293. case *parse.IfNode:
  294. return Identifiers(&n.BranchNode)
  295. case *parse.RangeNode:
  296. return Identifiers(&n.BranchNode)
  297. case *parse.WithNode:
  298. return Identifiers(&n.BranchNode)
  299. case *parse.PipeNode:
  300. var names []string
  301. for _, c := range n.Cmds {
  302. for _, a := range c.Args {
  303. names = append(names, Identifiers(a)...)
  304. }
  305. }
  306. return names
  307. case *parse.FieldNode:
  308. return n.Ident
  309. case *parse.VariableNode:
  310. return n.Ident
  311. }
  312. return nil
  313. }
  314. // deleteNode walks the node list and deletes nodes that match the predicate
  315. // this is currently to remove the {{ .Response }} node from templates
  316. func deleteNode(n parse.Node, fn func(parse.Node) bool) parse.Node {
  317. var walk func(n parse.Node) parse.Node
  318. walk = func(n parse.Node) parse.Node {
  319. if fn(n) {
  320. return nil
  321. }
  322. switch t := n.(type) {
  323. case *parse.ListNode:
  324. var nodes []parse.Node
  325. for _, c := range t.Nodes {
  326. if n := walk(c); n != nil {
  327. nodes = append(nodes, n)
  328. }
  329. }
  330. t.Nodes = nodes
  331. return t
  332. case *parse.IfNode:
  333. t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
  334. case *parse.WithNode:
  335. t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
  336. case *parse.RangeNode:
  337. t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
  338. case *parse.BranchNode:
  339. t.List = walk(t.List).(*parse.ListNode)
  340. if t.ElseList != nil {
  341. t.ElseList = walk(t.ElseList).(*parse.ListNode)
  342. }
  343. case *parse.ActionNode:
  344. n := walk(t.Pipe)
  345. if n == nil {
  346. return nil
  347. }
  348. t.Pipe = n.(*parse.PipeNode)
  349. case *parse.PipeNode:
  350. var commands []*parse.CommandNode
  351. for _, c := range t.Cmds {
  352. var args []parse.Node
  353. for _, a := range c.Args {
  354. if n := walk(a); n != nil {
  355. args = append(args, n)
  356. }
  357. }
  358. if len(args) == 0 {
  359. return nil
  360. }
  361. c.Args = args
  362. commands = append(commands, c)
  363. }
  364. if len(commands) == 0 {
  365. return nil
  366. }
  367. t.Cmds = commands
  368. }
  369. return n
  370. }
  371. return walk(n)
  372. }