template_test.go 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. package template
  2. import (
  3. "bufio"
  4. "bytes"
  5. "encoding/json"
  6. "io"
  7. "os"
  8. "path/filepath"
  9. "slices"
  10. "strings"
  11. "testing"
  12. "github.com/google/go-cmp/cmp"
  13. "github.com/ollama/ollama/api"
  14. "github.com/ollama/ollama/llm"
  15. )
  16. func TestNamed(t *testing.T) {
  17. f, err := os.Open(filepath.Join("testdata", "templates.jsonl"))
  18. if err != nil {
  19. t.Fatal(err)
  20. }
  21. defer f.Close()
  22. scanner := bufio.NewScanner(f)
  23. for scanner.Scan() {
  24. var ss map[string]string
  25. if err := json.Unmarshal(scanner.Bytes(), &ss); err != nil {
  26. t.Fatal(err)
  27. }
  28. for k, v := range ss {
  29. t.Run(k, func(t *testing.T) {
  30. kv := llm.KV{"tokenizer.chat_template": v}
  31. s := kv.ChatTemplate()
  32. r, err := Named(s)
  33. if err != nil {
  34. t.Fatal(err)
  35. }
  36. if r.Name != k {
  37. t.Errorf("expected %q, got %q", k, r.Name)
  38. }
  39. var b bytes.Buffer
  40. if _, err := io.Copy(&b, r.Reader()); err != nil {
  41. t.Fatal(err)
  42. }
  43. tmpl, err := Parse(b.String())
  44. if err != nil {
  45. t.Fatal(err)
  46. }
  47. if tmpl.Tree.Root.String() == "" {
  48. t.Errorf("empty %s template", k)
  49. }
  50. })
  51. }
  52. }
  53. }
  54. func TestTemplate(t *testing.T) {
  55. cases := make(map[string][]api.Message)
  56. for _, mm := range [][]api.Message{
  57. {
  58. {Role: "user", Content: "Hello, how are you?"},
  59. },
  60. {
  61. {Role: "user", Content: "Hello, how are you?"},
  62. {Role: "assistant", Content: "I'm doing great. How can I help you today?"},
  63. {Role: "user", Content: "I'd like to show off how chat templating works!"},
  64. },
  65. {
  66. {Role: "system", Content: "You are a helpful assistant."},
  67. {Role: "user", Content: "Hello, how are you?"},
  68. {Role: "assistant", Content: "I'm doing great. How can I help you today?"},
  69. {Role: "user", Content: "I'd like to show off how chat templating works!"},
  70. },
  71. } {
  72. var roles []string
  73. for _, m := range mm {
  74. roles = append(roles, m.Role)
  75. }
  76. cases[strings.Join(roles, "-")] = mm
  77. }
  78. matches, err := filepath.Glob("*.gotmpl")
  79. if err != nil {
  80. t.Fatal(err)
  81. }
  82. for _, match := range matches {
  83. t.Run(match, func(t *testing.T) {
  84. bts, err := os.ReadFile(match)
  85. if err != nil {
  86. t.Fatal(err)
  87. }
  88. tmpl, err := Parse(string(bts))
  89. if err != nil {
  90. t.Fatal(err)
  91. }
  92. for n, tt := range cases {
  93. var actual bytes.Buffer
  94. t.Run(n, func(t *testing.T) {
  95. if err := tmpl.Execute(&actual, Values{Messages: tt}); err != nil {
  96. t.Fatal(err)
  97. }
  98. expect, err := os.ReadFile(filepath.Join("testdata", match, n))
  99. if err != nil {
  100. t.Fatal(err)
  101. }
  102. if diff := cmp.Diff(actual.Bytes(), expect); diff != "" {
  103. t.Errorf("mismatch (-got +want):\n%s", diff)
  104. }
  105. })
  106. t.Run("legacy", func(t *testing.T) {
  107. t.Skip("legacy outputs are currently default outputs")
  108. var legacy bytes.Buffer
  109. if err := tmpl.Execute(&legacy, Values{Messages: tt, forceLegacy: true}); err != nil {
  110. t.Fatal(err)
  111. }
  112. legacyBytes := legacy.Bytes()
  113. if slices.Contains([]string{"chatqa.gotmpl", "openchat.gotmpl", "vicuna.gotmpl"}, match) && legacyBytes[len(legacyBytes)-1] == ' ' {
  114. t.Log("removing trailing space from legacy output")
  115. legacyBytes = legacyBytes[:len(legacyBytes)-1]
  116. } else if slices.Contains([]string{"codellama-70b-instruct.gotmpl", "llama2-chat.gotmpl", "mistral-instruct.gotmpl"}, match) {
  117. t.Skip("legacy outputs cannot be compared to messages outputs")
  118. }
  119. if diff := cmp.Diff(legacyBytes, actual.Bytes()); diff != "" {
  120. t.Errorf("mismatch (-got +want):\n%s", diff)
  121. }
  122. })
  123. }
  124. })
  125. }
  126. }
  127. func TestParse(t *testing.T) {
  128. cases := []struct {
  129. template string
  130. vars []string
  131. }{
  132. {"{{ .Prompt }}", []string{"prompt", "response"}},
  133. {"{{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system"}},
  134. {"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
  135. {"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}},
  136. {"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
  137. {`{{- range .Messages }}
  138. {{- if eq .Role "system" }}SYSTEM:
  139. {{- else if eq .Role "user" }}USER:
  140. {{- else if eq .Role "assistant" }}ASSISTANT:
  141. {{- end }} {{ .Content }}
  142. {{- end }}`, []string{"content", "messages", "role"}},
  143. {`{{- if .Messages }}
  144. {{- range .Messages }}<|im_start|>{{ .Role }}
  145. {{ .Content }}<|im_end|>
  146. {{ end }}<|im_start|>assistant
  147. {{ else -}}
  148. {{ if .System }}<|im_start|>system
  149. {{ .System }}<|im_end|>
  150. {{ end }}{{ if .Prompt }}<|im_start|>user
  151. {{ .Prompt }}<|im_end|>
  152. {{ end }}<|im_start|>assistant
  153. {{ .Response }}<|im_end|>
  154. {{- end -}}`, []string{"content", "messages", "prompt", "response", "role", "system"}},
  155. }
  156. for _, tt := range cases {
  157. t.Run("", func(t *testing.T) {
  158. tmpl, err := Parse(tt.template)
  159. if err != nil {
  160. t.Fatal(err)
  161. }
  162. if diff := cmp.Diff(tmpl.Vars(), tt.vars); diff != "" {
  163. t.Errorf("mismatch (-got +want):\n%s", diff)
  164. }
  165. })
  166. }
  167. }
  168. func TestExecuteWithMessages(t *testing.T) {
  169. type template struct {
  170. name string
  171. template string
  172. }
  173. cases := []struct {
  174. name string
  175. templates []template
  176. values Values
  177. expected string
  178. }{
  179. {
  180. "mistral",
  181. []template{
  182. {"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
  183. {"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
  184. {"messages", `{{- range $index, $_ := .Messages }}
  185. {{- if eq .Role "user" }}[INST] {{ if and (eq $index 0) $.System }}{{ $.System }}{{ "\n\n" }}
  186. {{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
  187. {{- end }}
  188. {{- end }}`},
  189. },
  190. Values{
  191. Messages: []api.Message{
  192. {Role: "user", Content: "Hello friend!"},
  193. {Role: "assistant", Content: "Hello human!"},
  194. {Role: "user", Content: "What is your name?"},
  195. },
  196. },
  197. `[INST] Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
  198. },
  199. {
  200. "mistral system",
  201. []template{
  202. {"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
  203. {"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
  204. {"messages", `
  205. {{- range $index, $_ := .Messages }}
  206. {{- if eq .Role "user" }}[INST] {{ if and (eq $index 0) $.System }}{{ $.System }}{{ "\n\n" }}
  207. {{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
  208. {{- end }}
  209. {{- end }}`},
  210. },
  211. Values{
  212. Messages: []api.Message{
  213. {Role: "system", Content: "You are a helpful assistant!"},
  214. {Role: "user", Content: "Hello friend!"},
  215. {Role: "assistant", Content: "Hello human!"},
  216. {Role: "user", Content: "What is your name?"},
  217. },
  218. },
  219. `[INST] You are a helpful assistant!
  220. Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
  221. },
  222. {
  223. "chatml",
  224. []template{
  225. // this does not have a "no response" test because it's impossible to render the same output
  226. {"response", `{{ if .System }}<|im_start|>system
  227. {{ .System }}<|im_end|>
  228. {{ end }}{{ if .Prompt }}<|im_start|>user
  229. {{ .Prompt }}<|im_end|>
  230. {{ end }}<|im_start|>assistant
  231. {{ .Response }}<|im_end|>
  232. `},
  233. {"messages", `
  234. {{- range $index, $_ := .Messages }}
  235. {{- if and (eq .Role "user") (eq $index 0) $.System }}<|im_start|>system
  236. {{ $.System }}<|im_end|>{{ "\n" }}
  237. {{- end }}<|im_start|>{{ .Role }}
  238. {{ .Content }}<|im_end|>{{ "\n" }}
  239. {{- end }}<|im_start|>assistant
  240. `},
  241. },
  242. Values{
  243. Messages: []api.Message{
  244. {Role: "system", Content: "You are a helpful assistant!"},
  245. {Role: "user", Content: "Hello friend!"},
  246. {Role: "assistant", Content: "Hello human!"},
  247. {Role: "user", Content: "What is your name?"},
  248. },
  249. },
  250. `<|im_start|>system
  251. You are a helpful assistant!<|im_end|>
  252. <|im_start|>user
  253. Hello friend!<|im_end|>
  254. <|im_start|>assistant
  255. Hello human!<|im_end|>
  256. <|im_start|>user
  257. What is your name?<|im_end|>
  258. <|im_start|>assistant
  259. `,
  260. },
  261. {
  262. "moondream",
  263. []template{
  264. // this does not have a "no response" test because it's impossible to render the same output
  265. {"response", `{{ if .Prompt }}Question: {{ .Prompt }}
  266. {{ end }}Answer: {{ .Response }}
  267. `},
  268. {"messages", `
  269. {{- range .Messages }}
  270. {{- if eq .Role "user" }}Question: {{ .Content }}{{ "\n\n" }}
  271. {{- else if eq .Role "assistant" }}Answer: {{ .Content }}{{ "\n\n" }}
  272. {{- end }}
  273. {{- end }}Answer: `},
  274. },
  275. Values{
  276. Messages: []api.Message{
  277. {Role: "user", Content: "What's in this image?", Images: []api.ImageData{[]byte("")}},
  278. {Role: "assistant", Content: "It's a hot dog."},
  279. {Role: "user", Content: "What's in _this_ image?"},
  280. {Role: "user", Images: []api.ImageData{[]byte("")}},
  281. {Role: "user", Content: "Is it a hot dog?"},
  282. },
  283. },
  284. `Question: [img-0] What's in this image?
  285. Answer: It's a hot dog.
  286. Question: What's in _this_ image?
  287. [img-1]
  288. Is it a hot dog?
  289. Answer: `,
  290. },
  291. }
  292. for _, tt := range cases {
  293. t.Run(tt.name, func(t *testing.T) {
  294. for _, ttt := range tt.templates {
  295. t.Run(ttt.name, func(t *testing.T) {
  296. tmpl, err := Parse(ttt.template)
  297. if err != nil {
  298. t.Fatal(err)
  299. }
  300. var b bytes.Buffer
  301. if err := tmpl.Execute(&b, tt.values); err != nil {
  302. t.Fatal(err)
  303. }
  304. if diff := cmp.Diff(b.String(), tt.expected); diff != "" {
  305. t.Errorf("mismatch (-got +want):\n%s", diff)
  306. }
  307. })
  308. }
  309. })
  310. }
  311. }