template_test.go 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. package template
  2. import (
  3. "bufio"
  4. "bytes"
  5. "encoding/json"
  6. "io"
  7. "os"
  8. "path/filepath"
  9. "slices"
  10. "strings"
  11. "testing"
  12. "github.com/google/go-cmp/cmp"
  13. "github.com/ollama/ollama/api"
  14. "github.com/ollama/ollama/llm"
  15. )
  16. func TestNamed(t *testing.T) {
  17. f, err := os.Open(filepath.Join("testdata", "templates.jsonl"))
  18. if err != nil {
  19. t.Fatal(err)
  20. }
  21. defer f.Close()
  22. scanner := bufio.NewScanner(f)
  23. for scanner.Scan() {
  24. var ss map[string]string
  25. if err := json.Unmarshal(scanner.Bytes(), &ss); err != nil {
  26. t.Fatal(err)
  27. }
  28. for k, v := range ss {
  29. t.Run(k, func(t *testing.T) {
  30. kv := llm.KV{"tokenizer.chat_template": v}
  31. s := kv.ChatTemplate()
  32. r, err := Named(s)
  33. if err != nil {
  34. t.Fatal(err)
  35. }
  36. if r.Name != k {
  37. t.Errorf("expected %q, got %q", k, r.Name)
  38. }
  39. var b bytes.Buffer
  40. if _, err := io.Copy(&b, r.Reader()); err != nil {
  41. t.Fatal(err)
  42. }
  43. tmpl, err := Parse(b.String())
  44. if err != nil {
  45. t.Fatal(err)
  46. }
  47. if tmpl.Tree.Root.String() == "" {
  48. t.Errorf("empty %s template", k)
  49. }
  50. })
  51. }
  52. }
  53. }
  54. func TestTemplate(t *testing.T) {
  55. cases := make(map[string][]api.Message)
  56. for _, mm := range [][]api.Message{
  57. {
  58. {Role: "user", Content: "Hello, how are you?"},
  59. },
  60. {
  61. {Role: "user", Content: "Hello, how are you?"},
  62. {Role: "assistant", Content: "I'm doing great. How can I help you today?"},
  63. {Role: "user", Content: "I'd like to show off how chat templating works!"},
  64. },
  65. {
  66. {Role: "system", Content: "You are a helpful assistant."},
  67. {Role: "user", Content: "Hello, how are you?"},
  68. {Role: "assistant", Content: "I'm doing great. How can I help you today?"},
  69. {Role: "user", Content: "I'd like to show off how chat templating works!"},
  70. },
  71. } {
  72. var roles []string
  73. for _, m := range mm {
  74. roles = append(roles, m.Role)
  75. }
  76. cases[strings.Join(roles, "-")] = mm
  77. }
  78. matches, err := filepath.Glob("*.gotmpl")
  79. if err != nil {
  80. t.Fatal(err)
  81. }
  82. for _, match := range matches {
  83. t.Run(match, func(t *testing.T) {
  84. bts, err := os.ReadFile(match)
  85. if err != nil {
  86. t.Fatal(err)
  87. }
  88. tmpl, err := Parse(string(bts))
  89. if err != nil {
  90. t.Fatal(err)
  91. }
  92. for n, tt := range cases {
  93. t.Run(n, func(t *testing.T) {
  94. var actual bytes.Buffer
  95. if err := tmpl.Execute(&actual, Values{Messages: tt}); err != nil {
  96. t.Fatal(err)
  97. }
  98. expect, err := os.ReadFile(filepath.Join("testdata", match, n))
  99. if err != nil {
  100. t.Fatal(err)
  101. }
  102. if diff := cmp.Diff(actual.Bytes(), expect); diff != "" {
  103. t.Errorf("mismatch (-got +want):\n%s", diff)
  104. }
  105. })
  106. }
  107. })
  108. }
  109. }
  110. func TestParse(t *testing.T) {
  111. cases := []struct {
  112. template string
  113. vars []string
  114. }{
  115. {"{{ .Prompt }}", []string{"prompt", "response"}},
  116. {"{{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system"}},
  117. {"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
  118. {"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}},
  119. {"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
  120. {"{{ range .Messages }}{{ if eq .Role \"system\" }}SYSTEM: {{ .Content }}{{ else if eq .Role \"user\" }}USER: {{ .Content }}{{ else if eq .Role \"assistant\" }}ASSISTANT: {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role"}},
  121. }
  122. for _, tt := range cases {
  123. t.Run("", func(t *testing.T) {
  124. tmpl, err := Parse(tt.template)
  125. if err != nil {
  126. t.Fatal(err)
  127. }
  128. vars := tmpl.Vars()
  129. if !slices.Equal(tt.vars, vars) {
  130. t.Errorf("expected %v, got %v", tt.vars, vars)
  131. }
  132. })
  133. }
  134. }
  135. func TestExecuteWithMessages(t *testing.T) {
  136. type template struct {
  137. name string
  138. template string
  139. }
  140. cases := []struct {
  141. name string
  142. templates []template
  143. values Values
  144. expected string
  145. }{
  146. {
  147. "mistral",
  148. []template{
  149. {"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
  150. {"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
  151. {"messages", `{{- range $index, $_ := .Messages }}
  152. {{- if eq .Role "user" }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}{{ "\n\n" }}
  153. {{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
  154. {{- end }}
  155. {{- end }}`},
  156. },
  157. Values{
  158. Messages: []api.Message{
  159. {Role: "user", Content: "Hello friend!"},
  160. {Role: "assistant", Content: "Hello human!"},
  161. {Role: "user", Content: "What is your name?"},
  162. },
  163. },
  164. `[INST] Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
  165. },
  166. {
  167. "mistral system",
  168. []template{
  169. {"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
  170. {"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
  171. {"messages", `
  172. {{- range $index, $_ := .Messages }}
  173. {{- if eq .Role "user" }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}{{ "\n\n" }}
  174. {{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
  175. {{- end }}
  176. {{- end }}`},
  177. },
  178. Values{
  179. Messages: []api.Message{
  180. {Role: "system", Content: "You are a helpful assistant!"},
  181. {Role: "user", Content: "Hello friend!"},
  182. {Role: "assistant", Content: "Hello human!"},
  183. {Role: "user", Content: "What is your name?"},
  184. },
  185. },
  186. `[INST] Hello friend![/INST] Hello human![INST] You are a helpful assistant!
  187. What is your name?[/INST] `,
  188. },
  189. {
  190. "chatml",
  191. []template{
  192. // this does not have a "no response" test because it's impossible to render the same output
  193. {"response", `{{ if .System }}<|im_start|>system
  194. {{ .System }}<|im_end|>
  195. {{ end }}{{ if .Prompt }}<|im_start|>user
  196. {{ .Prompt }}<|im_end|>
  197. {{ end }}<|im_start|>assistant
  198. {{ .Response }}<|im_end|>
  199. `},
  200. {"messages", `
  201. {{- range $index, $_ := .Messages }}
  202. {{- if and (eq .Role "user") (eq (len (slice $.Messages $index)) 1) $.System }}<|im_start|>system
  203. {{ $.System }}<|im_end|>{{ "\n" }}
  204. {{- end }}<|im_start|>{{ .Role }}
  205. {{ .Content }}<|im_end|>{{ "\n" }}
  206. {{- end }}<|im_start|>assistant
  207. `},
  208. },
  209. Values{
  210. Messages: []api.Message{
  211. {Role: "system", Content: "You are a helpful assistant!"},
  212. {Role: "user", Content: "Hello friend!"},
  213. {Role: "assistant", Content: "Hello human!"},
  214. {Role: "user", Content: "What is your name?"},
  215. },
  216. },
  217. `<|im_start|>user
  218. Hello friend!<|im_end|>
  219. <|im_start|>assistant
  220. Hello human!<|im_end|>
  221. <|im_start|>system
  222. You are a helpful assistant!<|im_end|>
  223. <|im_start|>user
  224. What is your name?<|im_end|>
  225. <|im_start|>assistant
  226. `,
  227. },
  228. {
  229. "moondream",
  230. []template{
  231. // this does not have a "no response" test because it's impossible to render the same output
  232. {"response", `{{ if .Prompt }}Question: {{ .Prompt }}
  233. {{ end }}Answer: {{ .Response }}
  234. `},
  235. {"messages", `
  236. {{- range .Messages }}
  237. {{- if eq .Role "user" }}Question: {{ .Content }}{{ "\n\n" }}
  238. {{- else if eq .Role "assistant" }}Answer: {{ .Content }}{{ "\n\n" }}
  239. {{- end }}
  240. {{- end }}Answer: `},
  241. },
  242. Values{
  243. Messages: []api.Message{
  244. {Role: "user", Content: "What's in this image?", Images: []api.ImageData{[]byte("")}},
  245. {Role: "assistant", Content: "It's a hot dog."},
  246. {Role: "user", Content: "What's in _this_ image?"},
  247. {Role: "user", Images: []api.ImageData{[]byte("")}},
  248. {Role: "user", Content: "Is it a hot dog?"},
  249. },
  250. },
  251. `Question: [img-0] What's in this image?
  252. Answer: It's a hot dog.
  253. Question: What's in _this_ image?
  254. [img-1]
  255. Is it a hot dog?
  256. Answer: `,
  257. },
  258. }
  259. for _, tt := range cases {
  260. t.Run(tt.name, func(t *testing.T) {
  261. for _, ttt := range tt.templates {
  262. t.Run(ttt.name, func(t *testing.T) {
  263. tmpl, err := Parse(ttt.template)
  264. if err != nil {
  265. t.Fatal(err)
  266. }
  267. var b bytes.Buffer
  268. if err := tmpl.Execute(&b, tt.values); err != nil {
  269. t.Fatal(err)
  270. }
  271. if b.String() != tt.expected {
  272. t.Errorf("expected\n%s,\ngot\n%s", tt.expected, b.String())
  273. }
  274. })
  275. }
  276. })
  277. }
  278. }