prompt_test.go 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. package server
  2. import (
  3. "strings"
  4. "testing"
  5. "github.com/ollama/ollama/api"
  6. "github.com/ollama/ollama/template"
  7. )
  8. func TestPrompt(t *testing.T) {
  9. tests := []struct {
  10. name string
  11. template string
  12. system string
  13. prompt string
  14. response string
  15. generate bool
  16. want string
  17. }{
  18. {
  19. name: "simple prompt",
  20. template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
  21. system: "You are a Wizard.",
  22. prompt: "What are the potion ingredients?",
  23. want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
  24. },
  25. {
  26. name: "implicit response",
  27. template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
  28. system: "You are a Wizard.",
  29. prompt: "What are the potion ingredients?",
  30. response: "I don't know.",
  31. want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]I don't know.",
  32. },
  33. {
  34. name: "response",
  35. template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
  36. system: "You are a Wizard.",
  37. prompt: "What are the potion ingredients?",
  38. response: "I don't know.",
  39. want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
  40. },
  41. {
  42. name: "cut",
  43. template: "<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>",
  44. system: "You are a Wizard.",
  45. prompt: "What are the potion ingredients?",
  46. response: "I don't know.",
  47. generate: true,
  48. want: "<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know.",
  49. },
  50. {
  51. name: "nocut",
  52. template: "<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>",
  53. system: "You are a Wizard.",
  54. prompt: "What are the potion ingredients?",
  55. response: "I don't know.",
  56. want: "<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know.</assistant>",
  57. },
  58. }
  59. for _, tc := range tests {
  60. t.Run(tc.name, func(t *testing.T) {
  61. tmpl, err := template.Parse(tc.template)
  62. if err != nil {
  63. t.Fatal(err)
  64. }
  65. got, err := Prompt(tmpl, tc.system, tc.prompt, tc.response, tc.generate)
  66. if err != nil {
  67. t.Errorf("error = %v", err)
  68. }
  69. if got != tc.want {
  70. t.Errorf("got = %v, want %v", got, tc.want)
  71. }
  72. })
  73. }
  74. }
  75. func TestChatPrompt(t *testing.T) {
  76. tests := []struct {
  77. name string
  78. template string
  79. messages []api.Message
  80. window int
  81. want string
  82. }{
  83. {
  84. name: "simple prompt",
  85. template: "[INST] {{ .Prompt }} [/INST]",
  86. messages: []api.Message{
  87. {Role: "user", Content: "Hello"},
  88. },
  89. window: 1024,
  90. want: "[INST] Hello [/INST]",
  91. },
  92. {
  93. name: "with system message",
  94. template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
  95. messages: []api.Message{
  96. {Role: "system", Content: "You are a Wizard."},
  97. {Role: "user", Content: "Hello"},
  98. },
  99. window: 1024,
  100. want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]",
  101. },
  102. {
  103. name: "with response",
  104. template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }}",
  105. messages: []api.Message{
  106. {Role: "system", Content: "You are a Wizard."},
  107. {Role: "user", Content: "Hello"},
  108. {Role: "assistant", Content: "I am?"},
  109. },
  110. window: 1024,
  111. want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST] I am?",
  112. },
  113. {
  114. name: "with implicit response",
  115. template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
  116. messages: []api.Message{
  117. {Role: "system", Content: "You are a Wizard."},
  118. {Role: "user", Content: "Hello"},
  119. {Role: "assistant", Content: "I am?"},
  120. },
  121. window: 1024,
  122. want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]I am?",
  123. },
  124. {
  125. name: "with conversation",
  126. template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
  127. messages: []api.Message{
  128. {Role: "system", Content: "You are a Wizard."},
  129. {Role: "user", Content: "What are the potion ingredients?"},
  130. {Role: "assistant", Content: "sugar"},
  131. {Role: "user", Content: "Anything else?"},
  132. },
  133. window: 1024,
  134. want: "[INST] <<SYS>>You are a Wizard.<</SYS>> What are the potion ingredients? [/INST] sugar [INST] Anything else? [/INST] ",
  135. },
  136. {
  137. name: "with truncation",
  138. template: "{{ .System }} {{ .Prompt }} {{ .Response }} ",
  139. messages: []api.Message{
  140. {Role: "system", Content: "You are a Wizard."},
  141. {Role: "user", Content: "Hello"},
  142. {Role: "assistant", Content: "I am?"},
  143. {Role: "user", Content: "Why is the sky blue?"},
  144. {Role: "assistant", Content: "The sky is blue from rayleigh scattering"},
  145. },
  146. window: 10,
  147. want: "You are a Wizard. Why is the sky blue? The sky is blue from rayleigh scattering",
  148. },
  149. {
  150. name: "images",
  151. template: "{{ .System }} {{ .Prompt }}",
  152. messages: []api.Message{
  153. {Role: "system", Content: "You are a Wizard."},
  154. {Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("base64")}},
  155. },
  156. window: 1024,
  157. want: "You are a Wizard. [img-0] Hello",
  158. },
  159. {
  160. name: "images truncated",
  161. template: "{{ .System }} {{ .Prompt }}",
  162. messages: []api.Message{
  163. {Role: "system", Content: "You are a Wizard."},
  164. {Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("img1"), []byte("img2")}},
  165. },
  166. window: 1024,
  167. want: "You are a Wizard. [img-0] [img-1] Hello",
  168. },
  169. {
  170. name: "empty list",
  171. template: "{{ .System }} {{ .Prompt }}",
  172. messages: []api.Message{},
  173. window: 1024,
  174. want: "",
  175. },
  176. {
  177. name: "empty prompt",
  178. template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
  179. messages: []api.Message{
  180. {Role: "user", Content: ""},
  181. },
  182. window: 1024,
  183. want: "",
  184. },
  185. }
  186. encode := func(s string) ([]int, error) {
  187. words := strings.Fields(s)
  188. return make([]int, len(words)), nil
  189. }
  190. for _, tc := range tests {
  191. t.Run(tc.name, func(t *testing.T) {
  192. tmpl, err := template.Parse(tc.template)
  193. if err != nil {
  194. t.Fatal(err)
  195. }
  196. got, err := ChatPrompt(tmpl, tc.messages, tc.window, encode)
  197. if err != nil {
  198. t.Errorf("error = %v", err)
  199. }
  200. if got != tc.want {
  201. t.Errorf("got: %q, want: %q", got, tc.want)
  202. }
  203. })
  204. }
  205. }