prompt_test.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. package server
  2. import (
  3. "bytes"
  4. "context"
  5. "image"
  6. "image/png"
  7. "testing"
  8. "github.com/google/go-cmp/cmp"
  9. "github.com/ollama/ollama/api"
  10. "github.com/ollama/ollama/template"
  11. )
  12. func TestChatPrompt(t *testing.T) {
  13. type expect struct {
  14. prompt string
  15. images [][]byte
  16. aspectRatioID int
  17. error error
  18. }
  19. tmpl, err := template.Parse(`
  20. {{- if .System }}{{ .System }} {{ end }}
  21. {{- if .Prompt }}{{ .Prompt }} {{ end }}
  22. {{- if .Response }}{{ .Response }} {{ end }}`)
  23. if err != nil {
  24. t.Fatal(err)
  25. }
  26. visionModel := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
  27. mllamaModel := Model{Template: tmpl, ProjectorPaths: []string{"vision"}, Config: ConfigV2{ModelFamilies: []string{"mllama"}}}
  28. createImg := func(width, height int) ([]byte, error) {
  29. img := image.NewRGBA(image.Rect(0, 0, 5, 5))
  30. var buf bytes.Buffer
  31. if err := png.Encode(&buf, img); err != nil {
  32. return nil, err
  33. }
  34. return buf.Bytes(), nil
  35. }
  36. imgBuf, err := createImg(5, 5)
  37. if err != nil {
  38. t.Fatal(err)
  39. }
  40. imgBuf2, err := createImg(6, 6)
  41. if err != nil {
  42. t.Fatal(err)
  43. }
  44. cases := []struct {
  45. name string
  46. model Model
  47. limit int
  48. msgs []api.Message
  49. expect
  50. }{
  51. {
  52. name: "messages",
  53. model: visionModel,
  54. limit: 64,
  55. msgs: []api.Message{
  56. {Role: "user", Content: "You're a test, Harry!"},
  57. {Role: "assistant", Content: "I-I'm a what?"},
  58. {Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
  59. },
  60. expect: expect{
  61. prompt: "You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
  62. },
  63. },
  64. {
  65. name: "truncate messages",
  66. model: visionModel,
  67. limit: 1,
  68. msgs: []api.Message{
  69. {Role: "user", Content: "You're a test, Harry!"},
  70. {Role: "assistant", Content: "I-I'm a what?"},
  71. {Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
  72. },
  73. expect: expect{
  74. prompt: "A test. And a thumping good one at that, I'd wager. ",
  75. },
  76. },
  77. {
  78. name: "truncate messages with image",
  79. model: visionModel,
  80. limit: 64,
  81. msgs: []api.Message{
  82. {Role: "user", Content: "You're a test, Harry!"},
  83. {Role: "assistant", Content: "I-I'm a what?"},
  84. {Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("something")}},
  85. },
  86. expect: expect{
  87. prompt: "[img-0] A test. And a thumping good one at that, I'd wager. ",
  88. images: [][]byte{
  89. []byte("something"),
  90. },
  91. },
  92. },
  93. {
  94. name: "truncate messages with images",
  95. model: visionModel,
  96. limit: 64,
  97. msgs: []api.Message{
  98. {Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
  99. {Role: "assistant", Content: "I-I'm a what?"},
  100. {Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
  101. },
  102. expect: expect{
  103. prompt: "[img-0] A test. And a thumping good one at that, I'd wager. ",
  104. images: [][]byte{
  105. []byte("somethingelse"),
  106. },
  107. },
  108. },
  109. {
  110. name: "messages with images",
  111. model: visionModel,
  112. limit: 2048,
  113. msgs: []api.Message{
  114. {Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
  115. {Role: "assistant", Content: "I-I'm a what?"},
  116. {Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
  117. },
  118. expect: expect{
  119. prompt: "[img-0] You're a test, Harry! I-I'm a what? [img-1] A test. And a thumping good one at that, I'd wager. ",
  120. images: [][]byte{
  121. []byte("something"),
  122. []byte("somethingelse"),
  123. },
  124. },
  125. },
  126. {
  127. name: "message with image tag",
  128. model: visionModel,
  129. limit: 2048,
  130. msgs: []api.Message{
  131. {Role: "user", Content: "You're a test, Harry! [img]", Images: []api.ImageData{[]byte("something")}},
  132. {Role: "assistant", Content: "I-I'm a what?"},
  133. {Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
  134. },
  135. expect: expect{
  136. prompt: "You're a test, Harry! [img-0] I-I'm a what? [img-1] A test. And a thumping good one at that, I'd wager. ",
  137. images: [][]byte{
  138. []byte("something"),
  139. []byte("somethingelse"),
  140. },
  141. },
  142. },
  143. {
  144. name: "messages with interleaved images",
  145. model: visionModel,
  146. limit: 2048,
  147. msgs: []api.Message{
  148. {Role: "user", Content: "You're a test, Harry!"},
  149. {Role: "user", Images: []api.ImageData{[]byte("something")}},
  150. {Role: "user", Images: []api.ImageData{[]byte("somethingelse")}},
  151. {Role: "assistant", Content: "I-I'm a what?"},
  152. {Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
  153. },
  154. expect: expect{
  155. prompt: "You're a test, Harry!\n\n[img-0]\n\n[img-1] I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
  156. images: [][]byte{
  157. []byte("something"),
  158. []byte("somethingelse"),
  159. },
  160. },
  161. },
  162. {
  163. name: "truncate message with interleaved images",
  164. model: visionModel,
  165. limit: 1024,
  166. msgs: []api.Message{
  167. {Role: "user", Content: "You're a test, Harry!"},
  168. {Role: "user", Images: []api.ImageData{[]byte("something")}},
  169. {Role: "user", Images: []api.ImageData{[]byte("somethingelse")}},
  170. {Role: "assistant", Content: "I-I'm a what?"},
  171. {Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
  172. },
  173. expect: expect{
  174. prompt: "[img-0] I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
  175. images: [][]byte{
  176. []byte("somethingelse"),
  177. },
  178. },
  179. },
  180. {
  181. name: "message with system prompt",
  182. model: visionModel,
  183. limit: 2048,
  184. msgs: []api.Message{
  185. {Role: "system", Content: "You are the Test Who Lived."},
  186. {Role: "user", Content: "You're a test, Harry!"},
  187. {Role: "assistant", Content: "I-I'm a what?"},
  188. {Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
  189. },
  190. expect: expect{
  191. prompt: "You are the Test Who Lived. You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
  192. },
  193. },
  194. {
  195. name: "out of order system",
  196. model: visionModel,
  197. limit: 2048,
  198. msgs: []api.Message{
  199. {Role: "user", Content: "You're a test, Harry!"},
  200. {Role: "assistant", Content: "I-I'm a what?"},
  201. {Role: "system", Content: "You are the Test Who Lived."},
  202. {Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
  203. },
  204. expect: expect{
  205. prompt: "You're a test, Harry! I-I'm a what? You are the Test Who Lived. A test. And a thumping good one at that, I'd wager. ",
  206. },
  207. },
  208. {
  209. name: "messages with mllama (no images)",
  210. model: mllamaModel,
  211. limit: 2048,
  212. msgs: []api.Message{
  213. {Role: "user", Content: "You're a test, Harry!"},
  214. {Role: "assistant", Content: "I-I'm a what?"},
  215. {Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
  216. },
  217. expect: expect{
  218. prompt: "You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
  219. },
  220. },
  221. {
  222. name: "messages with mllama single prompt",
  223. model: mllamaModel,
  224. limit: 2048,
  225. msgs: []api.Message{
  226. {Role: "user", Content: "How many hotdogs are in this image?", Images: []api.ImageData{imgBuf}},
  227. },
  228. expect: expect{
  229. prompt: "<|image|>How many hotdogs are in this image? ",
  230. images: [][]byte{imgBuf},
  231. aspectRatioID: 1,
  232. },
  233. },
  234. {
  235. name: "messages with mllama",
  236. model: mllamaModel,
  237. limit: 2048,
  238. msgs: []api.Message{
  239. {Role: "user", Content: "You're a test, Harry!"},
  240. {Role: "assistant", Content: "I-I'm a what?"},
  241. {Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf}},
  242. },
  243. expect: expect{
  244. prompt: "You're a test, Harry! I-I'm a what? <|image|>A test. And a thumping good one at that, I'd wager. ",
  245. images: [][]byte{imgBuf},
  246. aspectRatioID: 1,
  247. },
  248. },
  249. {
  250. name: "multiple messages with mllama",
  251. model: mllamaModel,
  252. limit: 2048,
  253. msgs: []api.Message{
  254. {Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{imgBuf}},
  255. {Role: "assistant", Content: "I-I'm a what?"},
  256. {Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf2}},
  257. },
  258. expect: expect{
  259. prompt: "You're a test, Harry! I-I'm a what? <|image|>A test. And a thumping good one at that, I'd wager. ",
  260. images: [][]byte{imgBuf2},
  261. aspectRatioID: 1,
  262. },
  263. },
  264. {
  265. name: "earlier image with mllama",
  266. model: mllamaModel,
  267. limit: 2048,
  268. msgs: []api.Message{
  269. {Role: "user", Content: "How many hotdogs are in this image?", Images: []api.ImageData{imgBuf}},
  270. {Role: "assistant", Content: "There are four hotdogs."},
  271. {Role: "user", Content: "Which ones have mustard?"},
  272. },
  273. expect: expect{
  274. prompt: "<|image|>How many hotdogs are in this image? There are four hotdogs. Which ones have mustard? ",
  275. images: [][]byte{imgBuf},
  276. aspectRatioID: 1,
  277. },
  278. },
  279. {
  280. name: "too many images with mllama",
  281. model: mllamaModel,
  282. limit: 2048,
  283. msgs: []api.Message{
  284. {Role: "user", Content: "You're a test, Harry!"},
  285. {Role: "assistant", Content: "I-I'm a what?"},
  286. {Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf, imgBuf}},
  287. },
  288. expect: expect{
  289. error: errTooManyImages,
  290. },
  291. },
  292. }
  293. for _, tt := range cases {
  294. t.Run(tt.name, func(t *testing.T) {
  295. model := tt.model
  296. opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
  297. prompt, images, err := chatPrompt(context.TODO(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil)
  298. if tt.error == nil && err != nil {
  299. t.Fatal(err)
  300. } else if tt.error != nil && err != tt.error {
  301. t.Fatalf("expected err '%q', got '%q'", tt.error, err)
  302. }
  303. if diff := cmp.Diff(prompt, tt.prompt); diff != "" {
  304. t.Errorf("mismatch (-got +want):\n%s", diff)
  305. }
  306. if len(images) != len(tt.images) {
  307. t.Fatalf("expected %d images, got %d", len(tt.images), len(images))
  308. }
  309. for i := range images {
  310. if images[i].ID != i {
  311. t.Errorf("expected ID %d, got %d", i, images[i].ID)
  312. }
  313. if len(model.Config.ModelFamilies) == 0 {
  314. if !bytes.Equal(images[i].Data, tt.images[i]) {
  315. t.Errorf("expected %q, got %q", tt.images[i], images[i].Data)
  316. }
  317. } else {
  318. if images[i].AspectRatioID != tt.aspectRatioID {
  319. t.Errorf("expected aspect ratio %d, got %d", tt.aspectRatioID, images[i].AspectRatioID)
  320. }
  321. }
  322. }
  323. })
  324. }
  325. }