prompt_test.go 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. package server
  2. import (
  3. "strings"
  4. "testing"
  5. "github.com/jmorganca/ollama/api"
  6. )
  7. func TestPrompt(t *testing.T) {
  8. tests := []struct {
  9. name string
  10. template string
  11. system string
  12. prompt string
  13. response string
  14. generate bool
  15. want string
  16. }{
  17. {
  18. name: "simple prompt",
  19. template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
  20. system: "You are a Wizard.",
  21. prompt: "What are the potion ingredients?",
  22. want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
  23. },
  24. {
  25. name: "implicit response",
  26. template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
  27. system: "You are a Wizard.",
  28. prompt: "What are the potion ingredients?",
  29. response: "I don't know.",
  30. want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]I don't know.",
  31. },
  32. {
  33. name: "response",
  34. template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
  35. system: "You are a Wizard.",
  36. prompt: "What are the potion ingredients?",
  37. response: "I don't know.",
  38. want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
  39. },
  40. {
  41. name: "cut",
  42. template: "<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>",
  43. system: "You are a Wizard.",
  44. prompt: "What are the potion ingredients?",
  45. response: "I don't know.",
  46. generate: true,
  47. want: "<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know.",
  48. },
  49. {
  50. name: "nocut",
  51. template: "<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>",
  52. system: "You are a Wizard.",
  53. prompt: "What are the potion ingredients?",
  54. response: "I don't know.",
  55. want: "<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know.</assistant>",
  56. },
  57. }
  58. for _, tc := range tests {
  59. t.Run(tc.name, func(t *testing.T) {
  60. got, err := Prompt(tc.template, tc.system, tc.prompt, tc.response, tc.generate)
  61. if err != nil {
  62. t.Errorf("error = %v", err)
  63. }
  64. if got != tc.want {
  65. t.Errorf("got = %v, want %v", got, tc.want)
  66. }
  67. })
  68. }
  69. }
  70. func TestChatPrompt(t *testing.T) {
  71. tests := []struct {
  72. name string
  73. template string
  74. system string
  75. messages []api.Message
  76. window int
  77. want string
  78. }{
  79. {
  80. name: "simple prompt",
  81. template: "[INST] {{ .Prompt }} [/INST]",
  82. messages: []api.Message{
  83. {Role: "user", Content: "Hello"},
  84. },
  85. window: 1024,
  86. want: "[INST] Hello [/INST]",
  87. },
  88. {
  89. name: "with default system message",
  90. system: "You are a Wizard.",
  91. template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
  92. messages: []api.Message{
  93. {Role: "user", Content: "Hello"},
  94. },
  95. window: 1024,
  96. want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]",
  97. },
  98. {
  99. name: "with system message",
  100. template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
  101. messages: []api.Message{
  102. {Role: "system", Content: "You are a Wizard."},
  103. {Role: "user", Content: "Hello"},
  104. },
  105. window: 1024,
  106. want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]",
  107. },
  108. {
  109. name: "with response",
  110. template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }}",
  111. messages: []api.Message{
  112. {Role: "system", Content: "You are a Wizard."},
  113. {Role: "user", Content: "Hello"},
  114. {Role: "assistant", Content: "I am?"},
  115. },
  116. window: 1024,
  117. want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST] I am?",
  118. },
  119. {
  120. name: "with implicit response",
  121. template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
  122. messages: []api.Message{
  123. {Role: "system", Content: "You are a Wizard."},
  124. {Role: "user", Content: "Hello"},
  125. {Role: "assistant", Content: "I am?"},
  126. },
  127. window: 1024,
  128. want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]I am?",
  129. },
  130. {
  131. name: "with conversation",
  132. template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
  133. messages: []api.Message{
  134. {Role: "system", Content: "You are a Wizard."},
  135. {Role: "user", Content: "What are the potion ingredients?"},
  136. {Role: "assistant", Content: "sugar"},
  137. {Role: "user", Content: "Anything else?"},
  138. },
  139. window: 1024,
  140. want: "[INST] <<SYS>>You are a Wizard.<</SYS>> What are the potion ingredients? [/INST] sugar [INST] Anything else? [/INST] ",
  141. },
  142. {
  143. name: "with truncation",
  144. template: "{{ .System }} {{ .Prompt }} {{ .Response }} ",
  145. messages: []api.Message{
  146. {Role: "system", Content: "You are a Wizard."},
  147. {Role: "user", Content: "Hello"},
  148. {Role: "assistant", Content: "I am?"},
  149. {Role: "user", Content: "Why is the sky blue?"},
  150. {Role: "assistant", Content: "The sky is blue from rayleigh scattering"},
  151. },
  152. window: 10,
  153. want: "You are a Wizard. Why is the sky blue? The sky is blue from rayleigh scattering",
  154. },
  155. {
  156. name: "images",
  157. template: "{{ .System }} {{ .Prompt }}",
  158. messages: []api.Message{
  159. {Role: "system", Content: "You are a Wizard."},
  160. {Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("base64")}},
  161. },
  162. window: 1024,
  163. want: "You are a Wizard. Hello [img-0]",
  164. },
  165. {
  166. name: "images truncated",
  167. template: "{{ .System }} {{ .Prompt }}",
  168. messages: []api.Message{
  169. {Role: "system", Content: "You are a Wizard."},
  170. {Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("img1"), []byte("img2")}},
  171. },
  172. window: 1024,
  173. want: "You are a Wizard. Hello [img-1]",
  174. },
  175. {
  176. name: "empty list",
  177. template: "{{ .System }} {{ .Prompt }}",
  178. messages: []api.Message{},
  179. window: 1024,
  180. want: "",
  181. },
  182. {
  183. name: "empty list default system",
  184. system: "You are a Wizard.",
  185. template: "{{ .System }} {{ .Prompt }}",
  186. messages: []api.Message{},
  187. window: 1024,
  188. want: "You are a Wizard. ",
  189. },
  190. {
  191. name: "empty user message",
  192. system: "You are a Wizard.",
  193. template: "{{ .System }} {{ .Prompt }}",
  194. messages: []api.Message{
  195. {Role: "user", Content: ""},
  196. },
  197. window: 1024,
  198. want: "You are a Wizard. ",
  199. },
  200. {
  201. name: "empty prompt",
  202. template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
  203. messages: []api.Message{
  204. {Role: "user", Content: ""},
  205. },
  206. window: 1024,
  207. want: "",
  208. },
  209. }
  210. encode := func(s string) ([]int, error) {
  211. words := strings.Fields(s)
  212. return make([]int, len(words)), nil
  213. }
  214. for _, tc := range tests {
  215. t.Run(tc.name, func(t *testing.T) {
  216. got, err := ChatPrompt(tc.template, tc.system, tc.messages, tc.window, encode)
  217. if err != nil {
  218. t.Errorf("error = %v", err)
  219. }
  220. if got != tc.want {
  221. t.Errorf("got = %v, want %v", got, tc.want)
  222. }
  223. })
  224. }
  225. }