template.go 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. package template
  2. import (
  3. "bytes"
  4. "embed"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "math"
  10. "slices"
  11. "strings"
  12. "sync"
  13. "text/template"
  14. "text/template/parse"
  15. "github.com/agnivade/levenshtein"
  16. "golang.org/x/exp/maps"
  17. "github.com/ollama/ollama/api"
  18. )
  19. //go:embed index.json
  20. var indexBytes []byte
  21. //go:embed *.gotmpl
  22. //go:embed *.json
  23. var templatesFS embed.FS
  24. var templatesOnce = sync.OnceValues(func() ([]*named, error) {
  25. var templates []*named
  26. if err := json.Unmarshal(indexBytes, &templates); err != nil {
  27. return nil, err
  28. }
  29. for _, t := range templates {
  30. bts, err := templatesFS.ReadFile(t.Name + ".gotmpl")
  31. if err != nil {
  32. return nil, err
  33. }
  34. // normalize line endings
  35. t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n"))
  36. params, err := templatesFS.ReadFile(t.Name + ".json")
  37. if err != nil {
  38. continue
  39. }
  40. if err := json.Unmarshal(params, &t.Parameters); err != nil {
  41. return nil, err
  42. }
  43. }
  44. return templates, nil
  45. })
  46. type named struct {
  47. Name string `json:"name"`
  48. Template string `json:"template"`
  49. Bytes []byte
  50. Parameters *struct {
  51. Stop []string `json:"stop"`
  52. }
  53. }
  54. func (t named) Reader() io.Reader {
  55. return bytes.NewReader(t.Bytes)
  56. }
  57. func Named(s string) (*named, error) {
  58. templates, err := templatesOnce()
  59. if err != nil {
  60. return nil, err
  61. }
  62. var template *named
  63. score := math.MaxInt
  64. for _, t := range templates {
  65. if s := levenshtein.ComputeDistance(s, t.Template); s < score {
  66. score = s
  67. template = t
  68. }
  69. }
  70. if score < 100 {
  71. return template, nil
  72. }
  73. return nil, errors.New("no matching template found")
  74. }
  75. var DefaultTemplate, _ = Parse("{{ .Prompt }}")
  76. type Template struct {
  77. *template.Template
  78. raw string
  79. }
  80. // response is a template node that can be added to templates that don't already have one
  81. var response = parse.ActionNode{
  82. NodeType: parse.NodeAction,
  83. Pipe: &parse.PipeNode{
  84. NodeType: parse.NodePipe,
  85. Cmds: []*parse.CommandNode{
  86. {
  87. NodeType: parse.NodeCommand,
  88. Args: []parse.Node{
  89. &parse.FieldNode{
  90. NodeType: parse.NodeField,
  91. Ident: []string{"Response"},
  92. },
  93. },
  94. },
  95. },
  96. },
  97. }
  98. var funcs = template.FuncMap{
  99. "json": func(v any) string {
  100. b, _ := json.Marshal(v)
  101. return string(b)
  102. },
  103. }
  104. func Parse(s string) (*Template, error) {
  105. tmpl := template.New("").Option("missingkey=zero").Funcs(funcs)
  106. tmpl, err := tmpl.Parse(s)
  107. if err != nil {
  108. return nil, err
  109. }
  110. t := Template{Template: tmpl, raw: s}
  111. if vars := t.Vars(); !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") {
  112. // touch up the template and append {{ .Response }}
  113. tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, &response)
  114. }
  115. return &t, nil
  116. }
  117. func (t *Template) String() string {
  118. return t.raw
  119. }
  120. func (t *Template) Vars() []string {
  121. var vars []string
  122. for _, tt := range t.Templates() {
  123. for _, n := range tt.Root.Nodes {
  124. vars = append(vars, Identifiers(n)...)
  125. }
  126. }
  127. set := make(map[string]struct{})
  128. for _, n := range vars {
  129. set[strings.ToLower(n)] = struct{}{}
  130. }
  131. vars = maps.Keys(set)
  132. slices.Sort(vars)
  133. return vars
  134. }
  135. type Values struct {
  136. Messages []api.Message
  137. api.Tools
  138. Prompt string
  139. Suffix string
  140. // forceLegacy is a flag used to test compatibility with legacy templates
  141. forceLegacy bool
  142. }
  143. func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
  144. var walk func(parse.Node) parse.Node
  145. walk = func(n parse.Node) parse.Node {
  146. if fn(n) {
  147. return n
  148. }
  149. switch t := n.(type) {
  150. case *parse.ListNode:
  151. for _, c := range t.Nodes {
  152. if n := walk(c); n != nil {
  153. return n
  154. }
  155. }
  156. case *parse.BranchNode:
  157. for _, n := range []*parse.ListNode{t.List, t.ElseList} {
  158. if n != nil {
  159. if n := walk(n); n != nil {
  160. return n
  161. }
  162. }
  163. }
  164. case *parse.IfNode:
  165. return walk(&t.BranchNode)
  166. case *parse.WithNode:
  167. return walk(&t.BranchNode)
  168. case *parse.RangeNode:
  169. return walk(&t.BranchNode)
  170. }
  171. return nil
  172. }
  173. if n := walk(t.Tree.Root); n != nil {
  174. return (&template.Template{
  175. Tree: &parse.Tree{
  176. Root: &parse.ListNode{
  177. Nodes: []parse.Node{n},
  178. },
  179. },
  180. }).Funcs(funcs)
  181. }
  182. return nil
  183. }
  184. func (t *Template) Execute(w io.Writer, v Values) error {
  185. system, messages := collate(v.Messages)
  186. if v.Prompt != "" && v.Suffix != "" {
  187. return t.Template.Execute(w, map[string]any{
  188. "Prompt": v.Prompt,
  189. "Suffix": v.Suffix,
  190. "Response": "",
  191. })
  192. } else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
  193. return t.Template.Execute(w, map[string]any{
  194. "System": system,
  195. "Messages": messages,
  196. "Tools": v.Tools,
  197. "Response": "",
  198. })
  199. }
  200. system = ""
  201. var b bytes.Buffer
  202. var prompt, response string
  203. for _, m := range messages {
  204. execute := func() error {
  205. if err := t.Template.Execute(&b, map[string]any{
  206. "System": system,
  207. "Prompt": prompt,
  208. "Response": response,
  209. }); err != nil {
  210. return err
  211. }
  212. system = ""
  213. prompt = ""
  214. response = ""
  215. return nil
  216. }
  217. switch m.Role {
  218. case "system":
  219. if prompt != "" || response != "" {
  220. if err := execute(); err != nil {
  221. return err
  222. }
  223. }
  224. system = m.Content
  225. case "user":
  226. if response != "" {
  227. if err := execute(); err != nil {
  228. return err
  229. }
  230. }
  231. prompt = m.Content
  232. case "assistant":
  233. response = m.Content
  234. }
  235. }
  236. var cut bool
  237. nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
  238. if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") {
  239. cut = true
  240. return false
  241. }
  242. return cut
  243. })
  244. tree := parse.Tree{Root: nodes.(*parse.ListNode)}
  245. if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
  246. "System": system,
  247. "Prompt": prompt,
  248. "Response": response,
  249. }); err != nil {
  250. return err
  251. }
  252. _, err := io.Copy(w, &b)
  253. return err
  254. }
  255. // collate messages based on role. consecutive messages of the same role are merged
  256. // into a single message. collate also collects and returns all system messages.
  257. // collate mutates message content adding image tags ([img-%d]) as needed
  258. func collate(msgs []api.Message) (string, []*api.Message) {
  259. var n int
  260. var system []string
  261. var collated []*api.Message
  262. for i := range msgs {
  263. msg := msgs[i]
  264. for range msg.Images {
  265. imageTag := fmt.Sprintf("[img-%d]", n)
  266. if !strings.Contains(msg.Content, "[img]") {
  267. msg.Content = strings.TrimSpace("[img] " + msg.Content)
  268. }
  269. msg.Content = strings.Replace(msg.Content, "[img]", imageTag, 1)
  270. n++
  271. }
  272. if msg.Role == "system" {
  273. system = append(system, msg.Content)
  274. }
  275. if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role {
  276. collated[len(collated)-1].Content += "\n\n" + msg.Content
  277. } else {
  278. collated = append(collated, &msg)
  279. }
  280. }
  281. return strings.Join(system, "\n\n"), collated
  282. }
  283. // Identifiers walks the node tree returning any identifiers it finds along the way
  284. func Identifiers(n parse.Node) []string {
  285. switch n := n.(type) {
  286. case *parse.ListNode:
  287. var names []string
  288. for _, n := range n.Nodes {
  289. names = append(names, Identifiers(n)...)
  290. }
  291. return names
  292. case *parse.TemplateNode:
  293. return Identifiers(n.Pipe)
  294. case *parse.ActionNode:
  295. return Identifiers(n.Pipe)
  296. case *parse.BranchNode:
  297. names := Identifiers(n.Pipe)
  298. for _, n := range []*parse.ListNode{n.List, n.ElseList} {
  299. if n != nil {
  300. names = append(names, Identifiers(n)...)
  301. }
  302. }
  303. return names
  304. case *parse.IfNode:
  305. return Identifiers(&n.BranchNode)
  306. case *parse.RangeNode:
  307. return Identifiers(&n.BranchNode)
  308. case *parse.WithNode:
  309. return Identifiers(&n.BranchNode)
  310. case *parse.PipeNode:
  311. var names []string
  312. for _, c := range n.Cmds {
  313. for _, a := range c.Args {
  314. names = append(names, Identifiers(a)...)
  315. }
  316. }
  317. return names
  318. case *parse.FieldNode:
  319. return n.Ident
  320. case *parse.VariableNode:
  321. return n.Ident
  322. }
  323. return nil
  324. }
  325. // deleteNode walks the node list and deletes nodes that match the predicate
  326. // this is currently to remove the {{ .Response }} node from templates
  327. func deleteNode(n parse.Node, fn func(parse.Node) bool) parse.Node {
  328. var walk func(n parse.Node) parse.Node
  329. walk = func(n parse.Node) parse.Node {
  330. if fn(n) {
  331. return nil
  332. }
  333. switch t := n.(type) {
  334. case *parse.ListNode:
  335. var nodes []parse.Node
  336. for _, c := range t.Nodes {
  337. if n := walk(c); n != nil {
  338. nodes = append(nodes, n)
  339. }
  340. }
  341. t.Nodes = nodes
  342. return t
  343. case *parse.IfNode:
  344. t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
  345. case *parse.WithNode:
  346. t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
  347. case *parse.RangeNode:
  348. t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
  349. case *parse.BranchNode:
  350. t.List = walk(t.List).(*parse.ListNode)
  351. if t.ElseList != nil {
  352. t.ElseList = walk(t.ElseList).(*parse.ListNode)
  353. }
  354. case *parse.ActionNode:
  355. n := walk(t.Pipe)
  356. if n == nil {
  357. return nil
  358. }
  359. t.Pipe = n.(*parse.PipeNode)
  360. case *parse.PipeNode:
  361. var commands []*parse.CommandNode
  362. for _, c := range t.Cmds {
  363. var args []parse.Node
  364. for _, a := range c.Args {
  365. if n := walk(a); n != nil {
  366. args = append(args, n)
  367. }
  368. }
  369. if len(args) == 0 {
  370. return nil
  371. }
  372. c.Args = args
  373. commands = append(commands, c)
  374. }
  375. if len(commands) == 0 {
  376. return nil
  377. }
  378. t.Cmds = commands
  379. }
  380. return n
  381. }
  382. return walk(n)
  383. }