template.go 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. package template
  2. import (
  3. "bytes"
  4. "embed"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "math"
  10. "slices"
  11. "strings"
  12. "sync"
  13. "text/template"
  14. "text/template/parse"
  15. "github.com/agnivade/levenshtein"
  16. "github.com/ollama/ollama/api"
  17. "golang.org/x/exp/maps"
  18. )
  19. //go:embed index.json
  20. var indexBytes []byte
  21. //go:embed *.gotmpl
  22. var templatesFS embed.FS
  23. var templatesOnce = sync.OnceValues(func() ([]*named, error) {
  24. var templates []*named
  25. if err := json.Unmarshal(indexBytes, &templates); err != nil {
  26. return nil, err
  27. }
  28. for _, t := range templates {
  29. bts, err := templatesFS.ReadFile(t.Name + ".gotmpl")
  30. if err != nil {
  31. return nil, err
  32. }
  33. // normalize line endings
  34. t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n"))
  35. }
  36. return templates, nil
  37. })
  38. type named struct {
  39. Name string `json:"name"`
  40. Template string `json:"template"`
  41. Bytes []byte
  42. }
  43. func (t named) Reader() io.Reader {
  44. return bytes.NewReader(t.Bytes)
  45. }
  46. func Named(s string) (*named, error) {
  47. templates, err := templatesOnce()
  48. if err != nil {
  49. return nil, err
  50. }
  51. var template *named
  52. score := math.MaxInt
  53. for _, t := range templates {
  54. if s := levenshtein.ComputeDistance(s, t.Template); s < score {
  55. score = s
  56. template = t
  57. }
  58. }
  59. if score < 100 {
  60. return template, nil
  61. }
  62. return nil, errors.New("no matching template found")
  63. }
  64. var DefaultTemplate, _ = Parse("{{ .Prompt }}")
  65. type Template struct {
  66. *template.Template
  67. raw string
  68. }
  69. // response is a template node that can be added to templates that don't already have one
  70. var response = parse.ActionNode{
  71. NodeType: parse.NodeAction,
  72. Pipe: &parse.PipeNode{
  73. NodeType: parse.NodePipe,
  74. Cmds: []*parse.CommandNode{
  75. {
  76. NodeType: parse.NodeCommand,
  77. Args: []parse.Node{
  78. &parse.FieldNode{
  79. NodeType: parse.NodeField,
  80. Ident: []string{"Response"},
  81. },
  82. },
  83. },
  84. },
  85. },
  86. }
  87. func Parse(s string) (*Template, error) {
  88. tmpl := template.New("").Option("missingkey=zero")
  89. tmpl, err := tmpl.Parse(s)
  90. if err != nil {
  91. return nil, err
  92. }
  93. t := Template{Template: tmpl, raw: s}
  94. if vars := t.Vars(); !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") {
  95. // touch up the template and append {{ .Response }}
  96. tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, &response)
  97. }
  98. return &t, nil
  99. }
  100. func (t *Template) String() string {
  101. return t.raw
  102. }
  103. func (t *Template) Vars() []string {
  104. var vars []string
  105. for _, tt := range t.Templates() {
  106. for _, n := range tt.Root.Nodes {
  107. vars = append(vars, parseNode(n)...)
  108. }
  109. }
  110. set := make(map[string]struct{})
  111. for _, n := range vars {
  112. set[strings.ToLower(n)] = struct{}{}
  113. }
  114. vars = maps.Keys(set)
  115. slices.Sort(vars)
  116. return vars
  117. }
  118. type Values struct {
  119. Messages []api.Message
  120. // forceLegacy is a flag used to test compatibility with legacy templates
  121. forceLegacy bool
  122. }
  123. func (t *Template) Execute(w io.Writer, v Values) error {
  124. system, collated := collate(v.Messages)
  125. if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
  126. return t.Template.Execute(w, map[string]any{
  127. "System": system,
  128. "Messages": collated,
  129. })
  130. }
  131. var b bytes.Buffer
  132. var prompt, response string
  133. for i, m := range collated {
  134. switch m.Role {
  135. case "system":
  136. system = m.Content
  137. case "user":
  138. prompt = m.Content
  139. case "assistant":
  140. response = m.Content
  141. }
  142. if i != len(collated)-1 && prompt != "" && response != "" {
  143. if err := t.Template.Execute(&b, map[string]any{
  144. "System": system,
  145. "Prompt": prompt,
  146. "Response": response,
  147. }); err != nil {
  148. return err
  149. }
  150. system = ""
  151. prompt = ""
  152. response = ""
  153. }
  154. }
  155. var cut bool
  156. nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
  157. switch t := n.(type) {
  158. case *parse.ActionNode:
  159. case *parse.FieldNode:
  160. if slices.Contains(t.Ident, "Response") {
  161. cut = true
  162. }
  163. }
  164. return cut
  165. })
  166. tree := parse.Tree{Root: nodes.(*parse.ListNode)}
  167. if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
  168. "System": "",
  169. "Prompt": prompt,
  170. }); err != nil {
  171. return err
  172. }
  173. _, err := io.Copy(w, &b)
  174. return err
  175. }
  176. // collate messages based on role. consecutive messages of the same role are merged
  177. // into a single message. collate also collects and returns all system messages.
  178. // collate mutates message content adding image tags ([img-%d]) as needed
  179. func collate(msgs []api.Message) (string, []*api.Message) {
  180. var n int
  181. var system []string
  182. var collated []*api.Message
  183. for i := range msgs {
  184. msg := msgs[i]
  185. for range msg.Images {
  186. imageTag := fmt.Sprintf("[img-%d]", n)
  187. if !strings.Contains(msg.Content, "[img]") {
  188. msg.Content = strings.TrimSpace("[img] " + msg.Content)
  189. }
  190. msg.Content = strings.Replace(msg.Content, "[img]", imageTag, 1)
  191. n++
  192. }
  193. if msg.Role == "system" {
  194. system = append(system, msg.Content)
  195. }
  196. if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role {
  197. collated[len(collated)-1].Content += "\n\n" + msg.Content
  198. } else {
  199. collated = append(collated, &msg)
  200. }
  201. }
  202. return strings.Join(system, "\n\n"), collated
  203. }
  204. func parseNode(n parse.Node) []string {
  205. switch n := n.(type) {
  206. case *parse.ActionNode:
  207. return parseNode(n.Pipe)
  208. case *parse.IfNode:
  209. names := parseNode(n.Pipe)
  210. names = append(names, parseNode(n.List)...)
  211. if n.ElseList != nil {
  212. names = append(names, parseNode(n.ElseList)...)
  213. }
  214. return names
  215. case *parse.RangeNode:
  216. names := parseNode(n.Pipe)
  217. names = append(names, parseNode(n.List)...)
  218. if n.ElseList != nil {
  219. names = append(names, parseNode(n.ElseList)...)
  220. }
  221. return names
  222. case *parse.WithNode:
  223. names := parseNode(n.Pipe)
  224. names = append(names, parseNode(n.List)...)
  225. if n.ElseList != nil {
  226. names = append(names, parseNode(n.ElseList)...)
  227. }
  228. return names
  229. case *parse.PipeNode:
  230. var names []string
  231. for _, c := range n.Cmds {
  232. for _, a := range c.Args {
  233. names = append(names, parseNode(a)...)
  234. }
  235. }
  236. return names
  237. case *parse.ListNode:
  238. var names []string
  239. for _, n := range n.Nodes {
  240. names = append(names, parseNode(n)...)
  241. }
  242. return names
  243. case *parse.FieldNode:
  244. return n.Ident
  245. case *parse.TemplateNode:
  246. return parseNode(n.Pipe)
  247. }
  248. return nil
  249. }
  250. // deleteNode walks the node list and deletes nodes that match the predicate
  251. // this is currently to remove the {{ .Response }} node from templates
  252. func deleteNode(n parse.Node, fn func(parse.Node) bool) parse.Node {
  253. var walk func(n parse.Node) parse.Node
  254. walk = func(n parse.Node) parse.Node {
  255. if fn(n) {
  256. return nil
  257. }
  258. switch t := n.(type) {
  259. case *parse.ListNode:
  260. var nodes []parse.Node
  261. for _, c := range t.Nodes {
  262. if n := walk(c); n != nil {
  263. nodes = append(nodes, n)
  264. }
  265. }
  266. t.Nodes = nodes
  267. return t
  268. case *parse.IfNode:
  269. t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
  270. case *parse.WithNode:
  271. t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
  272. case *parse.RangeNode:
  273. t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
  274. case *parse.BranchNode:
  275. t.List = walk(t.List).(*parse.ListNode)
  276. if t.ElseList != nil {
  277. t.ElseList = walk(t.ElseList).(*parse.ListNode)
  278. }
  279. case *parse.ActionNode:
  280. n := walk(t.Pipe)
  281. if n == nil {
  282. return nil
  283. }
  284. t.Pipe = n.(*parse.PipeNode)
  285. case *parse.PipeNode:
  286. var commands []*parse.CommandNode
  287. for _, c := range t.Cmds {
  288. var args []parse.Node
  289. for _, a := range c.Args {
  290. if n := walk(a); n != nil {
  291. args = append(args, n)
  292. }
  293. }
  294. if len(args) == 0 {
  295. return nil
  296. }
  297. c.Args = args
  298. commands = append(commands, c)
  299. }
  300. if len(commands) == 0 {
  301. return nil
  302. }
  303. t.Cmds = commands
  304. }
  305. return n
  306. }
  307. return walk(n)
  308. }