template_test.go 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. package template
  2. import (
  3. "bufio"
  4. "bytes"
  5. "encoding/json"
  6. "io"
  7. "os"
  8. "path/filepath"
  9. "slices"
  10. "testing"
  11. "text/template"
  12. "github.com/ollama/ollama/api"
  13. "github.com/ollama/ollama/llm"
  14. )
  15. func TestNamed(t *testing.T) {
  16. f, err := os.Open(filepath.Join("testdata", "templates.jsonl"))
  17. if err != nil {
  18. t.Fatal(err)
  19. }
  20. defer f.Close()
  21. scanner := bufio.NewScanner(f)
  22. for scanner.Scan() {
  23. var ss map[string]string
  24. if err := json.Unmarshal(scanner.Bytes(), &ss); err != nil {
  25. t.Fatal(err)
  26. }
  27. for k, v := range ss {
  28. t.Run(k, func(t *testing.T) {
  29. kv := llm.KV{"tokenizer.chat_template": v}
  30. s := kv.ChatTemplate()
  31. r, err := Named(s)
  32. if err != nil {
  33. t.Fatal(err)
  34. }
  35. if r.Name != k {
  36. t.Errorf("expected %q, got %q", k, r.Name)
  37. }
  38. var b bytes.Buffer
  39. if _, err := io.Copy(&b, r.Reader()); err != nil {
  40. t.Fatal(err)
  41. }
  42. tmpl, err := template.New(s).Parse(b.String())
  43. if err != nil {
  44. t.Fatal(err)
  45. }
  46. if tmpl.Tree.Root.String() == "" {
  47. t.Errorf("empty %s template", k)
  48. }
  49. })
  50. }
  51. }
  52. }
  53. func TestParse(t *testing.T) {
  54. cases := []struct {
  55. template string
  56. vars []string
  57. }{
  58. {"{{ .Prompt }}", []string{"prompt", "response"}},
  59. {"{{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system"}},
  60. {"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
  61. {"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}},
  62. {"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
  63. {"{{ range .Messages }}{{ if eq .Role \"system\" }}SYSTEM: {{ .Content }}{{ else if eq .Role \"user\" }}USER: {{ .Content }}{{ else if eq .Role \"assistant\" }}ASSISTANT: {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role"}},
  64. }
  65. for _, tt := range cases {
  66. t.Run("", func(t *testing.T) {
  67. tmpl, err := Parse(tt.template)
  68. if err != nil {
  69. t.Fatal(err)
  70. }
  71. vars := tmpl.Vars()
  72. if !slices.Equal(tt.vars, vars) {
  73. t.Errorf("expected %v, got %v", tt.vars, vars)
  74. }
  75. })
  76. }
  77. }
  78. func TestExecuteWithMessages(t *testing.T) {
  79. cases := []struct {
  80. templates []string
  81. values Values
  82. expected string
  83. }{
  84. {
  85. []string{
  86. `[INST] {{ if .System }}{{ .System }}{{ print "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `,
  87. `[INST] {{ if .System }}{{ .System }}{{ print "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`,
  88. `{{- range .Messages }}
  89. {{- if eq .Role "user" }}[INST] {{ if and (isLastMessage $.Messages .) $.System }}{{ $.System }}{{ print "\n\n" }}
  90. {{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
  91. {{- end }}
  92. {{- end }}`,
  93. },
  94. Values{
  95. Messages: []api.Message{
  96. {Role: "user", Content: "Hello friend!"},
  97. {Role: "assistant", Content: "Hello human!"},
  98. {Role: "user", Content: "Yay!"},
  99. },
  100. },
  101. `[INST] Hello friend![/INST] Hello human![INST] Yay![/INST] `,
  102. },
  103. {
  104. []string{
  105. `[INST] {{ if .System }}{{ .System }}{{ print "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `,
  106. `[INST] {{ if .System }}{{ .System }}{{ print "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`,
  107. `
  108. {{- range .Messages }}
  109. {{- if eq .Role "user" }}[INST] {{ if and (isLastMessage $.Messages .) $.System }}{{ $.System }}{{ print "\n\n" }}
  110. {{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
  111. {{- end }}
  112. {{- end }}`,
  113. },
  114. Values{
  115. Messages: []api.Message{
  116. {Role: "system", Content: "You are a helpful assistant!"},
  117. {Role: "user", Content: "Hello friend!"},
  118. {Role: "assistant", Content: "Hello human!"},
  119. {Role: "user", Content: "Yay!"},
  120. },
  121. },
  122. `[INST] Hello friend![/INST] Hello human![INST] You are a helpful assistant!
  123. Yay![/INST] `,
  124. },
  125. {
  126. []string{
  127. `{{ if .System }}<|im_start|>system
  128. {{ .System }}<|im_end|>
  129. {{ end }}{{ if .Prompt }}<|im_start|>user
  130. {{ .Prompt }}<|im_end|>
  131. {{ end }}<|im_start|>assistant
  132. {{ .Response }}<|im_end|>
  133. `,
  134. `
  135. {{- range .Messages }}
  136. {{- if and (eq .Role "user") (isLastMessage $.Messages .) $.System }}<|im_start|>system
  137. {{ $.System }}<|im_end|>{{ print "\n" }}
  138. {{- end }}<|im_start|>{{ .Role }}
  139. {{ .Content }}<|im_end|>{{ print "\n" }}
  140. {{- end }}<|im_start|>assistant
  141. `,
  142. },
  143. Values{
  144. Messages: []api.Message{
  145. {Role: "system", Content: "You are a helpful assistant!"},
  146. {Role: "user", Content: "Hello friend!"},
  147. {Role: "assistant", Content: "Hello human!"},
  148. {Role: "user", Content: "Yay!"},
  149. },
  150. },
  151. `<|im_start|>user
  152. Hello friend!<|im_end|>
  153. <|im_start|>assistant
  154. Hello human!<|im_end|>
  155. <|im_start|>system
  156. You are a helpful assistant!<|im_end|>
  157. <|im_start|>user
  158. Yay!<|im_end|>
  159. <|im_start|>assistant
  160. `,
  161. },
  162. {
  163. []string{
  164. `{{ if .Prompt }}Question: {{ .Prompt }}
  165. {{ end }}Answer: {{ .Response }}
  166. `,
  167. `
  168. {{- range .Messages }}
  169. {{- if eq .Role "user" }}Question: {{ .Content }}{{ print "\n\n" }}
  170. {{- else if eq .Role "assistant" }}Answer: {{ .Content }}{{ print "\n\n" }}
  171. {{- end }}
  172. {{- end }}Answer: `,
  173. },
  174. Values{
  175. Messages: []api.Message{
  176. {Role: "user", Content: "What's in this image?", Images: []api.ImageData{[]byte("")}},
  177. {Role: "assistant", Content: "It's a hot dog."},
  178. {Role: "user", Content: "What's in _this_ image?"},
  179. {Role: "user", Images: []api.ImageData{[]byte("")}},
  180. {Role: "user", Content: "Is it a hot dog?"},
  181. },
  182. },
  183. `Question: [img-0] What's in this image?
  184. Answer: It's a hot dog.
  185. Question: What's in _this_ image?
  186. [img-1]
  187. Is it a hot dog?
  188. Answer: `,
  189. },
  190. }
  191. for _, tt := range cases {
  192. t.Run("", func(t *testing.T) {
  193. for _, tmpl := range tt.templates {
  194. t.Run("", func(t *testing.T) {
  195. tmpl, err := Parse(tmpl)
  196. if err != nil {
  197. t.Fatal(err)
  198. }
  199. var b bytes.Buffer
  200. if err := tmpl.Execute(&b, tt.values); err != nil {
  201. t.Fatal(err)
  202. }
  203. if b.String() != tt.expected {
  204. t.Errorf("expected\n%s,\ngot\n%s", tt.expected, b.String())
  205. }
  206. })
  207. }
  208. })
  209. }
  210. }