cache_test.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. package ollamarunner
  2. import (
  3. "image"
  4. "testing"
  5. "time"
  6. "github.com/ollama/ollama/model/input"
  7. )
  8. func TestCountCommon(t *testing.T) {
  9. imgA := image.NewRGBA(image.Rect(0, 0, 100, 100))
  10. imgB := image.NewRGBA(image.Rect(0, 0, 50, 50))
  11. imgC := image.NewRGBA(image.Rect(50, 50, 100, 100))
  12. tests := []struct {
  13. name string
  14. t1 []input.Input
  15. t2 []input.Input
  16. expected int32
  17. }{
  18. {
  19. name: "Equal",
  20. t1: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
  21. t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
  22. expected: 3,
  23. },
  24. {
  25. name: "Prefix",
  26. t1: []input.Input{{Token: 1}},
  27. t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
  28. expected: 1,
  29. },
  30. {
  31. name: "Image Prefix",
  32. t1: []input.Input{{Multimodal: imgA, MultimodalHash: 1}},
  33. t2: []input.Input{{Multimodal: imgA, MultimodalHash: 1}, {Multimodal: imgB, MultimodalHash: 2}, {Multimodal: imgC, MultimodalHash: 3}},
  34. expected: 1,
  35. },
  36. {
  37. name: "Mixed",
  38. t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
  39. t2: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}, {Token: 5}},
  40. expected: 2,
  41. },
  42. {
  43. name: "Mixed, Same Length",
  44. t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
  45. t2: []input.Input{{Token: 1}, {Multimodal: imgB, MultimodalHash: 2}},
  46. expected: 1,
  47. },
  48. {
  49. name: "Empty",
  50. t1: []input.Input{},
  51. t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
  52. expected: 0,
  53. },
  54. {
  55. name: "Both Empty",
  56. t1: []input.Input{},
  57. t2: []input.Input{},
  58. expected: 0,
  59. },
  60. }
  61. for _, tt := range tests {
  62. t.Run(tt.name, func(t *testing.T) {
  63. result := countCommonPrefix(tt.t1, tt.t2)
  64. if result != tt.expected {
  65. t.Errorf("countCommonPrefix(%v, %v): have %v; want %v", tt.t1, tt.t2, result, tt.expected)
  66. }
  67. })
  68. }
  69. }
  70. func TestFindCacheSlot(t *testing.T) {
  71. type expected struct {
  72. result int
  73. len int32
  74. }
  75. tests := []struct {
  76. name string
  77. cache InputCache
  78. prompt []input.Input
  79. longest expected
  80. best expected
  81. }{
  82. {
  83. name: "Empty",
  84. cache: InputCache{slots: []InputCacheSlot{
  85. {
  86. Id: 0,
  87. Inputs: []input.Input{},
  88. InUse: false,
  89. lastUsed: time.Time{},
  90. },
  91. {
  92. Id: 1,
  93. Inputs: []input.Input{},
  94. InUse: false,
  95. lastUsed: time.Time{},
  96. },
  97. }},
  98. prompt: []input.Input{{Token: 1}},
  99. longest: expected{result: 0, len: 0},
  100. best: expected{result: 0, len: 0},
  101. },
  102. {
  103. name: "Extend",
  104. cache: InputCache{slots: []InputCacheSlot{
  105. {
  106. Id: 0,
  107. Inputs: []input.Input{{Token: 1}},
  108. InUse: false,
  109. lastUsed: time.Now().Add(-time.Second),
  110. },
  111. {
  112. Id: 1,
  113. Inputs: []input.Input{{Token: 1}, {Token: 2}},
  114. InUse: false,
  115. lastUsed: time.Now().Add(-2 * time.Second),
  116. },
  117. }},
  118. prompt: []input.Input{{Token: 1}, {Token: 2}},
  119. longest: expected{result: 1, len: 2},
  120. best: expected{result: 1, len: 2},
  121. },
  122. {
  123. name: "New",
  124. cache: InputCache{slots: []InputCacheSlot{
  125. {
  126. Id: 0,
  127. Inputs: []input.Input{{Token: 1}, {Token: 2}},
  128. InUse: false,
  129. lastUsed: time.Now().Add(-time.Second),
  130. },
  131. {
  132. Id: 1,
  133. Inputs: []input.Input{},
  134. InUse: false,
  135. lastUsed: time.Time{},
  136. },
  137. }},
  138. prompt: []input.Input{{Token: 2}},
  139. longest: expected{result: 0, len: 0},
  140. best: expected{result: 1, len: 0},
  141. },
  142. {
  143. name: "Fork",
  144. cache: InputCache{
  145. slots: []InputCacheSlot{
  146. {
  147. Id: 0,
  148. Inputs: []input.Input{{Token: 1}, {Token: 2}},
  149. InUse: false,
  150. lastUsed: time.Now().Add(-time.Second),
  151. },
  152. {
  153. Id: 1,
  154. Inputs: []input.Input{},
  155. InUse: false,
  156. lastUsed: time.Time{},
  157. },
  158. },
  159. },
  160. prompt: []input.Input{{Token: 1}},
  161. longest: expected{result: 0, len: 1},
  162. best: expected{result: 1, len: 1},
  163. },
  164. {
  165. name: "Evict",
  166. cache: InputCache{slots: []InputCacheSlot{
  167. {
  168. Id: 0,
  169. Inputs: []input.Input{{Token: 1}},
  170. InUse: false,
  171. lastUsed: time.Now().Add(-time.Second),
  172. },
  173. {
  174. Id: 1,
  175. Inputs: []input.Input{{Token: 1}, {Token: 2}},
  176. InUse: false,
  177. lastUsed: time.Now().Add(-2 * time.Second),
  178. },
  179. }},
  180. prompt: []input.Input{{Token: 2}, {Token: 3}},
  181. longest: expected{result: 0, len: 0},
  182. best: expected{result: 1, len: 0},
  183. },
  184. {
  185. name: "In use",
  186. cache: InputCache{slots: []InputCacheSlot{
  187. {
  188. Id: 0,
  189. Inputs: []input.Input{{Token: 1}, {Token: 2}},
  190. InUse: true,
  191. lastUsed: time.Now().Add(-time.Second),
  192. },
  193. {
  194. Id: 1,
  195. Inputs: []input.Input{{Token: 1}},
  196. InUse: false,
  197. lastUsed: time.Now().Add(-2 * time.Second),
  198. },
  199. }},
  200. prompt: []input.Input{{Token: 1}, {Token: 2}},
  201. longest: expected{result: 1, len: 1},
  202. best: expected{result: 1, len: 2},
  203. },
  204. }
  205. for _, tt := range tests {
  206. t.Run("Longest-"+tt.name, func(t *testing.T) {
  207. result, resultLen, err := tt.cache.findLongestCacheSlot(tt.prompt)
  208. if err != nil {
  209. t.Errorf("findLongestCacheSlot: err %v", err)
  210. } else if result.Id != tt.longest.result || resultLen != tt.longest.len {
  211. t.Errorf("findLongestCacheSlot: slot have %v, want %v len have %v, want %v",
  212. result.Id, tt.longest.result, resultLen, tt.longest.len)
  213. }
  214. })
  215. }
  216. for _, tt := range tests {
  217. t.Run("Best-"+tt.name, func(t *testing.T) {
  218. result, resultLen, err := tt.cache.findBestCacheSlot(tt.prompt)
  219. if err != nil {
  220. t.Errorf("findBestCacheSlot: err %v", err)
  221. } else if result.Id != tt.best.result || resultLen != tt.best.len {
  222. t.Errorf("findBestCacheSlot: slot have %v, want %v len have %v, want %v",
  223. result.Id, tt.best.result, resultLen, tt.best.len)
  224. }
  225. })
  226. }
  227. }
  228. func TestShiftDiscard(t *testing.T) {
  229. tests := []struct {
  230. name string
  231. numCtx int32
  232. numKeep int32
  233. inputLen int32
  234. expected int32
  235. }{
  236. {
  237. name: "Shift",
  238. numCtx: 2048,
  239. numKeep: 5,
  240. inputLen: 2048,
  241. expected: 1021,
  242. },
  243. {
  244. name: "Max Keep",
  245. numCtx: 2048,
  246. numKeep: 2047,
  247. inputLen: 2048,
  248. expected: 1,
  249. },
  250. {
  251. name: "No Keep",
  252. numCtx: 2048,
  253. numKeep: 0,
  254. inputLen: 2048,
  255. expected: 1024,
  256. },
  257. {
  258. name: "Truncate",
  259. numCtx: 2048,
  260. numKeep: 5,
  261. inputLen: 5000,
  262. expected: 3973,
  263. },
  264. {
  265. name: "Truncate Keep",
  266. numCtx: 2048,
  267. numKeep: 2047,
  268. inputLen: 5000,
  269. expected: 2953,
  270. },
  271. {
  272. name: "No Op",
  273. numCtx: 2048,
  274. numKeep: 5,
  275. inputLen: 512,
  276. expected: 0,
  277. },
  278. }
  279. for _, tt := range tests {
  280. t.Run(tt.name, func(t *testing.T) {
  281. c := InputCache{numCtx: tt.numCtx}
  282. result := c.ShiftDiscard(tt.inputLen, tt.numKeep)
  283. if result != tt.expected {
  284. t.Errorf("shiftDiscard(ctx: %v, keep: %v input: %v): have %v; want %v", tt.numCtx, tt.numKeep, tt.inputLen, result, tt.expected)
  285. }
  286. })
  287. }
  288. }
  289. func TestLoadCacheSlot(t *testing.T) {
  290. tests := []struct {
  291. name string
  292. cache InputCache
  293. prompt []input.Input
  294. wantErr bool
  295. expectedSlotId int
  296. expectedPrompt int // expected length of remaining prompt
  297. }{
  298. {
  299. name: "Basic cache hit - single user",
  300. cache: InputCache{
  301. multiUserCache: false,
  302. slots: []InputCacheSlot{
  303. {
  304. Id: 0,
  305. Inputs: []input.Input{{Token: 1}, {Token: 2}},
  306. InUse: false,
  307. lastUsed: time.Now().Add(-time.Second),
  308. },
  309. {
  310. Id: 1,
  311. Inputs: []input.Input{},
  312. InUse: false,
  313. lastUsed: time.Now().Add(-2 * time.Second),
  314. },
  315. },
  316. },
  317. prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
  318. wantErr: false,
  319. expectedSlotId: 0,
  320. expectedPrompt: 1, // Only token 3 remains
  321. },
  322. {
  323. name: "Basic cache hit - multi user",
  324. cache: InputCache{
  325. multiUserCache: true,
  326. slots: []InputCacheSlot{
  327. {
  328. Id: 0,
  329. Inputs: []input.Input{{Token: 1}, {Token: 2}},
  330. InUse: false,
  331. lastUsed: time.Now().Add(-time.Second),
  332. },
  333. {
  334. Id: 1,
  335. Inputs: []input.Input{},
  336. InUse: false,
  337. lastUsed: time.Now().Add(-2 * time.Second),
  338. },
  339. },
  340. },
  341. prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
  342. wantErr: false,
  343. expectedSlotId: 0,
  344. expectedPrompt: 1, // Only token 3 remains
  345. },
  346. {
  347. name: "Exact match - leave one input",
  348. cache: InputCache{
  349. multiUserCache: false,
  350. slots: []InputCacheSlot{
  351. {
  352. Id: 0,
  353. Inputs: []input.Input{{Token: 1}, {Token: 2}},
  354. InUse: false,
  355. lastUsed: time.Now().Add(-time.Second),
  356. },
  357. },
  358. },
  359. prompt: []input.Input{{Token: 1}, {Token: 2}},
  360. wantErr: false,
  361. expectedSlotId: 0,
  362. expectedPrompt: 1, // Should leave 1 token for sampling
  363. },
  364. {
  365. name: "No available slots",
  366. cache: InputCache{
  367. multiUserCache: false,
  368. slots: []InputCacheSlot{
  369. {
  370. Id: 0,
  371. Inputs: []input.Input{{Token: 1}, {Token: 2}},
  372. InUse: true,
  373. lastUsed: time.Now().Add(-time.Second),
  374. },
  375. },
  376. },
  377. prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
  378. wantErr: true,
  379. expectedSlotId: -1,
  380. expectedPrompt: -1,
  381. },
  382. }
  383. for _, tt := range tests {
  384. t.Run(tt.name, func(t *testing.T) {
  385. slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt)
  386. // Check error state
  387. if (err != nil) != tt.wantErr {
  388. t.Errorf("LoadCacheSlot() error = %v, wantErr %v", err, tt.wantErr)
  389. return
  390. }
  391. if tt.wantErr {
  392. return // Skip further checks if we expected an error
  393. }
  394. // Verify slot ID
  395. if slot.Id != tt.expectedSlotId {
  396. t.Errorf("LoadCacheSlot() slot ID = %v, expected %v", slot.Id, tt.expectedSlotId)
  397. }
  398. // Verify slot is now marked in use
  399. if !slot.InUse {
  400. t.Errorf("LoadCacheSlot() slot not marked InUse")
  401. }
  402. // Verify remaining prompt length
  403. if len(remainingPrompt) != tt.expectedPrompt {
  404. t.Errorf("LoadCacheSlot() remaining prompt length = %v, expected %v",
  405. len(remainingPrompt), tt.expectedPrompt)
  406. }
  407. })
  408. }
  409. }