template_test.go 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. package template
  2. import (
  3. "bufio"
  4. "bytes"
  5. "encoding/json"
  6. "io"
  7. "os"
  8. "path/filepath"
  9. "slices"
  10. "strings"
  11. "testing"
  12. "github.com/google/go-cmp/cmp"
  13. "github.com/ollama/ollama/api"
  14. "github.com/ollama/ollama/llm"
  15. )
  16. func TestNamed(t *testing.T) {
  17. f, err := os.Open(filepath.Join("testdata", "templates.jsonl"))
  18. if err != nil {
  19. t.Fatal(err)
  20. }
  21. defer f.Close()
  22. scanner := bufio.NewScanner(f)
  23. for scanner.Scan() {
  24. var ss map[string]string
  25. if err := json.Unmarshal(scanner.Bytes(), &ss); err != nil {
  26. t.Fatal(err)
  27. }
  28. for k, v := range ss {
  29. t.Run(k, func(t *testing.T) {
  30. kv := llm.KV{"tokenizer.chat_template": v}
  31. s := kv.ChatTemplate()
  32. r, err := Named(s)
  33. if err != nil {
  34. t.Fatal(err)
  35. }
  36. if r.Name != k {
  37. t.Errorf("expected %q, got %q", k, r.Name)
  38. }
  39. var b bytes.Buffer
  40. if _, err := io.Copy(&b, r.Reader()); err != nil {
  41. t.Fatal(err)
  42. }
  43. tmpl, err := Parse(b.String())
  44. if err != nil {
  45. t.Fatal(err)
  46. }
  47. if tmpl.Tree.Root.String() == "" {
  48. t.Errorf("empty %s template", k)
  49. }
  50. })
  51. }
  52. }
  53. }
  54. func TestTemplate(t *testing.T) {
  55. cases := make(map[string][]api.Message)
  56. for _, mm := range [][]api.Message{
  57. {
  58. {Role: "user", Content: "Hello, how are you?"},
  59. },
  60. {
  61. {Role: "user", Content: "Hello, how are you?"},
  62. {Role: "assistant", Content: "I'm doing great. How can I help you today?"},
  63. {Role: "user", Content: "I'd like to show off how chat templating works!"},
  64. },
  65. {
  66. {Role: "system", Content: "You are a helpful assistant."},
  67. {Role: "user", Content: "Hello, how are you?"},
  68. {Role: "assistant", Content: "I'm doing great. How can I help you today?"},
  69. {Role: "user", Content: "I'd like to show off how chat templating works!"},
  70. },
  71. } {
  72. var roles []string
  73. for _, m := range mm {
  74. roles = append(roles, m.Role)
  75. }
  76. cases[strings.Join(roles, "-")] = mm
  77. }
  78. matches, err := filepath.Glob("*.gotmpl")
  79. if err != nil {
  80. t.Fatal(err)
  81. }
  82. for _, match := range matches {
  83. t.Run(match, func(t *testing.T) {
  84. bts, err := os.ReadFile(match)
  85. if err != nil {
  86. t.Fatal(err)
  87. }
  88. tmpl, err := Parse(string(bts))
  89. if err != nil {
  90. t.Fatal(err)
  91. }
  92. for n, tt := range cases {
  93. var actual bytes.Buffer
  94. t.Run(n, func(t *testing.T) {
  95. if err := tmpl.Execute(&actual, Values{Messages: tt}); err != nil {
  96. t.Fatal(err)
  97. }
  98. expect, err := os.ReadFile(filepath.Join("testdata", match, n))
  99. if err != nil {
  100. t.Fatal(err)
  101. }
  102. bts := actual.Bytes()
  103. if slices.Contains([]string{"chatqa.gotmpl", "llama2-chat.gotmpl", "mistral-instruct.gotmpl", "openchat.gotmpl", "vicuna.gotmpl"}, match) && bts[len(bts)-1] == ' ' {
  104. t.Log("removing trailing space from output")
  105. bts = bts[:len(bts)-1]
  106. }
  107. if diff := cmp.Diff(bts, expect); diff != "" {
  108. t.Errorf("mismatch (-got +want):\n%s", diff)
  109. }
  110. })
  111. t.Run("legacy", func(t *testing.T) {
  112. t.Skip("legacy outputs are currently default outputs")
  113. var legacy bytes.Buffer
  114. if err := tmpl.Execute(&legacy, Values{Messages: tt, forceLegacy: true}); err != nil {
  115. t.Fatal(err)
  116. }
  117. legacyBytes := legacy.Bytes()
  118. if slices.Contains([]string{"chatqa.gotmpl", "openchat.gotmpl", "vicuna.gotmpl"}, match) && legacyBytes[len(legacyBytes)-1] == ' ' {
  119. t.Log("removing trailing space from legacy output")
  120. legacyBytes = legacyBytes[:len(legacyBytes)-1]
  121. } else if slices.Contains([]string{"codellama-70b-instruct.gotmpl", "llama2-chat.gotmpl", "mistral-instruct.gotmpl"}, match) {
  122. t.Skip("legacy outputs cannot be compared to messages outputs")
  123. }
  124. if diff := cmp.Diff(legacyBytes, actual.Bytes()); diff != "" {
  125. t.Errorf("mismatch (-got +want):\n%s", diff)
  126. }
  127. })
  128. }
  129. })
  130. }
  131. }
  132. func TestParse(t *testing.T) {
  133. cases := []struct {
  134. template string
  135. vars []string
  136. }{
  137. {"{{ .Prompt }}", []string{"prompt", "response"}},
  138. {"{{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system"}},
  139. {"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
  140. {"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}},
  141. {"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
  142. {`{{- range .Messages }}
  143. {{- if eq .Role "system" }}SYSTEM:
  144. {{- else if eq .Role "user" }}USER:
  145. {{- else if eq .Role "assistant" }}ASSISTANT:
  146. {{- end }} {{ .Content }}
  147. {{- end }}`, []string{"content", "messages", "role"}},
  148. {`{{- if .Messages }}
  149. {{- range .Messages }}<|im_start|>{{ .Role }}
  150. {{ .Content }}<|im_end|>
  151. {{ end }}<|im_start|>assistant
  152. {{ else -}}
  153. {{ if .System }}<|im_start|>system
  154. {{ .System }}<|im_end|>
  155. {{ end }}{{ if .Prompt }}<|im_start|>user
  156. {{ .Prompt }}<|im_end|>
  157. {{ end }}<|im_start|>assistant
  158. {{ .Response }}<|im_end|>
  159. {{- end -}}`, []string{"content", "messages", "prompt", "response", "role", "system"}},
  160. }
  161. for _, tt := range cases {
  162. t.Run("", func(t *testing.T) {
  163. tmpl, err := Parse(tt.template)
  164. if err != nil {
  165. t.Fatal(err)
  166. }
  167. if diff := cmp.Diff(tmpl.Vars(), tt.vars); diff != "" {
  168. t.Errorf("mismatch (-got +want):\n%s", diff)
  169. }
  170. })
  171. }
  172. }
  173. func TestExecuteWithMessages(t *testing.T) {
  174. type template struct {
  175. name string
  176. template string
  177. }
  178. cases := []struct {
  179. name string
  180. templates []template
  181. values Values
  182. expected string
  183. }{
  184. {
  185. "mistral",
  186. []template{
  187. {"no response", `[INST] {{ if .System }}{{ .System }}
  188. {{ end }}{{ .Prompt }}[/INST] `},
  189. {"response", `[INST] {{ if .System }}{{ .System }}
  190. {{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
  191. {"messages", `[INST] {{ if .System }}{{ .System }}
  192. {{ end }}
  193. {{- range .Messages }}
  194. {{- if eq .Role "user" }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}[INST] {{ end }}
  195. {{- end }}`},
  196. },
  197. Values{
  198. Messages: []api.Message{
  199. {Role: "user", Content: "Hello friend!"},
  200. {Role: "assistant", Content: "Hello human!"},
  201. {Role: "user", Content: "What is your name?"},
  202. },
  203. },
  204. `[INST] Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
  205. },
  206. {
  207. "mistral system",
  208. []template{
  209. {"no response", `[INST] {{ if .System }}{{ .System }}
  210. {{ end }}{{ .Prompt }}[/INST] `},
  211. {"response", `[INST] {{ if .System }}{{ .System }}
  212. {{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
  213. {"messages", `[INST] {{ if .System }}{{ .System }}
  214. {{ end }}
  215. {{- range .Messages }}
  216. {{- if eq .Role "user" }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}[INST] {{ end }}
  217. {{- end }}`},
  218. },
  219. Values{
  220. Messages: []api.Message{
  221. {Role: "system", Content: "You are a helpful assistant!"},
  222. {Role: "user", Content: "Hello friend!"},
  223. {Role: "assistant", Content: "Hello human!"},
  224. {Role: "user", Content: "What is your name?"},
  225. },
  226. },
  227. `[INST] You are a helpful assistant!
  228. Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
  229. },
  230. {
  231. "mistral assistant",
  232. []template{
  233. {"no response", `[INST] {{ .Prompt }}[/INST] `},
  234. {"response", `[INST] {{ .Prompt }}[/INST] {{ .Response }}`},
  235. {"messages", `
  236. {{- range $i, $m := .Messages }}
  237. {{- if eq .Role "user" }}[INST] {{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}{{ end }}
  238. {{- end }}`},
  239. },
  240. Values{
  241. Messages: []api.Message{
  242. {Role: "user", Content: "Hello friend!"},
  243. {Role: "assistant", Content: "Hello human!"},
  244. {Role: "user", Content: "What is your name?"},
  245. {Role: "assistant", Content: "My name is Ollama and I"},
  246. },
  247. },
  248. `[INST] Hello friend![/INST] Hello human![INST] What is your name?[/INST] My name is Ollama and I`,
  249. },
  250. {
  251. "chatml",
  252. []template{
  253. // this does not have a "no response" test because it's impossible to render the same output
  254. {"response", `{{ if .System }}<|im_start|>system
  255. {{ .System }}<|im_end|>
  256. {{ end }}{{ if .Prompt }}<|im_start|>user
  257. {{ .Prompt }}<|im_end|>
  258. {{ end }}<|im_start|>assistant
  259. {{ .Response }}<|im_end|>
  260. `},
  261. {"messages", `
  262. {{- range $index, $_ := .Messages }}<|im_start|>{{ .Role }}
  263. {{ .Content }}<|im_end|>
  264. {{ end }}<|im_start|>assistant
  265. `},
  266. },
  267. Values{
  268. Messages: []api.Message{
  269. {Role: "system", Content: "You are a helpful assistant!"},
  270. {Role: "user", Content: "Hello friend!"},
  271. {Role: "assistant", Content: "Hello human!"},
  272. {Role: "user", Content: "What is your name?"},
  273. },
  274. },
  275. `<|im_start|>system
  276. You are a helpful assistant!<|im_end|>
  277. <|im_start|>user
  278. Hello friend!<|im_end|>
  279. <|im_start|>assistant
  280. Hello human!<|im_end|>
  281. <|im_start|>user
  282. What is your name?<|im_end|>
  283. <|im_start|>assistant
  284. `,
  285. },
  286. }
  287. for _, tt := range cases {
  288. t.Run(tt.name, func(t *testing.T) {
  289. for _, ttt := range tt.templates {
  290. t.Run(ttt.name, func(t *testing.T) {
  291. tmpl, err := Parse(ttt.template)
  292. if err != nil {
  293. t.Fatal(err)
  294. }
  295. var b bytes.Buffer
  296. if err := tmpl.Execute(&b, tt.values); err != nil {
  297. t.Fatal(err)
  298. }
  299. if diff := cmp.Diff(b.String(), tt.expected); diff != "" {
  300. t.Errorf("mismatch (-got +want):\n%s", diff)
  301. }
  302. })
  303. }
  304. })
  305. }
  306. }
  307. func TestExecuteWithSuffix(t *testing.T) {
  308. tmpl, err := Parse(`{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
  309. {{- else }}{{ .Prompt }}
  310. {{- end }}`)
  311. if err != nil {
  312. t.Fatal(err)
  313. }
  314. cases := []struct {
  315. name string
  316. values Values
  317. expect string
  318. }{
  319. {
  320. "message", Values{Messages: []api.Message{{Role: "user", Content: "hello"}}}, "hello",
  321. },
  322. {
  323. "prompt suffix", Values{Prompt: "def add(", Suffix: "return x"}, "<PRE> def add( <SUF>return x <MID>",
  324. },
  325. }
  326. for _, tt := range cases {
  327. t.Run(tt.name, func(t *testing.T) {
  328. var b bytes.Buffer
  329. if err := tmpl.Execute(&b, tt.values); err != nil {
  330. t.Fatal(err)
  331. }
  332. if diff := cmp.Diff(b.String(), tt.expect); diff != "" {
  333. t.Errorf("mismatch (-got +want):\n%s", diff)
  334. }
  335. })
  336. }
  337. }