template_test.go 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. package template
  2. import (
  3. "bufio"
  4. "bytes"
  5. "encoding/json"
  6. "io"
  7. "os"
  8. "path/filepath"
  9. "slices"
  10. "strings"
  11. "testing"
  12. "github.com/google/go-cmp/cmp"
  13. "github.com/ollama/ollama/api"
  14. "github.com/ollama/ollama/llm"
  15. )
  16. func TestNamed(t *testing.T) {
  17. f, err := os.Open(filepath.Join("testdata", "templates.jsonl"))
  18. if err != nil {
  19. t.Fatal(err)
  20. }
  21. defer f.Close()
  22. scanner := bufio.NewScanner(f)
  23. for scanner.Scan() {
  24. var ss map[string]string
  25. if err := json.Unmarshal(scanner.Bytes(), &ss); err != nil {
  26. t.Fatal(err)
  27. }
  28. for k, v := range ss {
  29. t.Run(k, func(t *testing.T) {
  30. kv := llm.KV{"tokenizer.chat_template": v}
  31. s := kv.ChatTemplate()
  32. r, err := Named(s)
  33. if err != nil {
  34. t.Fatal(err)
  35. }
  36. if r.Name != k {
  37. t.Errorf("expected %q, got %q", k, r.Name)
  38. }
  39. var b bytes.Buffer
  40. if _, err := io.Copy(&b, r.Reader()); err != nil {
  41. t.Fatal(err)
  42. }
  43. tmpl, err := Parse(b.String())
  44. if err != nil {
  45. t.Fatal(err)
  46. }
  47. if tmpl.Tree.Root.String() == "" {
  48. t.Errorf("empty %s template", k)
  49. }
  50. })
  51. }
  52. }
  53. }
  54. func TestTemplate(t *testing.T) {
  55. cases := make(map[string][]api.Message)
  56. for _, mm := range [][]api.Message{
  57. {
  58. {Role: "user", Content: "Hello, how are you?"},
  59. },
  60. {
  61. {Role: "user", Content: "Hello, how are you?"},
  62. {Role: "assistant", Content: "I'm doing great. How can I help you today?"},
  63. {Role: "user", Content: "I'd like to show off how chat templating works!"},
  64. },
  65. {
  66. {Role: "system", Content: "You are a helpful assistant."},
  67. {Role: "user", Content: "Hello, how are you?"},
  68. {Role: "assistant", Content: "I'm doing great. How can I help you today?"},
  69. {Role: "user", Content: "I'd like to show off how chat templating works!"},
  70. },
  71. } {
  72. var roles []string
  73. for _, m := range mm {
  74. roles = append(roles, m.Role)
  75. }
  76. cases[strings.Join(roles, "-")] = mm
  77. }
  78. matches, err := filepath.Glob("*.gotmpl")
  79. if err != nil {
  80. t.Fatal(err)
  81. }
  82. for _, match := range matches {
  83. t.Run(match, func(t *testing.T) {
  84. bts, err := os.ReadFile(match)
  85. if err != nil {
  86. t.Fatal(err)
  87. }
  88. tmpl, err := Parse(string(bts))
  89. if err != nil {
  90. t.Fatal(err)
  91. }
  92. for n, tt := range cases {
  93. var actual bytes.Buffer
  94. t.Run(n, func(t *testing.T) {
  95. if err := tmpl.Execute(&actual, Values{Messages: tt}); err != nil {
  96. t.Fatal(err)
  97. }
  98. expect, err := os.ReadFile(filepath.Join("testdata", match, n))
  99. if err != nil {
  100. t.Fatal(err)
  101. }
  102. if diff := cmp.Diff(actual.Bytes(), expect); diff != "" {
  103. t.Errorf("mismatch (-got +want):\n%s", diff)
  104. }
  105. })
  106. t.Run("legacy", func(t *testing.T) {
  107. var legacy bytes.Buffer
  108. if err := tmpl.Execute(&legacy, Values{Messages: tt, forceLegacy: true}); err != nil {
  109. t.Fatal(err)
  110. }
  111. legacyBytes := legacy.Bytes()
  112. if slices.Contains([]string{"chatqa.gotmpl", "openchat.gotmpl", "vicuna.gotmpl"}, match) && legacyBytes[len(legacyBytes)-1] == ' ' {
  113. t.Log("removing trailing space from legacy output")
  114. legacyBytes = legacyBytes[:len(legacyBytes)-1]
  115. } else if slices.Contains([]string{"codellama-70b-instruct.gotmpl", "llama2-chat.gotmpl", "mistral-instruct.gotmpl"}, match) {
  116. t.Skip("legacy outputs cannot be compared to messages outputs")
  117. }
  118. if diff := cmp.Diff(legacyBytes, actual.Bytes()); diff != "" {
  119. t.Errorf("mismatch (-got +want):\n%s", diff)
  120. }
  121. })
  122. }
  123. })
  124. }
  125. }
  126. func TestParse(t *testing.T) {
  127. cases := []struct {
  128. template string
  129. vars []string
  130. }{
  131. {"{{ .Prompt }}", []string{"prompt", "response"}},
  132. {"{{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system"}},
  133. {"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
  134. {"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}},
  135. {"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
  136. {"{{ range .Messages }}{{ if eq .Role \"system\" }}SYSTEM: {{ .Content }}{{ else if eq .Role \"user\" }}USER: {{ .Content }}{{ else if eq .Role \"assistant\" }}ASSISTANT: {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role"}},
  137. {`{{- if .Messages }}
  138. {{- if .System }}<|im_start|>system
  139. {{ .System }}<|im_end|>
  140. {{ end }}
  141. {{- range .Messages }}<|im_start|>{{ .Role }}
  142. {{ .Content }}<|im_end|>
  143. {{ end }}<|im_start|>assistant
  144. {{ else -}}
  145. {{ if .System }}<|im_start|>system
  146. {{ .System }}<|im_end|>
  147. {{ end }}{{ if .Prompt }}<|im_start|>user
  148. {{ .Prompt }}<|im_end|>
  149. {{ end }}<|im_start|>assistant
  150. {{ .Response }}<|im_end|>
  151. {{- end -}}`, []string{"content", "messages", "prompt", "response", "role", "system"}},
  152. }
  153. for _, tt := range cases {
  154. t.Run("", func(t *testing.T) {
  155. tmpl, err := Parse(tt.template)
  156. if err != nil {
  157. t.Fatal(err)
  158. }
  159. if diff := cmp.Diff(tmpl.Vars(), tt.vars); diff != "" {
  160. t.Errorf("mismatch (-got +want):\n%s", diff)
  161. }
  162. })
  163. }
  164. }
  165. func TestExecuteWithMessages(t *testing.T) {
  166. type template struct {
  167. name string
  168. template string
  169. }
  170. cases := []struct {
  171. name string
  172. templates []template
  173. values Values
  174. expected string
  175. }{
  176. {
  177. "mistral",
  178. []template{
  179. {"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
  180. {"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
  181. {"messages", `{{- range $index, $_ := .Messages }}
  182. {{- if eq .Role "user" }}[INST] {{ if and (eq $index 0) $.System }}{{ $.System }}{{ "\n\n" }}
  183. {{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
  184. {{- end }}
  185. {{- end }}`},
  186. },
  187. Values{
  188. Messages: []api.Message{
  189. {Role: "user", Content: "Hello friend!"},
  190. {Role: "assistant", Content: "Hello human!"},
  191. {Role: "user", Content: "What is your name?"},
  192. },
  193. },
  194. `[INST] Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
  195. },
  196. {
  197. "mistral system",
  198. []template{
  199. {"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
  200. {"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
  201. {"messages", `
  202. {{- range $index, $_ := .Messages }}
  203. {{- if eq .Role "user" }}[INST] {{ if and (eq $index 0) $.System }}{{ $.System }}{{ "\n\n" }}
  204. {{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
  205. {{- end }}
  206. {{- end }}`},
  207. },
  208. Values{
  209. Messages: []api.Message{
  210. {Role: "system", Content: "You are a helpful assistant!"},
  211. {Role: "user", Content: "Hello friend!"},
  212. {Role: "assistant", Content: "Hello human!"},
  213. {Role: "user", Content: "What is your name?"},
  214. },
  215. },
  216. `[INST] You are a helpful assistant!
  217. Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
  218. },
  219. {
  220. "chatml",
  221. []template{
  222. // this does not have a "no response" test because it's impossible to render the same output
  223. {"response", `{{ if .System }}<|im_start|>system
  224. {{ .System }}<|im_end|>
  225. {{ end }}{{ if .Prompt }}<|im_start|>user
  226. {{ .Prompt }}<|im_end|>
  227. {{ end }}<|im_start|>assistant
  228. {{ .Response }}<|im_end|>
  229. `},
  230. {"messages", `
  231. {{- range $index, $_ := .Messages }}
  232. {{- if and (eq .Role "user") (eq $index 0) $.System }}<|im_start|>system
  233. {{ $.System }}<|im_end|>{{ "\n" }}
  234. {{- end }}<|im_start|>{{ .Role }}
  235. {{ .Content }}<|im_end|>{{ "\n" }}
  236. {{- end }}<|im_start|>assistant
  237. `},
  238. },
  239. Values{
  240. Messages: []api.Message{
  241. {Role: "system", Content: "You are a helpful assistant!"},
  242. {Role: "user", Content: "Hello friend!"},
  243. {Role: "assistant", Content: "Hello human!"},
  244. {Role: "user", Content: "What is your name?"},
  245. },
  246. },
  247. `<|im_start|>system
  248. You are a helpful assistant!<|im_end|>
  249. <|im_start|>user
  250. Hello friend!<|im_end|>
  251. <|im_start|>assistant
  252. Hello human!<|im_end|>
  253. <|im_start|>user
  254. What is your name?<|im_end|>
  255. <|im_start|>assistant
  256. `,
  257. },
  258. {
  259. "moondream",
  260. []template{
  261. // this does not have a "no response" test because it's impossible to render the same output
  262. {"response", `{{ if .Prompt }}Question: {{ .Prompt }}
  263. {{ end }}Answer: {{ .Response }}
  264. `},
  265. {"messages", `
  266. {{- range .Messages }}
  267. {{- if eq .Role "user" }}Question: {{ .Content }}{{ "\n\n" }}
  268. {{- else if eq .Role "assistant" }}Answer: {{ .Content }}{{ "\n\n" }}
  269. {{- end }}
  270. {{- end }}Answer: `},
  271. },
  272. Values{
  273. Messages: []api.Message{
  274. {Role: "user", Content: "What's in this image?", Images: []api.ImageData{[]byte("")}},
  275. {Role: "assistant", Content: "It's a hot dog."},
  276. {Role: "user", Content: "What's in _this_ image?"},
  277. {Role: "user", Images: []api.ImageData{[]byte("")}},
  278. {Role: "user", Content: "Is it a hot dog?"},
  279. },
  280. },
  281. `Question: [img-0] What's in this image?
  282. Answer: It's a hot dog.
  283. Question: What's in _this_ image?
  284. [img-1]
  285. Is it a hot dog?
  286. Answer: `,
  287. },
  288. }
  289. for _, tt := range cases {
  290. t.Run(tt.name, func(t *testing.T) {
  291. for _, ttt := range tt.templates {
  292. t.Run(ttt.name, func(t *testing.T) {
  293. tmpl, err := Parse(ttt.template)
  294. if err != nil {
  295. t.Fatal(err)
  296. }
  297. var b bytes.Buffer
  298. if err := tmpl.Execute(&b, tt.values); err != nil {
  299. t.Fatal(err)
  300. }
  301. if diff := cmp.Diff(b.String(), tt.expected); diff != "" {
  302. t.Errorf("mismatch (-got +want):\n%s", diff)
  303. }
  304. })
  305. }
  306. })
  307. }
  308. }