images_test.go 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375
  1. package server
  2. import (
  3. "bytes"
  4. "strings"
  5. "testing"
  6. "github.com/jmorganca/ollama/api"
  7. )
  8. func TestPrompt(t *testing.T) {
  9. tests := []struct {
  10. name string
  11. template string
  12. vars PromptVars
  13. want string
  14. wantErr bool
  15. }{
  16. {
  17. name: "System Prompt",
  18. template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
  19. vars: PromptVars{
  20. System: "You are a Wizard.",
  21. Prompt: "What are the potion ingredients?",
  22. },
  23. want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
  24. },
  25. {
  26. name: "System Prompt with Response",
  27. template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
  28. vars: PromptVars{
  29. System: "You are a Wizard.",
  30. Prompt: "What are the potion ingredients?",
  31. Response: "I don't know.",
  32. },
  33. want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
  34. },
  35. {
  36. name: "Conditional Logic Nodes",
  37. template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
  38. vars: PromptVars{
  39. First: true,
  40. System: "You are a Wizard.",
  41. Prompt: "What are the potion ingredients?",
  42. Response: "I don't know.",
  43. },
  44. want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
  45. },
  46. }
  47. for _, tt := range tests {
  48. t.Run(tt.name, func(t *testing.T) {
  49. got, err := Prompt(tt.template, tt.vars)
  50. if (err != nil) != tt.wantErr {
  51. t.Errorf("Prompt() error = %v, wantErr %v", err, tt.wantErr)
  52. return
  53. }
  54. if got != tt.want {
  55. t.Errorf("Prompt() got = %v, want %v", got, tt.want)
  56. }
  57. })
  58. }
  59. }
  60. func TestModel_PreResponsePrompt(t *testing.T) {
  61. tests := []struct {
  62. name string
  63. template string
  64. vars PromptVars
  65. want string
  66. wantErr bool
  67. }{
  68. {
  69. name: "No Response in Template",
  70. template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
  71. vars: PromptVars{
  72. System: "You are a Wizard.",
  73. Prompt: "What are the potion ingredients?",
  74. },
  75. want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
  76. },
  77. {
  78. name: "Response in Template",
  79. template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
  80. vars: PromptVars{
  81. System: "You are a Wizard.",
  82. Prompt: "What are the potion ingredients?",
  83. },
  84. want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] ",
  85. },
  86. {
  87. name: "Response in Template with Trailing Formatting",
  88. template: "<|im_start|>user\n{{ .Prompt }}<|im_end|><|im_start|>assistant\n{{ .Response }}<|im_end|>",
  89. vars: PromptVars{
  90. Prompt: "What are the potion ingredients?",
  91. },
  92. want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\n",
  93. },
  94. {
  95. name: "Response in Template with Alternative Formatting",
  96. template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
  97. vars: PromptVars{
  98. Prompt: "What are the potion ingredients?",
  99. },
  100. want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\n",
  101. },
  102. }
  103. for _, tt := range tests {
  104. m := Model{Template: tt.template}
  105. t.Run(tt.name, func(t *testing.T) {
  106. got, err := m.PreResponsePrompt(tt.vars)
  107. if (err != nil) != tt.wantErr {
  108. t.Errorf("PreResponsePrompt() error = %v, wantErr %v", err, tt.wantErr)
  109. return
  110. }
  111. if got != tt.want {
  112. t.Errorf("PreResponsePrompt() got = %v, want %v", got, tt.want)
  113. }
  114. })
  115. }
  116. }
  117. func TestModel_PostResponsePrompt(t *testing.T) {
  118. tests := []struct {
  119. name string
  120. template string
  121. vars PromptVars
  122. want string
  123. wantErr bool
  124. }{
  125. {
  126. name: "No Response in Template",
  127. template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
  128. vars: PromptVars{
  129. Response: "I don't know.",
  130. },
  131. want: "I don't know.",
  132. },
  133. {
  134. name: "Response in Template",
  135. template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
  136. vars: PromptVars{
  137. Response: "I don't know.",
  138. },
  139. want: "I don't know.",
  140. },
  141. {
  142. name: "Response in Template with Trailing Formatting",
  143. template: "<|im_start|>user\n{{ .Prompt }}<|im_end|><|im_start|>assistant\n{{ .Response }}<|im_end|>",
  144. vars: PromptVars{
  145. Response: "I don't know.",
  146. },
  147. want: "I don't know.<|im_end|>",
  148. },
  149. {
  150. name: "Response in Template with Alternative Formatting",
  151. template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
  152. vars: PromptVars{
  153. Response: "I don't know.",
  154. },
  155. want: "I don't know.<|im_end|>",
  156. },
  157. }
  158. for _, tt := range tests {
  159. m := Model{Template: tt.template}
  160. t.Run(tt.name, func(t *testing.T) {
  161. got, err := m.PostResponseTemplate(tt.vars)
  162. if (err != nil) != tt.wantErr {
  163. t.Errorf("PostResponseTemplate() error = %v, wantErr %v", err, tt.wantErr)
  164. return
  165. }
  166. if got != tt.want {
  167. t.Errorf("PostResponseTemplate() got = %v, want %v", got, tt.want)
  168. }
  169. })
  170. }
  171. }
  172. func TestModel_PreResponsePrompt_PostResponsePrompt(t *testing.T) {
  173. tests := []struct {
  174. name string
  175. template string
  176. preVars PromptVars
  177. postVars PromptVars
  178. want string
  179. wantErr bool
  180. }{
  181. {
  182. name: "Response in Template",
  183. template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
  184. preVars: PromptVars{
  185. Prompt: "What are the potion ingredients?",
  186. },
  187. postVars: PromptVars{
  188. Prompt: "What are the potion ingredients?",
  189. Response: "Sugar.",
  190. },
  191. want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\nSugar.<|im_end|>",
  192. },
  193. {
  194. name: "No Response in Template",
  195. template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n",
  196. preVars: PromptVars{
  197. Prompt: "What are the potion ingredients?",
  198. },
  199. postVars: PromptVars{
  200. Prompt: "What are the potion ingredients?",
  201. Response: "Spice.",
  202. },
  203. want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\nSpice.",
  204. },
  205. }
  206. for _, tt := range tests {
  207. m := Model{Template: tt.template}
  208. t.Run(tt.name, func(t *testing.T) {
  209. pre, err := m.PreResponsePrompt(tt.preVars)
  210. if (err != nil) != tt.wantErr {
  211. t.Errorf("PreResponsePrompt() error = %v, wantErr %v", err, tt.wantErr)
  212. return
  213. }
  214. post, err := m.PostResponseTemplate(tt.postVars)
  215. if err != nil {
  216. t.Errorf("PostResponseTemplate() error = %v, wantErr %v", err, tt.wantErr)
  217. return
  218. }
  219. result := pre + post
  220. if result != tt.want {
  221. t.Errorf("Prompt() got = %v, want %v", result, tt.want)
  222. }
  223. })
  224. }
  225. }
  226. func chatHistoryEqual(a, b ChatHistory) bool {
  227. if len(a.Prompts) != len(b.Prompts) {
  228. return false
  229. }
  230. if len(a.CurrentImages) != len(b.CurrentImages) {
  231. return false
  232. }
  233. for i, v := range a.Prompts {
  234. if v != b.Prompts[i] {
  235. return false
  236. }
  237. }
  238. for i, v := range a.CurrentImages {
  239. if !bytes.Equal(v, b.CurrentImages[i]) {
  240. return false
  241. }
  242. }
  243. return a.LastSystem == b.LastSystem
  244. }
  245. func TestChat(t *testing.T) {
  246. tests := []struct {
  247. name string
  248. template string
  249. msgs []api.Message
  250. want ChatHistory
  251. wantErr string
  252. }{
  253. {
  254. name: "Single Message",
  255. template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
  256. msgs: []api.Message{
  257. {
  258. Role: "system",
  259. Content: "You are a Wizard.",
  260. },
  261. {
  262. Role: "user",
  263. Content: "What are the potion ingredients?",
  264. },
  265. },
  266. want: ChatHistory{
  267. Prompts: []PromptVars{
  268. {
  269. System: "You are a Wizard.",
  270. Prompt: "What are the potion ingredients?",
  271. First: true,
  272. },
  273. },
  274. LastSystem: "You are a Wizard.",
  275. },
  276. },
  277. {
  278. name: "Message History",
  279. template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
  280. msgs: []api.Message{
  281. {
  282. Role: "system",
  283. Content: "You are a Wizard.",
  284. },
  285. {
  286. Role: "user",
  287. Content: "What are the potion ingredients?",
  288. },
  289. {
  290. Role: "assistant",
  291. Content: "sugar",
  292. },
  293. {
  294. Role: "user",
  295. Content: "Anything else?",
  296. },
  297. },
  298. want: ChatHistory{
  299. Prompts: []PromptVars{
  300. {
  301. System: "You are a Wizard.",
  302. Prompt: "What are the potion ingredients?",
  303. Response: "sugar",
  304. First: true,
  305. },
  306. {
  307. Prompt: "Anything else?",
  308. },
  309. },
  310. LastSystem: "You are a Wizard.",
  311. },
  312. },
  313. {
  314. name: "Assistant Only",
  315. template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
  316. msgs: []api.Message{
  317. {
  318. Role: "assistant",
  319. Content: "everything nice",
  320. },
  321. },
  322. want: ChatHistory{
  323. Prompts: []PromptVars{
  324. {
  325. Response: "everything nice",
  326. First: true,
  327. },
  328. },
  329. },
  330. },
  331. {
  332. name: "Invalid Role",
  333. msgs: []api.Message{
  334. {
  335. Role: "not-a-role",
  336. Content: "howdy",
  337. },
  338. },
  339. wantErr: "invalid role: not-a-role",
  340. },
  341. }
  342. for _, tt := range tests {
  343. m := Model{
  344. Template: tt.template,
  345. }
  346. t.Run(tt.name, func(t *testing.T) {
  347. got, err := m.ChatPrompts(tt.msgs)
  348. if tt.wantErr != "" {
  349. if err == nil {
  350. t.Errorf("ChatPrompt() expected error, got nil")
  351. }
  352. if !strings.Contains(err.Error(), tt.wantErr) {
  353. t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
  354. }
  355. return
  356. }
  357. if !chatHistoryEqual(*got, tt.want) {
  358. t.Errorf("ChatPrompt() got = %#v, want %#v", got, tt.want)
  359. }
  360. })
  361. }
  362. }