template.go 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. package template
  2. import (
  3. "bytes"
  4. "embed"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "math"
  10. "slices"
  11. "strings"
  12. "sync"
  13. "text/template"
  14. "text/template/parse"
  15. "github.com/agnivade/levenshtein"
  16. "github.com/ollama/ollama/api"
  17. "golang.org/x/exp/maps"
  18. )
  19. //go:embed index.json
  20. var indexBytes []byte
  21. //go:embed *.gotmpl
  22. //go:embed *.json
  23. var templatesFS embed.FS
  24. var templatesOnce = sync.OnceValues(func() ([]*named, error) {
  25. var templates []*named
  26. if err := json.Unmarshal(indexBytes, &templates); err != nil {
  27. return nil, err
  28. }
  29. for _, t := range templates {
  30. bts, err := templatesFS.ReadFile(t.Name + ".gotmpl")
  31. if err != nil {
  32. return nil, err
  33. }
  34. // normalize line endings
  35. t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n"))
  36. params, err := templatesFS.ReadFile(t.Name + ".json")
  37. if err != nil {
  38. continue
  39. }
  40. if err := json.Unmarshal(params, &t.Parameters); err != nil {
  41. return nil, err
  42. }
  43. }
  44. return templates, nil
  45. })
  46. type named struct {
  47. Name string `json:"name"`
  48. Template string `json:"template"`
  49. Bytes []byte
  50. Parameters *struct {
  51. Stop []string `json:"stop"`
  52. }
  53. }
  54. func (t named) Reader() io.Reader {
  55. return bytes.NewReader(t.Bytes)
  56. }
  57. func Named(s string) (*named, error) {
  58. templates, err := templatesOnce()
  59. if err != nil {
  60. return nil, err
  61. }
  62. var template *named
  63. score := math.MaxInt
  64. for _, t := range templates {
  65. if s := levenshtein.ComputeDistance(s, t.Template); s < score {
  66. score = s
  67. template = t
  68. }
  69. }
  70. if score < 100 {
  71. return template, nil
  72. }
  73. return nil, errors.New("no matching template found")
  74. }
  75. var DefaultTemplate, _ = Parse("{{ .Prompt }}")
  76. type Template struct {
  77. *template.Template
  78. raw string
  79. }
  80. // response is a template node that can be added to templates that don't already have one
  81. var response = parse.ActionNode{
  82. NodeType: parse.NodeAction,
  83. Pipe: &parse.PipeNode{
  84. NodeType: parse.NodePipe,
  85. Cmds: []*parse.CommandNode{
  86. {
  87. NodeType: parse.NodeCommand,
  88. Args: []parse.Node{
  89. &parse.FieldNode{
  90. NodeType: parse.NodeField,
  91. Ident: []string{"Response"},
  92. },
  93. },
  94. },
  95. },
  96. },
  97. }
  98. func Parse(s string) (*Template, error) {
  99. tmpl := template.New("").Option("missingkey=zero")
  100. tmpl, err := tmpl.Parse(s)
  101. if err != nil {
  102. return nil, err
  103. }
  104. t := Template{Template: tmpl, raw: s}
  105. if vars := t.Vars(); !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") {
  106. // touch up the template and append {{ .Response }}
  107. tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, &response)
  108. }
  109. return &t, nil
  110. }
  111. func (t *Template) String() string {
  112. return t.raw
  113. }
  114. func (t *Template) Vars() []string {
  115. var vars []string
  116. for _, tt := range t.Templates() {
  117. for _, n := range tt.Root.Nodes {
  118. vars = append(vars, parseNode(n)...)
  119. }
  120. }
  121. set := make(map[string]struct{})
  122. for _, n := range vars {
  123. set[strings.ToLower(n)] = struct{}{}
  124. }
  125. vars = maps.Keys(set)
  126. slices.Sort(vars)
  127. return vars
  128. }
  129. type Values struct {
  130. Messages []api.Message
  131. // forceLegacy is a flag used to test compatibility with legacy templates
  132. forceLegacy bool
  133. }
  134. func (t *Template) Execute(w io.Writer, v Values) error {
  135. system, collated := collate(v.Messages)
  136. if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
  137. return t.Template.Execute(w, map[string]any{
  138. "System": system,
  139. "Messages": collated,
  140. })
  141. }
  142. var b bytes.Buffer
  143. var prompt, response string
  144. for i, m := range collated {
  145. switch m.Role {
  146. case "system":
  147. system = m.Content
  148. case "user":
  149. prompt = m.Content
  150. case "assistant":
  151. response = m.Content
  152. }
  153. if i != len(collated)-1 && prompt != "" && response != "" {
  154. if err := t.Template.Execute(&b, map[string]any{
  155. "System": system,
  156. "Prompt": prompt,
  157. "Response": response,
  158. }); err != nil {
  159. return err
  160. }
  161. system = ""
  162. prompt = ""
  163. response = ""
  164. }
  165. }
  166. var cut bool
  167. nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
  168. switch t := n.(type) {
  169. case *parse.ActionNode:
  170. case *parse.FieldNode:
  171. if slices.Contains(t.Ident, "Response") {
  172. cut = true
  173. }
  174. }
  175. return cut
  176. })
  177. tree := parse.Tree{Root: nodes.(*parse.ListNode)}
  178. if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
  179. "System": "",
  180. "Prompt": prompt,
  181. }); err != nil {
  182. return err
  183. }
  184. _, err := io.Copy(w, &b)
  185. return err
  186. }
  187. // collate messages based on role. consecutive messages of the same role are merged
  188. // into a single message. collate also collects and returns all system messages.
  189. // collate mutates message content adding image tags ([img-%d]) as needed
  190. func collate(msgs []api.Message) (string, []*api.Message) {
  191. var n int
  192. var system []string
  193. var collated []*api.Message
  194. for i := range msgs {
  195. msg := msgs[i]
  196. for range msg.Images {
  197. imageTag := fmt.Sprintf("[img-%d]", n)
  198. if !strings.Contains(msg.Content, "[img]") {
  199. msg.Content = strings.TrimSpace("[img] " + msg.Content)
  200. }
  201. msg.Content = strings.Replace(msg.Content, "[img]", imageTag, 1)
  202. n++
  203. }
  204. if msg.Role == "system" {
  205. system = append(system, msg.Content)
  206. }
  207. if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role {
  208. collated[len(collated)-1].Content += "\n\n" + msg.Content
  209. } else {
  210. collated = append(collated, &msg)
  211. }
  212. }
  213. return strings.Join(system, "\n\n"), collated
  214. }
  215. func parseNode(n parse.Node) []string {
  216. switch n := n.(type) {
  217. case *parse.ActionNode:
  218. return parseNode(n.Pipe)
  219. case *parse.IfNode:
  220. names := parseNode(n.Pipe)
  221. names = append(names, parseNode(n.List)...)
  222. if n.ElseList != nil {
  223. names = append(names, parseNode(n.ElseList)...)
  224. }
  225. return names
  226. case *parse.RangeNode:
  227. names := parseNode(n.Pipe)
  228. names = append(names, parseNode(n.List)...)
  229. if n.ElseList != nil {
  230. names = append(names, parseNode(n.ElseList)...)
  231. }
  232. return names
  233. case *parse.WithNode:
  234. names := parseNode(n.Pipe)
  235. names = append(names, parseNode(n.List)...)
  236. if n.ElseList != nil {
  237. names = append(names, parseNode(n.ElseList)...)
  238. }
  239. return names
  240. case *parse.PipeNode:
  241. var names []string
  242. for _, c := range n.Cmds {
  243. for _, a := range c.Args {
  244. names = append(names, parseNode(a)...)
  245. }
  246. }
  247. return names
  248. case *parse.ListNode:
  249. var names []string
  250. for _, n := range n.Nodes {
  251. names = append(names, parseNode(n)...)
  252. }
  253. return names
  254. case *parse.FieldNode:
  255. return n.Ident
  256. case *parse.TemplateNode:
  257. return parseNode(n.Pipe)
  258. }
  259. return nil
  260. }
  261. // deleteNode walks the node list and deletes nodes that match the predicate
  262. // this is currently to remove the {{ .Response }} node from templates
  263. func deleteNode(n parse.Node, fn func(parse.Node) bool) parse.Node {
  264. var walk func(n parse.Node) parse.Node
  265. walk = func(n parse.Node) parse.Node {
  266. if fn(n) {
  267. return nil
  268. }
  269. switch t := n.(type) {
  270. case *parse.ListNode:
  271. var nodes []parse.Node
  272. for _, c := range t.Nodes {
  273. if n := walk(c); n != nil {
  274. nodes = append(nodes, n)
  275. }
  276. }
  277. t.Nodes = nodes
  278. return t
  279. case *parse.IfNode:
  280. t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
  281. case *parse.WithNode:
  282. t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
  283. case *parse.RangeNode:
  284. t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
  285. case *parse.BranchNode:
  286. t.List = walk(t.List).(*parse.ListNode)
  287. if t.ElseList != nil {
  288. t.ElseList = walk(t.ElseList).(*parse.ListNode)
  289. }
  290. case *parse.ActionNode:
  291. n := walk(t.Pipe)
  292. if n == nil {
  293. return nil
  294. }
  295. t.Pipe = n.(*parse.PipeNode)
  296. case *parse.PipeNode:
  297. var commands []*parse.CommandNode
  298. for _, c := range t.Cmds {
  299. var args []parse.Node
  300. for _, a := range c.Args {
  301. if n := walk(a); n != nil {
  302. args = append(args, n)
  303. }
  304. }
  305. if len(args) == 0 {
  306. return nil
  307. }
  308. c.Args = args
  309. commands = append(commands, c)
  310. }
  311. if len(commands) == 0 {
  312. return nil
  313. }
  314. t.Cmds = commands
  315. }
  316. return n
  317. }
  318. return walk(n)
  319. }