images_test.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442
  1. package server
  2. import (
  3. "bytes"
  4. "strings"
  5. "testing"
  6. "github.com/jmorganca/ollama/api"
  7. )
  8. func TestPrompt(t *testing.T) {
  9. tests := []struct {
  10. name string
  11. template string
  12. vars PromptVars
  13. want string
  14. wantErr bool
  15. }{
  16. {
  17. name: "System Prompt",
  18. template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
  19. vars: PromptVars{
  20. System: "You are a Wizard.",
  21. Prompt: "What are the potion ingredients?",
  22. },
  23. want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
  24. },
  25. {
  26. name: "System Prompt with Response",
  27. template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
  28. vars: PromptVars{
  29. System: "You are a Wizard.",
  30. Prompt: "What are the potion ingredients?",
  31. Response: "I don't know.",
  32. },
  33. want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
  34. },
  35. {
  36. name: "Conditional Logic Nodes",
  37. template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
  38. vars: PromptVars{
  39. First: true,
  40. System: "You are a Wizard.",
  41. Prompt: "What are the potion ingredients?",
  42. Response: "I don't know.",
  43. },
  44. want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
  45. },
  46. }
  47. for _, tt := range tests {
  48. t.Run(tt.name, func(t *testing.T) {
  49. got, err := Prompt(tt.template, tt.vars)
  50. if (err != nil) != tt.wantErr {
  51. t.Errorf("Prompt() error = %v, wantErr %v", err, tt.wantErr)
  52. return
  53. }
  54. if got != tt.want {
  55. t.Errorf("Prompt() got = %v, want %v", got, tt.want)
  56. }
  57. })
  58. }
  59. }
  60. func TestModel_PreResponsePrompt(t *testing.T) {
  61. tests := []struct {
  62. name string
  63. template string
  64. vars PromptVars
  65. want string
  66. wantErr bool
  67. }{
  68. {
  69. name: "No Response in Template",
  70. template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
  71. vars: PromptVars{
  72. System: "You are a Wizard.",
  73. Prompt: "What are the potion ingredients?",
  74. },
  75. want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
  76. },
  77. {
  78. name: "Response in Template",
  79. template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
  80. vars: PromptVars{
  81. System: "You are a Wizard.",
  82. Prompt: "What are the potion ingredients?",
  83. },
  84. want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] ",
  85. },
  86. {
  87. name: "Response in Template with Trailing Formatting",
  88. template: "<|im_start|>user\n{{ .Prompt }}<|im_end|><|im_start|>assistant\n{{ .Response }}<|im_end|>",
  89. vars: PromptVars{
  90. Prompt: "What are the potion ingredients?",
  91. },
  92. want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\n",
  93. },
  94. {
  95. name: "Response in Template with Alternative Formatting",
  96. template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
  97. vars: PromptVars{
  98. Prompt: "What are the potion ingredients?",
  99. },
  100. want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\n",
  101. },
  102. }
  103. for _, tt := range tests {
  104. m := Model{Template: tt.template}
  105. t.Run(tt.name, func(t *testing.T) {
  106. got, err := m.PreResponsePrompt(tt.vars)
  107. if (err != nil) != tt.wantErr {
  108. t.Errorf("PreResponsePrompt() error = %v, wantErr %v", err, tt.wantErr)
  109. return
  110. }
  111. if got != tt.want {
  112. t.Errorf("PreResponsePrompt() got = %v, want %v", got, tt.want)
  113. }
  114. })
  115. }
  116. }
  117. func TestModel_PostResponsePrompt(t *testing.T) {
  118. tests := []struct {
  119. name string
  120. template string
  121. vars PromptVars
  122. want string
  123. wantErr bool
  124. }{
  125. {
  126. name: "No Response in Template",
  127. template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
  128. vars: PromptVars{
  129. Response: "I don't know.",
  130. },
  131. want: "I don't know.",
  132. },
  133. {
  134. name: "Response in Template",
  135. template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
  136. vars: PromptVars{
  137. Response: "I don't know.",
  138. },
  139. want: "I don't know.",
  140. },
  141. {
  142. name: "Response in Template with Trailing Formatting",
  143. template: "<|im_start|>user\n{{ .Prompt }}<|im_end|><|im_start|>assistant\n{{ .Response }}<|im_end|>",
  144. vars: PromptVars{
  145. Response: "I don't know.",
  146. },
  147. want: "I don't know.<|im_end|>",
  148. },
  149. {
  150. name: "Response in Template with Alternative Formatting",
  151. template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
  152. vars: PromptVars{
  153. Response: "I don't know.",
  154. },
  155. want: "I don't know.<|im_end|>",
  156. },
  157. }
  158. for _, tt := range tests {
  159. m := Model{Template: tt.template}
  160. t.Run(tt.name, func(t *testing.T) {
  161. got, err := m.PostResponseTemplate(tt.vars)
  162. if (err != nil) != tt.wantErr {
  163. t.Errorf("PostResponseTemplate() error = %v, wantErr %v", err, tt.wantErr)
  164. return
  165. }
  166. if got != tt.want {
  167. t.Errorf("PostResponseTemplate() got = %v, want %v", got, tt.want)
  168. }
  169. })
  170. }
  171. }
  172. func TestModel_PreResponsePrompt_PostResponsePrompt(t *testing.T) {
  173. tests := []struct {
  174. name string
  175. template string
  176. preVars PromptVars
  177. postVars PromptVars
  178. want string
  179. wantErr bool
  180. }{
  181. {
  182. name: "Response in Template",
  183. template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
  184. preVars: PromptVars{
  185. Prompt: "What are the potion ingredients?",
  186. },
  187. postVars: PromptVars{
  188. Prompt: "What are the potion ingredients?",
  189. Response: "Sugar.",
  190. },
  191. want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\nSugar.<|im_end|>",
  192. },
  193. {
  194. name: "No Response in Template",
  195. template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n",
  196. preVars: PromptVars{
  197. Prompt: "What are the potion ingredients?",
  198. },
  199. postVars: PromptVars{
  200. Prompt: "What are the potion ingredients?",
  201. Response: "Spice.",
  202. },
  203. want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\nSpice.",
  204. },
  205. }
  206. for _, tt := range tests {
  207. m := Model{Template: tt.template}
  208. t.Run(tt.name, func(t *testing.T) {
  209. pre, err := m.PreResponsePrompt(tt.preVars)
  210. if (err != nil) != tt.wantErr {
  211. t.Errorf("PreResponsePrompt() error = %v, wantErr %v", err, tt.wantErr)
  212. return
  213. }
  214. post, err := m.PostResponseTemplate(tt.postVars)
  215. if err != nil {
  216. t.Errorf("PostResponseTemplate() error = %v, wantErr %v", err, tt.wantErr)
  217. return
  218. }
  219. result := pre + post
  220. if result != tt.want {
  221. t.Errorf("Prompt() got = %v, want %v", result, tt.want)
  222. }
  223. })
  224. }
  225. }
  226. func chatHistoryEqual(a, b ChatHistory) bool {
  227. if len(a.Prompts) != len(b.Prompts) {
  228. return false
  229. }
  230. for i, v := range a.Prompts {
  231. if v.First != b.Prompts[i].First {
  232. return false
  233. }
  234. if v.Response != b.Prompts[i].Response {
  235. return false
  236. }
  237. if v.Prompt != b.Prompts[i].Prompt {
  238. return false
  239. }
  240. if v.System != b.Prompts[i].System {
  241. return false
  242. }
  243. if len(v.Images) != len(b.Prompts[i].Images) {
  244. return false
  245. }
  246. for j, img := range v.Images {
  247. if img.ID != b.Prompts[i].Images[j].ID {
  248. return false
  249. }
  250. if !bytes.Equal(img.Data, b.Prompts[i].Images[j].Data) {
  251. return false
  252. }
  253. }
  254. }
  255. return a.LastSystem == b.LastSystem
  256. }
  257. func TestChat(t *testing.T) {
  258. tests := []struct {
  259. name string
  260. model Model
  261. msgs []api.Message
  262. want ChatHistory
  263. wantErr string
  264. }{
  265. {
  266. name: "Single Message",
  267. model: Model{
  268. Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
  269. },
  270. msgs: []api.Message{
  271. {
  272. Role: "system",
  273. Content: "You are a Wizard.",
  274. },
  275. {
  276. Role: "user",
  277. Content: "What are the potion ingredients?",
  278. },
  279. },
  280. want: ChatHistory{
  281. Prompts: []PromptVars{
  282. {
  283. System: "You are a Wizard.",
  284. Prompt: "What are the potion ingredients?",
  285. First: true,
  286. },
  287. },
  288. LastSystem: "You are a Wizard.",
  289. },
  290. },
  291. {
  292. name: "Message History",
  293. model: Model{
  294. Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
  295. },
  296. msgs: []api.Message{
  297. {
  298. Role: "system",
  299. Content: "You are a Wizard.",
  300. },
  301. {
  302. Role: "user",
  303. Content: "What are the potion ingredients?",
  304. },
  305. {
  306. Role: "assistant",
  307. Content: "sugar",
  308. },
  309. {
  310. Role: "user",
  311. Content: "Anything else?",
  312. },
  313. },
  314. want: ChatHistory{
  315. Prompts: []PromptVars{
  316. {
  317. System: "You are a Wizard.",
  318. Prompt: "What are the potion ingredients?",
  319. Response: "sugar",
  320. First: true,
  321. },
  322. {
  323. Prompt: "Anything else?",
  324. },
  325. },
  326. LastSystem: "You are a Wizard.",
  327. },
  328. },
  329. {
  330. name: "Assistant Only",
  331. model: Model{
  332. Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
  333. },
  334. msgs: []api.Message{
  335. {
  336. Role: "assistant",
  337. Content: "everything nice",
  338. },
  339. },
  340. want: ChatHistory{
  341. Prompts: []PromptVars{
  342. {
  343. Response: "everything nice",
  344. First: true,
  345. },
  346. },
  347. },
  348. },
  349. {
  350. name: "Last system message is preserved from modelfile",
  351. model: Model{
  352. Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
  353. System: "You are Mojo Jojo.",
  354. },
  355. msgs: []api.Message{
  356. {
  357. Role: "user",
  358. Content: "hi",
  359. },
  360. },
  361. want: ChatHistory{
  362. Prompts: []PromptVars{
  363. {
  364. System: "You are Mojo Jojo.",
  365. Prompt: "hi",
  366. First: true,
  367. },
  368. },
  369. LastSystem: "You are Mojo Jojo.",
  370. },
  371. },
  372. {
  373. name: "Last system message is preserved from messages",
  374. model: Model{
  375. Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
  376. System: "You are Mojo Jojo.",
  377. },
  378. msgs: []api.Message{
  379. {
  380. Role: "system",
  381. Content: "You are Professor Utonium.",
  382. },
  383. },
  384. want: ChatHistory{
  385. Prompts: []PromptVars{
  386. {
  387. System: "You are Professor Utonium.",
  388. First: true,
  389. },
  390. },
  391. LastSystem: "You are Professor Utonium.",
  392. },
  393. },
  394. {
  395. name: "Invalid Role",
  396. msgs: []api.Message{
  397. {
  398. Role: "not-a-role",
  399. Content: "howdy",
  400. },
  401. },
  402. wantErr: "invalid role: not-a-role",
  403. },
  404. }
  405. for _, tt := range tests {
  406. t.Run(tt.name, func(t *testing.T) {
  407. got, err := tt.model.ChatPrompts(tt.msgs)
  408. if tt.wantErr != "" {
  409. if err == nil {
  410. t.Errorf("ChatPrompt() expected error, got nil")
  411. }
  412. if !strings.Contains(err.Error(), tt.wantErr) {
  413. t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
  414. }
  415. return
  416. }
  417. if !chatHistoryEqual(*got, tt.want) {
  418. t.Errorf("ChatPrompt() got = %#v, want %#v", got, tt.want)
  419. }
  420. })
  421. }
  422. }