template_test.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. package template
  2. import (
  3. "bufio"
  4. "bytes"
  5. "encoding/json"
  6. "io"
  7. "os"
  8. "path/filepath"
  9. "slices"
  10. "testing"
  11. "text/template"
  12. "github.com/ollama/ollama/api"
  13. "github.com/ollama/ollama/llm"
  14. )
  15. func TestNamed(t *testing.T) {
  16. f, err := os.Open(filepath.Join("testdata", "templates.jsonl"))
  17. if err != nil {
  18. t.Fatal(err)
  19. }
  20. defer f.Close()
  21. scanner := bufio.NewScanner(f)
  22. for scanner.Scan() {
  23. var ss map[string]string
  24. if err := json.Unmarshal(scanner.Bytes(), &ss); err != nil {
  25. t.Fatal(err)
  26. }
  27. for k, v := range ss {
  28. t.Run(k, func(t *testing.T) {
  29. kv := llm.KV{"tokenizer.chat_template": v}
  30. s := kv.ChatTemplate()
  31. r, err := Named(s)
  32. if err != nil {
  33. t.Fatal(err)
  34. }
  35. if r.Name != k {
  36. t.Errorf("expected %q, got %q", k, r.Name)
  37. }
  38. var b bytes.Buffer
  39. if _, err := io.Copy(&b, r.Reader()); err != nil {
  40. t.Fatal(err)
  41. }
  42. tmpl, err := template.New(s).Parse(b.String())
  43. if err != nil {
  44. t.Fatal(err)
  45. }
  46. if tmpl.Tree.Root.String() == "" {
  47. t.Errorf("empty %s template", k)
  48. }
  49. })
  50. }
  51. }
  52. }
  53. func TestParse(t *testing.T) {
  54. cases := []struct {
  55. template string
  56. vars []string
  57. }{
  58. {"{{ .Prompt }}", []string{"prompt", "response"}},
  59. {"{{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system"}},
  60. {"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
  61. {"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}},
  62. {"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
  63. {"{{ range .Messages }}{{ if eq .Role \"system\" }}SYSTEM: {{ .Content }}{{ else if eq .Role \"user\" }}USER: {{ .Content }}{{ else if eq .Role \"assistant\" }}ASSISTANT: {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role"}},
  64. }
  65. for _, tt := range cases {
  66. t.Run("", func(t *testing.T) {
  67. tmpl, err := Parse(tt.template)
  68. if err != nil {
  69. t.Fatal(err)
  70. }
  71. vars := tmpl.Vars()
  72. if !slices.Equal(tt.vars, vars) {
  73. t.Errorf("expected %v, got %v", tt.vars, vars)
  74. }
  75. })
  76. }
  77. }
  78. func TestExecuteWithMessages(t *testing.T) {
  79. type template struct {
  80. name string
  81. template string
  82. }
  83. cases := []struct {
  84. name string
  85. templates []template
  86. values Values
  87. expected string
  88. }{
  89. {
  90. "mistral",
  91. []template{
  92. {"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
  93. {"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
  94. {"messages", `{{- range $index, $_ := .Messages }}
  95. {{- if eq .Role "user" }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}{{ "\n\n" }}
  96. {{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
  97. {{- end }}
  98. {{- end }}`},
  99. },
  100. Values{
  101. Messages: []api.Message{
  102. {Role: "user", Content: "Hello friend!"},
  103. {Role: "assistant", Content: "Hello human!"},
  104. {Role: "user", Content: "What is your name?"},
  105. },
  106. },
  107. `[INST] Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
  108. },
  109. {
  110. "mistral system",
  111. []template{
  112. {"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
  113. {"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
  114. {"messages", `
  115. {{- range $index, $_ := .Messages }}
  116. {{- if eq .Role "user" }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}{{ "\n\n" }}
  117. {{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
  118. {{- end }}
  119. {{- end }}`},
  120. },
  121. Values{
  122. Messages: []api.Message{
  123. {Role: "system", Content: "You are a helpful assistant!"},
  124. {Role: "user", Content: "Hello friend!"},
  125. {Role: "assistant", Content: "Hello human!"},
  126. {Role: "user", Content: "What is your name?"},
  127. },
  128. },
  129. `[INST] Hello friend![/INST] Hello human![INST] You are a helpful assistant!
  130. What is your name?[/INST] `,
  131. },
  132. {
  133. "chatml",
  134. []template{
  135. // this does not have a "no response" test because it's impossible to render the same output
  136. {"response", `{{ if .System }}<|im_start|>system
  137. {{ .System }}<|im_end|>
  138. {{ end }}{{ if .Prompt }}<|im_start|>user
  139. {{ .Prompt }}<|im_end|>
  140. {{ end }}<|im_start|>assistant
  141. {{ .Response }}<|im_end|>
  142. `},
  143. {"messages", `
  144. {{- range $index, $_ := .Messages }}
  145. {{- if and (eq .Role "user") (eq (len (slice $.Messages $index)) 1) $.System }}<|im_start|>system
  146. {{ $.System }}<|im_end|>{{ "\n" }}
  147. {{- end }}<|im_start|>{{ .Role }}
  148. {{ .Content }}<|im_end|>{{ "\n" }}
  149. {{- end }}<|im_start|>assistant
  150. `},
  151. },
  152. Values{
  153. Messages: []api.Message{
  154. {Role: "system", Content: "You are a helpful assistant!"},
  155. {Role: "user", Content: "Hello friend!"},
  156. {Role: "assistant", Content: "Hello human!"},
  157. {Role: "user", Content: "What is your name?"},
  158. },
  159. },
  160. `<|im_start|>user
  161. Hello friend!<|im_end|>
  162. <|im_start|>assistant
  163. Hello human!<|im_end|>
  164. <|im_start|>system
  165. You are a helpful assistant!<|im_end|>
  166. <|im_start|>user
  167. What is your name?<|im_end|>
  168. <|im_start|>assistant
  169. `,
  170. },
  171. {
  172. "moondream",
  173. []template{
  174. // this does not have a "no response" test because it's impossible to render the same output
  175. {"response", `{{ if .Prompt }}Question: {{ .Prompt }}
  176. {{ end }}Answer: {{ .Response }}
  177. `},
  178. {"messages", `
  179. {{- range .Messages }}
  180. {{- if eq .Role "user" }}Question: {{ .Content }}{{ "\n\n" }}
  181. {{- else if eq .Role "assistant" }}Answer: {{ .Content }}{{ "\n\n" }}
  182. {{- end }}
  183. {{- end }}Answer: `},
  184. },
  185. Values{
  186. Messages: []api.Message{
  187. {Role: "user", Content: "What's in this image?", Images: []api.ImageData{[]byte("")}},
  188. {Role: "assistant", Content: "It's a hot dog."},
  189. {Role: "user", Content: "What's in _this_ image?"},
  190. {Role: "user", Images: []api.ImageData{[]byte("")}},
  191. {Role: "user", Content: "Is it a hot dog?"},
  192. },
  193. },
  194. `Question: [img-0] What's in this image?
  195. Answer: It's a hot dog.
  196. Question: What's in _this_ image?
  197. [img-1]
  198. Is it a hot dog?
  199. Answer: `,
  200. },
  201. }
  202. for _, tt := range cases {
  203. t.Run(tt.name, func(t *testing.T) {
  204. for _, ttt := range tt.templates {
  205. t.Run(ttt.name, func(t *testing.T) {
  206. tmpl, err := Parse(ttt.template)
  207. if err != nil {
  208. t.Fatal(err)
  209. }
  210. var b bytes.Buffer
  211. if err := tmpl.Execute(&b, tt.values); err != nil {
  212. t.Fatal(err)
  213. }
  214. if b.String() != tt.expected {
  215. t.Errorf("expected\n%s,\ngot\n%s", tt.expected, b.String())
  216. }
  217. })
  218. }
  219. })
  220. }
  221. }