images_test.go 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. package server
  2. import (
  3. "strings"
  4. "testing"
  5. "github.com/jmorganca/ollama/api"
  6. )
  7. func TestPrompt(t *testing.T) {
  8. tests := []struct {
  9. name string
  10. template string
  11. vars PromptVars
  12. want string
  13. wantErr bool
  14. }{
  15. {
  16. name: "System Prompt",
  17. template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
  18. vars: PromptVars{
  19. System: "You are a Wizard.",
  20. Prompt: "What are the potion ingredients?",
  21. },
  22. want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
  23. },
  24. {
  25. name: "System Prompt with Response",
  26. template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
  27. vars: PromptVars{
  28. System: "You are a Wizard.",
  29. Prompt: "What are the potion ingredients?",
  30. Response: "I don't know.",
  31. },
  32. want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
  33. },
  34. {
  35. name: "Conditional Logic Nodes",
  36. template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
  37. vars: PromptVars{
  38. First: true,
  39. System: "You are a Wizard.",
  40. Prompt: "What are the potion ingredients?",
  41. Response: "I don't know.",
  42. },
  43. want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
  44. },
  45. }
  46. for _, tt := range tests {
  47. t.Run(tt.name, func(t *testing.T) {
  48. got, err := Prompt(tt.template, tt.vars)
  49. if (err != nil) != tt.wantErr {
  50. t.Errorf("Prompt() error = %v, wantErr %v", err, tt.wantErr)
  51. return
  52. }
  53. if got != tt.want {
  54. t.Errorf("Prompt() got = %v, want %v", got, tt.want)
  55. }
  56. })
  57. }
  58. }
  59. func TestModel_PreResponsePrompt(t *testing.T) {
  60. tests := []struct {
  61. name string
  62. template string
  63. vars PromptVars
  64. want string
  65. wantErr bool
  66. }{
  67. {
  68. name: "No Response in Template",
  69. template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
  70. vars: PromptVars{
  71. System: "You are a Wizard.",
  72. Prompt: "What are the potion ingredients?",
  73. },
  74. want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
  75. },
  76. {
  77. name: "Response in Template",
  78. template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
  79. vars: PromptVars{
  80. System: "You are a Wizard.",
  81. Prompt: "What are the potion ingredients?",
  82. },
  83. want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] ",
  84. },
  85. {
  86. name: "Response in Template with Trailing Formatting",
  87. template: "<|im_start|>user\n{{ .Prompt }}<|im_end|><|im_start|>assistant\n{{ .Response }}<|im_end|>",
  88. vars: PromptVars{
  89. Prompt: "What are the potion ingredients?",
  90. },
  91. want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\n",
  92. },
  93. {
  94. name: "Response in Template with Alternative Formatting",
  95. template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
  96. vars: PromptVars{
  97. Prompt: "What are the potion ingredients?",
  98. },
  99. want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\n",
  100. },
  101. }
  102. for _, tt := range tests {
  103. m := Model{Template: tt.template}
  104. t.Run(tt.name, func(t *testing.T) {
  105. got, err := m.PreResponsePrompt(tt.vars)
  106. if (err != nil) != tt.wantErr {
  107. t.Errorf("PreResponsePrompt() error = %v, wantErr %v", err, tt.wantErr)
  108. return
  109. }
  110. if got != tt.want {
  111. t.Errorf("PreResponsePrompt() got = %v, want %v", got, tt.want)
  112. }
  113. })
  114. }
  115. }
  116. func TestModel_PostResponsePrompt(t *testing.T) {
  117. tests := []struct {
  118. name string
  119. template string
  120. vars PromptVars
  121. want string
  122. wantErr bool
  123. }{
  124. {
  125. name: "No Response in Template",
  126. template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
  127. vars: PromptVars{
  128. Response: "I don't know.",
  129. },
  130. want: "I don't know.",
  131. },
  132. {
  133. name: "Response in Template",
  134. template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
  135. vars: PromptVars{
  136. Response: "I don't know.",
  137. },
  138. want: "I don't know.",
  139. },
  140. {
  141. name: "Response in Template with Trailing Formatting",
  142. template: "<|im_start|>user\n{{ .Prompt }}<|im_end|><|im_start|>assistant\n{{ .Response }}<|im_end|>",
  143. vars: PromptVars{
  144. Response: "I don't know.",
  145. },
  146. want: "I don't know.<|im_end|>",
  147. },
  148. {
  149. name: "Response in Template with Alternative Formatting",
  150. template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
  151. vars: PromptVars{
  152. Response: "I don't know.",
  153. },
  154. want: "I don't know.<|im_end|>",
  155. },
  156. }
  157. for _, tt := range tests {
  158. m := Model{Template: tt.template}
  159. t.Run(tt.name, func(t *testing.T) {
  160. got, err := m.PostResponseTemplate(tt.vars)
  161. if (err != nil) != tt.wantErr {
  162. t.Errorf("PostResponseTemplate() error = %v, wantErr %v", err, tt.wantErr)
  163. return
  164. }
  165. if got != tt.want {
  166. t.Errorf("PostResponseTemplate() got = %v, want %v", got, tt.want)
  167. }
  168. })
  169. }
  170. }
  171. func TestModel_PreResponsePrompt_PostResponsePrompt(t *testing.T) {
  172. tests := []struct {
  173. name string
  174. template string
  175. preVars PromptVars
  176. postVars PromptVars
  177. want string
  178. wantErr bool
  179. }{
  180. {
  181. name: "Response in Template",
  182. template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
  183. preVars: PromptVars{
  184. Prompt: "What are the potion ingredients?",
  185. },
  186. postVars: PromptVars{
  187. Prompt: "What are the potion ingredients?",
  188. Response: "Sugar.",
  189. },
  190. want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\nSugar.<|im_end|>",
  191. },
  192. {
  193. name: "No Response in Template",
  194. template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n",
  195. preVars: PromptVars{
  196. Prompt: "What are the potion ingredients?",
  197. },
  198. postVars: PromptVars{
  199. Prompt: "What are the potion ingredients?",
  200. Response: "Spice.",
  201. },
  202. want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\nSpice.",
  203. },
  204. }
  205. for _, tt := range tests {
  206. m := Model{Template: tt.template}
  207. t.Run(tt.name, func(t *testing.T) {
  208. pre, err := m.PreResponsePrompt(tt.preVars)
  209. if (err != nil) != tt.wantErr {
  210. t.Errorf("PreResponsePrompt() error = %v, wantErr %v", err, tt.wantErr)
  211. return
  212. }
  213. post, err := m.PostResponseTemplate(tt.postVars)
  214. if err != nil {
  215. t.Errorf("PostResponseTemplate() error = %v, wantErr %v", err, tt.wantErr)
  216. return
  217. }
  218. result := pre + post
  219. if result != tt.want {
  220. t.Errorf("Prompt() got = %v, want %v", result, tt.want)
  221. }
  222. })
  223. }
  224. }
  225. func TestChat(t *testing.T) {
  226. tests := []struct {
  227. name string
  228. template string
  229. msgs []api.Message
  230. want string
  231. wantErr string
  232. }{
  233. {
  234. name: "Single Message",
  235. template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
  236. msgs: []api.Message{
  237. {
  238. Role: "system",
  239. Content: "You are a Wizard.",
  240. },
  241. {
  242. Role: "user",
  243. Content: "What are the potion ingredients?",
  244. },
  245. },
  246. want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
  247. },
  248. {
  249. name: "First Message",
  250. template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST]",
  251. msgs: []api.Message{
  252. {
  253. Role: "system",
  254. Content: "You are a Wizard.",
  255. },
  256. {
  257. Role: "user",
  258. Content: "What are the potion ingredients?",
  259. },
  260. {
  261. Role: "assistant",
  262. Content: "eye of newt",
  263. },
  264. {
  265. Role: "user",
  266. Content: "Anything else?",
  267. },
  268. },
  269. want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST]eye of newt[INST] Anything else? [/INST]",
  270. },
  271. {
  272. name: "Message History",
  273. template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
  274. msgs: []api.Message{
  275. {
  276. Role: "system",
  277. Content: "You are a Wizard.",
  278. },
  279. {
  280. Role: "user",
  281. Content: "What are the potion ingredients?",
  282. },
  283. {
  284. Role: "assistant",
  285. Content: "sugar",
  286. },
  287. {
  288. Role: "user",
  289. Content: "Anything else?",
  290. },
  291. },
  292. want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST] Anything else? [/INST]",
  293. },
  294. {
  295. name: "Assistant Only",
  296. template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
  297. msgs: []api.Message{
  298. {
  299. Role: "assistant",
  300. Content: "everything nice",
  301. },
  302. },
  303. want: "[INST] [/INST]everything nice",
  304. },
  305. {
  306. name: "Invalid Role",
  307. msgs: []api.Message{
  308. {
  309. Role: "not-a-role",
  310. Content: "howdy",
  311. },
  312. },
  313. wantErr: "invalid role: not-a-role",
  314. },
  315. }
  316. for _, tt := range tests {
  317. m := Model{
  318. Template: tt.template,
  319. }
  320. t.Run(tt.name, func(t *testing.T) {
  321. got, _, err := m.ChatPrompt(tt.msgs)
  322. if tt.wantErr != "" {
  323. if err == nil {
  324. t.Errorf("ChatPrompt() expected error, got nil")
  325. }
  326. if !strings.Contains(err.Error(), tt.wantErr) {
  327. t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
  328. }
  329. }
  330. if got != tt.want {
  331. t.Errorf("ChatPrompt() got = %v, want %v", got, tt.want)
  332. }
  333. })
  334. }
  335. }