images_test.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423
  1. package server
  2. import (
  3. "bytes"
  4. "strings"
  5. "testing"
  6. "github.com/jmorganca/ollama/api"
  7. )
  8. func TestPrompt(t *testing.T) {
  9. tests := []struct {
  10. name string
  11. template string
  12. vars PromptVars
  13. want string
  14. wantErr bool
  15. }{
  16. {
  17. name: "System Prompt",
  18. template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
  19. vars: PromptVars{
  20. System: "You are a Wizard.",
  21. Prompt: "What are the potion ingredients?",
  22. },
  23. want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
  24. },
  25. {
  26. name: "System Prompt with Response",
  27. template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
  28. vars: PromptVars{
  29. System: "You are a Wizard.",
  30. Prompt: "What are the potion ingredients?",
  31. Response: "I don't know.",
  32. },
  33. want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
  34. },
  35. {
  36. name: "Conditional Logic Nodes",
  37. template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
  38. vars: PromptVars{
  39. First: true,
  40. System: "You are a Wizard.",
  41. Prompt: "What are the potion ingredients?",
  42. Response: "I don't know.",
  43. },
  44. want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
  45. },
  46. }
  47. for _, tt := range tests {
  48. t.Run(tt.name, func(t *testing.T) {
  49. got, err := Prompt(tt.template, tt.vars)
  50. if (err != nil) != tt.wantErr {
  51. t.Errorf("Prompt() error = %v, wantErr %v", err, tt.wantErr)
  52. return
  53. }
  54. if got != tt.want {
  55. t.Errorf("Prompt() got = %v, want %v", got, tt.want)
  56. }
  57. })
  58. }
  59. }
  60. func TestModel_PreResponsePrompt(t *testing.T) {
  61. tests := []struct {
  62. name string
  63. template string
  64. vars PromptVars
  65. want string
  66. wantErr bool
  67. }{
  68. {
  69. name: "No Response in Template",
  70. template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
  71. vars: PromptVars{
  72. System: "You are a Wizard.",
  73. Prompt: "What are the potion ingredients?",
  74. },
  75. want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
  76. },
  77. {
  78. name: "Response in Template",
  79. template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
  80. vars: PromptVars{
  81. System: "You are a Wizard.",
  82. Prompt: "What are the potion ingredients?",
  83. },
  84. want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] ",
  85. },
  86. {
  87. name: "Response in Template with Trailing Formatting",
  88. template: "<|im_start|>user\n{{ .Prompt }}<|im_end|><|im_start|>assistant\n{{ .Response }}<|im_end|>",
  89. vars: PromptVars{
  90. Prompt: "What are the potion ingredients?",
  91. },
  92. want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\n",
  93. },
  94. {
  95. name: "Response in Template with Alternative Formatting",
  96. template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
  97. vars: PromptVars{
  98. Prompt: "What are the potion ingredients?",
  99. },
  100. want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\n",
  101. },
  102. }
  103. for _, tt := range tests {
  104. m := Model{Template: tt.template}
  105. t.Run(tt.name, func(t *testing.T) {
  106. got, err := m.PreResponsePrompt(tt.vars)
  107. if (err != nil) != tt.wantErr {
  108. t.Errorf("PreResponsePrompt() error = %v, wantErr %v", err, tt.wantErr)
  109. return
  110. }
  111. if got != tt.want {
  112. t.Errorf("PreResponsePrompt() got = %v, want %v", got, tt.want)
  113. }
  114. })
  115. }
  116. }
  117. func TestModel_PostResponsePrompt(t *testing.T) {
  118. tests := []struct {
  119. name string
  120. template string
  121. vars PromptVars
  122. want string
  123. wantErr bool
  124. }{
  125. {
  126. name: "No Response in Template",
  127. template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
  128. vars: PromptVars{
  129. Response: "I don't know.",
  130. },
  131. want: "I don't know.",
  132. },
  133. {
  134. name: "Response in Template",
  135. template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
  136. vars: PromptVars{
  137. Response: "I don't know.",
  138. },
  139. want: "I don't know.",
  140. },
  141. {
  142. name: "Response in Template with Trailing Formatting",
  143. template: "<|im_start|>user\n{{ .Prompt }}<|im_end|><|im_start|>assistant\n{{ .Response }}<|im_end|>",
  144. vars: PromptVars{
  145. Response: "I don't know.",
  146. },
  147. want: "I don't know.<|im_end|>",
  148. },
  149. {
  150. name: "Response in Template with Alternative Formatting",
  151. template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
  152. vars: PromptVars{
  153. Response: "I don't know.",
  154. },
  155. want: "I don't know.<|im_end|>",
  156. },
  157. }
  158. for _, tt := range tests {
  159. m := Model{Template: tt.template}
  160. t.Run(tt.name, func(t *testing.T) {
  161. got, err := m.PostResponseTemplate(tt.vars)
  162. if (err != nil) != tt.wantErr {
  163. t.Errorf("PostResponseTemplate() error = %v, wantErr %v", err, tt.wantErr)
  164. return
  165. }
  166. if got != tt.want {
  167. t.Errorf("PostResponseTemplate() got = %v, want %v", got, tt.want)
  168. }
  169. })
  170. }
  171. }
  172. func TestModel_PreResponsePrompt_PostResponsePrompt(t *testing.T) {
  173. tests := []struct {
  174. name string
  175. template string
  176. preVars PromptVars
  177. postVars PromptVars
  178. want string
  179. wantErr bool
  180. }{
  181. {
  182. name: "Response in Template",
  183. template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
  184. preVars: PromptVars{
  185. Prompt: "What are the potion ingredients?",
  186. },
  187. postVars: PromptVars{
  188. Prompt: "What are the potion ingredients?",
  189. Response: "Sugar.",
  190. },
  191. want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\nSugar.<|im_end|>",
  192. },
  193. {
  194. name: "No Response in Template",
  195. template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n",
  196. preVars: PromptVars{
  197. Prompt: "What are the potion ingredients?",
  198. },
  199. postVars: PromptVars{
  200. Prompt: "What are the potion ingredients?",
  201. Response: "Spice.",
  202. },
  203. want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\nSpice.",
  204. },
  205. }
  206. for _, tt := range tests {
  207. m := Model{Template: tt.template}
  208. t.Run(tt.name, func(t *testing.T) {
  209. pre, err := m.PreResponsePrompt(tt.preVars)
  210. if (err != nil) != tt.wantErr {
  211. t.Errorf("PreResponsePrompt() error = %v, wantErr %v", err, tt.wantErr)
  212. return
  213. }
  214. post, err := m.PostResponseTemplate(tt.postVars)
  215. if err != nil {
  216. t.Errorf("PostResponseTemplate() error = %v, wantErr %v", err, tt.wantErr)
  217. return
  218. }
  219. result := pre + post
  220. if result != tt.want {
  221. t.Errorf("Prompt() got = %v, want %v", result, tt.want)
  222. }
  223. })
  224. }
  225. }
  226. func chatHistoryEqual(a, b ChatHistory) bool {
  227. if len(a.Prompts) != len(b.Prompts) {
  228. return false
  229. }
  230. if len(a.CurrentImages) != len(b.CurrentImages) {
  231. return false
  232. }
  233. for i, v := range a.Prompts {
  234. if v != b.Prompts[i] {
  235. return false
  236. }
  237. }
  238. for i, v := range a.CurrentImages {
  239. if !bytes.Equal(v, b.CurrentImages[i]) {
  240. return false
  241. }
  242. }
  243. return a.LastSystem == b.LastSystem
  244. }
  245. func TestChat(t *testing.T) {
  246. tests := []struct {
  247. name string
  248. model Model
  249. msgs []api.Message
  250. want ChatHistory
  251. wantErr string
  252. }{
  253. {
  254. name: "Single Message",
  255. model: Model{
  256. Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
  257. },
  258. msgs: []api.Message{
  259. {
  260. Role: "system",
  261. Content: "You are a Wizard.",
  262. },
  263. {
  264. Role: "user",
  265. Content: "What are the potion ingredients?",
  266. },
  267. },
  268. want: ChatHistory{
  269. Prompts: []PromptVars{
  270. {
  271. System: "You are a Wizard.",
  272. Prompt: "What are the potion ingredients?",
  273. First: true,
  274. },
  275. },
  276. LastSystem: "You are a Wizard.",
  277. },
  278. },
  279. {
  280. name: "Message History",
  281. model: Model{
  282. Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
  283. },
  284. msgs: []api.Message{
  285. {
  286. Role: "system",
  287. Content: "You are a Wizard.",
  288. },
  289. {
  290. Role: "user",
  291. Content: "What are the potion ingredients?",
  292. },
  293. {
  294. Role: "assistant",
  295. Content: "sugar",
  296. },
  297. {
  298. Role: "user",
  299. Content: "Anything else?",
  300. },
  301. },
  302. want: ChatHistory{
  303. Prompts: []PromptVars{
  304. {
  305. System: "You are a Wizard.",
  306. Prompt: "What are the potion ingredients?",
  307. Response: "sugar",
  308. First: true,
  309. },
  310. {
  311. Prompt: "Anything else?",
  312. },
  313. },
  314. LastSystem: "You are a Wizard.",
  315. },
  316. },
  317. {
  318. name: "Assistant Only",
  319. model: Model{
  320. Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
  321. },
  322. msgs: []api.Message{
  323. {
  324. Role: "assistant",
  325. Content: "everything nice",
  326. },
  327. },
  328. want: ChatHistory{
  329. Prompts: []PromptVars{
  330. {
  331. Response: "everything nice",
  332. First: true,
  333. },
  334. },
  335. },
  336. },
  337. {
  338. name: "Last system message is preserved from modelfile",
  339. model: Model{
  340. Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
  341. System: "You are Mojo Jojo.",
  342. },
  343. msgs: []api.Message{
  344. {
  345. Role: "user",
  346. Content: "hi",
  347. },
  348. },
  349. want: ChatHistory{
  350. Prompts: []PromptVars{
  351. {
  352. System: "You are Mojo Jojo.",
  353. Prompt: "hi",
  354. First: true,
  355. },
  356. },
  357. LastSystem: "You are Mojo Jojo.",
  358. },
  359. },
  360. {
  361. name: "Last system message is preserved from messages",
  362. model: Model{
  363. Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
  364. System: "You are Mojo Jojo.",
  365. },
  366. msgs: []api.Message{
  367. {
  368. Role: "system",
  369. Content: "You are Professor Utonium.",
  370. },
  371. },
  372. want: ChatHistory{
  373. Prompts: []PromptVars{
  374. {
  375. System: "You are Professor Utonium.",
  376. First: true,
  377. },
  378. },
  379. LastSystem: "You are Professor Utonium.",
  380. },
  381. },
  382. {
  383. name: "Invalid Role",
  384. msgs: []api.Message{
  385. {
  386. Role: "not-a-role",
  387. Content: "howdy",
  388. },
  389. },
  390. wantErr: "invalid role: not-a-role",
  391. },
  392. }
  393. for _, tt := range tests {
  394. t.Run(tt.name, func(t *testing.T) {
  395. got, err := tt.model.ChatPrompts(tt.msgs)
  396. if tt.wantErr != "" {
  397. if err == nil {
  398. t.Errorf("ChatPrompt() expected error, got nil")
  399. }
  400. if !strings.Contains(err.Error(), tt.wantErr) {
  401. t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
  402. }
  403. return
  404. }
  405. if !chatHistoryEqual(*got, tt.want) {
  406. t.Errorf("ChatPrompt() got = %#v, want %#v", got, tt.want)
  407. }
  408. })
  409. }
  410. }