template.go 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. package template
  2. import (
  3. "bytes"
  4. "embed"
  5. "encoding/json"
  6. "errors"
  7. "io"
  8. "math"
  9. "slices"
  10. "strings"
  11. "sync"
  12. "text/template"
  13. "text/template/parse"
  14. "github.com/agnivade/levenshtein"
  15. "golang.org/x/exp/maps"
  16. "github.com/ollama/ollama/api"
  17. )
  18. //go:embed index.json
  19. var indexBytes []byte
  20. //go:embed *.gotmpl
  21. //go:embed *.json
  22. var templatesFS embed.FS
  23. var templatesOnce = sync.OnceValues(func() ([]*named, error) {
  24. var templates []*named
  25. if err := json.Unmarshal(indexBytes, &templates); err != nil {
  26. return nil, err
  27. }
  28. for _, t := range templates {
  29. bts, err := templatesFS.ReadFile(t.Name + ".gotmpl")
  30. if err != nil {
  31. return nil, err
  32. }
  33. // normalize line endings
  34. t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n"))
  35. params, err := templatesFS.ReadFile(t.Name + ".json")
  36. if err != nil {
  37. continue
  38. }
  39. if err := json.Unmarshal(params, &t.Parameters); err != nil {
  40. return nil, err
  41. }
  42. }
  43. return templates, nil
  44. })
  45. type named struct {
  46. Name string `json:"name"`
  47. Template string `json:"template"`
  48. Bytes []byte
  49. Parameters *struct {
  50. Stop []string `json:"stop"`
  51. }
  52. }
  53. func (t named) Reader() io.Reader {
  54. return bytes.NewReader(t.Bytes)
  55. }
  56. func Named(s string) (*named, error) {
  57. templates, err := templatesOnce()
  58. if err != nil {
  59. return nil, err
  60. }
  61. var template *named
  62. score := math.MaxInt
  63. for _, t := range templates {
  64. if s := levenshtein.ComputeDistance(s, t.Template); s < score {
  65. score = s
  66. template = t
  67. }
  68. }
  69. if score < 100 {
  70. return template, nil
  71. }
  72. return nil, errors.New("no matching template found")
  73. }
  74. var DefaultTemplate, _ = Parse("{{ .Prompt }}")
  75. type Template struct {
  76. *template.Template
  77. raw string
  78. }
  79. // response is a template node that can be added to templates that don't already have one
  80. var response = parse.ActionNode{
  81. NodeType: parse.NodeAction,
  82. Pipe: &parse.PipeNode{
  83. NodeType: parse.NodePipe,
  84. Cmds: []*parse.CommandNode{
  85. {
  86. NodeType: parse.NodeCommand,
  87. Args: []parse.Node{
  88. &parse.FieldNode{
  89. NodeType: parse.NodeField,
  90. Ident: []string{"Response"},
  91. },
  92. },
  93. },
  94. },
  95. },
  96. }
  97. var funcs = template.FuncMap{
  98. "json": func(v any) string {
  99. b, _ := json.Marshal(v)
  100. return string(b)
  101. },
  102. }
  103. func Parse(s string) (*Template, error) {
  104. tmpl := template.New("").Option("missingkey=zero").Funcs(funcs)
  105. tmpl, err := tmpl.Parse(s)
  106. if err != nil {
  107. return nil, err
  108. }
  109. t := Template{Template: tmpl, raw: s}
  110. if vars := t.Vars(); !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") {
  111. // touch up the template and append {{ .Response }}
  112. tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, &response)
  113. }
  114. return &t, nil
  115. }
  116. func (t *Template) String() string {
  117. return t.raw
  118. }
  119. func (t *Template) Vars() []string {
  120. var vars []string
  121. for _, tt := range t.Templates() {
  122. for _, n := range tt.Root.Nodes {
  123. vars = append(vars, Identifiers(n)...)
  124. }
  125. }
  126. set := make(map[string]struct{})
  127. for _, n := range vars {
  128. set[strings.ToLower(n)] = struct{}{}
  129. }
  130. vars = maps.Keys(set)
  131. slices.Sort(vars)
  132. return vars
  133. }
  134. type Values struct {
  135. Messages []api.Message
  136. api.Tools
  137. Prompt string
  138. Suffix string
  139. // forceLegacy is a flag used to test compatibility with legacy templates
  140. forceLegacy bool
  141. }
  142. func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
  143. var walk func(parse.Node) parse.Node
  144. walk = func(n parse.Node) parse.Node {
  145. if fn(n) {
  146. return n
  147. }
  148. switch t := n.(type) {
  149. case *parse.ListNode:
  150. for _, c := range t.Nodes {
  151. if n := walk(c); n != nil {
  152. return n
  153. }
  154. }
  155. case *parse.BranchNode:
  156. for _, n := range []*parse.ListNode{t.List, t.ElseList} {
  157. if n != nil {
  158. if n := walk(n); n != nil {
  159. return n
  160. }
  161. }
  162. }
  163. case *parse.IfNode:
  164. return walk(&t.BranchNode)
  165. case *parse.WithNode:
  166. return walk(&t.BranchNode)
  167. case *parse.RangeNode:
  168. return walk(&t.BranchNode)
  169. }
  170. return nil
  171. }
  172. if n := walk(t.Tree.Root); n != nil {
  173. return (&template.Template{
  174. Tree: &parse.Tree{
  175. Root: &parse.ListNode{
  176. Nodes: []parse.Node{n},
  177. },
  178. },
  179. }).Funcs(funcs)
  180. }
  181. return nil
  182. }
  183. func (t *Template) Execute(w io.Writer, v Values) error {
  184. system, messages := collate(v.Messages)
  185. if v.Prompt != "" && v.Suffix != "" {
  186. return t.Template.Execute(w, map[string]any{
  187. "Prompt": v.Prompt,
  188. "Suffix": v.Suffix,
  189. "Response": "",
  190. })
  191. } else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
  192. return t.Template.Execute(w, map[string]any{
  193. "System": system,
  194. "Messages": messages,
  195. "Tools": v.Tools,
  196. "Response": "",
  197. })
  198. }
  199. system = ""
  200. var b bytes.Buffer
  201. var prompt, response string
  202. for _, m := range messages {
  203. execute := func() error {
  204. if err := t.Template.Execute(&b, map[string]any{
  205. "System": system,
  206. "Prompt": prompt,
  207. "Response": response,
  208. }); err != nil {
  209. return err
  210. }
  211. system = ""
  212. prompt = ""
  213. response = ""
  214. return nil
  215. }
  216. switch m.Role {
  217. case "system":
  218. if prompt != "" || response != "" {
  219. if err := execute(); err != nil {
  220. return err
  221. }
  222. }
  223. system = m.Content
  224. case "user":
  225. if response != "" {
  226. if err := execute(); err != nil {
  227. return err
  228. }
  229. }
  230. prompt = m.Content
  231. case "assistant":
  232. response = m.Content
  233. }
  234. }
  235. var cut bool
  236. nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
  237. if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") {
  238. cut = true
  239. return false
  240. }
  241. return cut
  242. })
  243. tree := parse.Tree{Root: nodes.(*parse.ListNode)}
  244. if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
  245. "System": system,
  246. "Prompt": prompt,
  247. "Response": response,
  248. }); err != nil {
  249. return err
  250. }
  251. _, err := io.Copy(w, &b)
  252. return err
  253. }
  254. // collate messages based on role. consecutive messages of the same role are merged
  255. // into a single message. collate also collects and returns all system messages.
  256. // collate mutates message content adding image tags ([img-%d]) as needed
  257. func collate(msgs []api.Message) (string, []*api.Message) {
  258. var system []string
  259. var collated []*api.Message
  260. for i := range msgs {
  261. msg := msgs[i]
  262. if msg.Role == "system" {
  263. system = append(system, msg.Content)
  264. }
  265. if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role {
  266. collated[len(collated)-1].Content += "\n\n" + msg.Content
  267. } else {
  268. collated = append(collated, &msg)
  269. }
  270. }
  271. return strings.Join(system, "\n\n"), collated
  272. }
  273. // Identifiers walks the node tree returning any identifiers it finds along the way
  274. func Identifiers(n parse.Node) []string {
  275. switch n := n.(type) {
  276. case *parse.ListNode:
  277. var names []string
  278. for _, n := range n.Nodes {
  279. names = append(names, Identifiers(n)...)
  280. }
  281. return names
  282. case *parse.TemplateNode:
  283. return Identifiers(n.Pipe)
  284. case *parse.ActionNode:
  285. return Identifiers(n.Pipe)
  286. case *parse.BranchNode:
  287. names := Identifiers(n.Pipe)
  288. for _, n := range []*parse.ListNode{n.List, n.ElseList} {
  289. if n != nil {
  290. names = append(names, Identifiers(n)...)
  291. }
  292. }
  293. return names
  294. case *parse.IfNode:
  295. return Identifiers(&n.BranchNode)
  296. case *parse.RangeNode:
  297. return Identifiers(&n.BranchNode)
  298. case *parse.WithNode:
  299. return Identifiers(&n.BranchNode)
  300. case *parse.PipeNode:
  301. var names []string
  302. for _, c := range n.Cmds {
  303. for _, a := range c.Args {
  304. names = append(names, Identifiers(a)...)
  305. }
  306. }
  307. return names
  308. case *parse.FieldNode:
  309. return n.Ident
  310. case *parse.VariableNode:
  311. return n.Ident
  312. }
  313. return nil
  314. }
  315. // deleteNode walks the node list and deletes nodes that match the predicate
  316. // this is currently to remove the {{ .Response }} node from templates
  317. func deleteNode(n parse.Node, fn func(parse.Node) bool) parse.Node {
  318. var walk func(n parse.Node) parse.Node
  319. walk = func(n parse.Node) parse.Node {
  320. if fn(n) {
  321. return nil
  322. }
  323. switch t := n.(type) {
  324. case *parse.ListNode:
  325. var nodes []parse.Node
  326. for _, c := range t.Nodes {
  327. if n := walk(c); n != nil {
  328. nodes = append(nodes, n)
  329. }
  330. }
  331. t.Nodes = nodes
  332. return t
  333. case *parse.IfNode:
  334. t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
  335. case *parse.WithNode:
  336. t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
  337. case *parse.RangeNode:
  338. t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
  339. case *parse.BranchNode:
  340. t.List = walk(t.List).(*parse.ListNode)
  341. if t.ElseList != nil {
  342. t.ElseList = walk(t.ElseList).(*parse.ListNode)
  343. }
  344. case *parse.ActionNode:
  345. n := walk(t.Pipe)
  346. if n == nil {
  347. return nil
  348. }
  349. t.Pipe = n.(*parse.PipeNode)
  350. case *parse.PipeNode:
  351. var commands []*parse.CommandNode
  352. for _, c := range t.Cmds {
  353. var args []parse.Node
  354. for _, a := range c.Args {
  355. if n := walk(a); n != nil {
  356. args = append(args, n)
  357. }
  358. }
  359. if len(args) == 0 {
  360. return nil
  361. }
  362. c.Args = args
  363. commands = append(commands, c)
  364. }
  365. if len(commands) == 0 {
  366. return nil
  367. }
  368. t.Cmds = commands
  369. }
  370. return n
  371. }
  372. return walk(n)
  373. }