prompt_test.go 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. package server
  2. import (
  3. "strings"
  4. "testing"
  5. "github.com/jmorganca/ollama/api"
  6. )
  7. func TestPrompt(t *testing.T) {
  8. tests := []struct {
  9. name string
  10. template string
  11. system string
  12. prompt string
  13. response string
  14. generate bool
  15. want string
  16. }{
  17. {
  18. name: "simple prompt",
  19. template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
  20. system: "You are a Wizard.",
  21. prompt: "What are the potion ingredients?",
  22. want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
  23. },
  24. {
  25. name: "implicit response",
  26. template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
  27. system: "You are a Wizard.",
  28. prompt: "What are the potion ingredients?",
  29. response: "I don't know.",
  30. want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]I don't know.",
  31. },
  32. {
  33. name: "response",
  34. template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
  35. system: "You are a Wizard.",
  36. prompt: "What are the potion ingredients?",
  37. response: "I don't know.",
  38. want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
  39. },
  40. {
  41. name: "cut",
  42. template: "<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>",
  43. system: "You are a Wizard.",
  44. prompt: "What are the potion ingredients?",
  45. response: "I don't know.",
  46. generate: true,
  47. want: "<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know.",
  48. },
  49. {
  50. name: "nocut",
  51. template: "<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>",
  52. system: "You are a Wizard.",
  53. prompt: "What are the potion ingredients?",
  54. response: "I don't know.",
  55. want: "<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know.</assistant>",
  56. },
  57. }
  58. for _, tc := range tests {
  59. t.Run(tc.name, func(t *testing.T) {
  60. got, err := Prompt(tc.template, tc.system, tc.prompt, tc.response, tc.generate)
  61. if err != nil {
  62. t.Errorf("error = %v", err)
  63. }
  64. if got != tc.want {
  65. t.Errorf("got = %v, want %v", got, tc.want)
  66. }
  67. })
  68. }
  69. }
  70. func TestChatPrompt(t *testing.T) {
  71. tests := []struct {
  72. name string
  73. template string
  74. messages []api.Message
  75. window int
  76. want string
  77. }{
  78. {
  79. name: "simple prompt",
  80. template: "[INST] {{ .Prompt }} [/INST]",
  81. messages: []api.Message{
  82. {Role: "user", Content: "Hello"},
  83. },
  84. window: 1024,
  85. want: "[INST] Hello [/INST]",
  86. },
  87. {
  88. name: "with system message",
  89. template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
  90. messages: []api.Message{
  91. {Role: "system", Content: "You are a Wizard."},
  92. {Role: "user", Content: "Hello"},
  93. },
  94. window: 1024,
  95. want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]",
  96. },
  97. {
  98. name: "with response",
  99. template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }}",
  100. messages: []api.Message{
  101. {Role: "system", Content: "You are a Wizard."},
  102. {Role: "user", Content: "Hello"},
  103. {Role: "assistant", Content: "I am?"},
  104. },
  105. window: 1024,
  106. want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST] I am?",
  107. },
  108. {
  109. name: "with implicit response",
  110. template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
  111. messages: []api.Message{
  112. {Role: "system", Content: "You are a Wizard."},
  113. {Role: "user", Content: "Hello"},
  114. {Role: "assistant", Content: "I am?"},
  115. },
  116. window: 1024,
  117. want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]I am?",
  118. },
  119. {
  120. name: "with conversation",
  121. template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
  122. messages: []api.Message{
  123. {Role: "system", Content: "You are a Wizard."},
  124. {Role: "user", Content: "What are the potion ingredients?"},
  125. {Role: "assistant", Content: "sugar"},
  126. {Role: "user", Content: "Anything else?"},
  127. },
  128. window: 1024,
  129. want: "[INST] <<SYS>>You are a Wizard.<</SYS>> What are the potion ingredients? [/INST] sugar [INST] Anything else? [/INST] ",
  130. },
  131. {
  132. name: "with truncation",
  133. template: "{{ .System }} {{ .Prompt }} {{ .Response }} ",
  134. messages: []api.Message{
  135. {Role: "system", Content: "You are a Wizard."},
  136. {Role: "user", Content: "Hello"},
  137. {Role: "assistant", Content: "I am?"},
  138. {Role: "user", Content: "Why is the sky blue?"},
  139. {Role: "assistant", Content: "The sky is blue from rayleigh scattering"},
  140. },
  141. window: 10,
  142. want: "You are a Wizard. Why is the sky blue? The sky is blue from rayleigh scattering",
  143. },
  144. {
  145. name: "images",
  146. template: "{{ .System }} {{ .Prompt }}",
  147. messages: []api.Message{
  148. {Role: "system", Content: "You are a Wizard."},
  149. {Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("base64")}},
  150. },
  151. window: 1024,
  152. want: "You are a Wizard. Hello [img-0]",
  153. },
  154. {
  155. name: "images truncated",
  156. template: "{{ .System }} {{ .Prompt }}",
  157. messages: []api.Message{
  158. {Role: "system", Content: "You are a Wizard."},
  159. {Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("img1"), []byte("img2")}},
  160. },
  161. window: 1024,
  162. want: "You are a Wizard. Hello [img-1]",
  163. },
  164. {
  165. name: "empty list",
  166. template: "{{ .System }} {{ .Prompt }}",
  167. messages: []api.Message{},
  168. window: 1024,
  169. want: "",
  170. },
  171. {
  172. name: "empty prompt",
  173. template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
  174. messages: []api.Message{
  175. {Role: "user", Content: ""},
  176. },
  177. window: 1024,
  178. want: "",
  179. },
  180. }
  181. encode := func(s string) ([]int, error) {
  182. words := strings.Fields(s)
  183. return make([]int, len(words)), nil
  184. }
  185. for _, tc := range tests {
  186. t.Run(tc.name, func(t *testing.T) {
  187. got, err := ChatPrompt(tc.template, tc.messages, tc.window, encode)
  188. if err != nil {
  189. t.Errorf("error = %v", err)
  190. }
  191. if got != tc.want {
  192. t.Errorf("got = %v, want %v", got, tc.want)
  193. }
  194. })
  195. }
  196. }