cache_test.go 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. package ollamarunner
  2. import (
  3. "image"
  4. "testing"
  5. "time"
  6. "github.com/ollama/ollama/model/input"
  7. )
  8. func TestCountCommon(t *testing.T) {
  9. imgA := image.NewRGBA(image.Rect(0, 0, 100, 100))
  10. imgB := image.NewRGBA(image.Rect(0, 0, 50, 50))
  11. imgC := image.NewRGBA(image.Rect(50, 50, 100, 100))
  12. tests := []struct {
  13. name string
  14. t1 []input.Input
  15. t2 []input.Input
  16. expected int32
  17. }{
  18. {
  19. name: "Equal",
  20. t1: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
  21. t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
  22. expected: 3,
  23. },
  24. {
  25. name: "Prefix",
  26. t1: []input.Input{{Token: 1}},
  27. t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
  28. expected: 1,
  29. },
  30. {
  31. name: "Image Prefix",
  32. t1: []input.Input{{Multimodal: imgA, MultimodalHash: 1}},
  33. t2: []input.Input{{Multimodal: imgA, MultimodalHash: 1}, {Multimodal: imgB, MultimodalHash: 2}, {Multimodal: imgC, MultimodalHash: 3}},
  34. expected: 1,
  35. },
  36. {
  37. name: "Mixed",
  38. t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
  39. t2: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}, {Token: 5}},
  40. expected: 2,
  41. },
  42. {
  43. name: "Mixed, Same Length",
  44. t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
  45. t2: []input.Input{{Token: 1}, {Multimodal: imgB, MultimodalHash: 2}},
  46. expected: 1,
  47. },
  48. {
  49. name: "Empty",
  50. t1: []input.Input{},
  51. t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
  52. expected: 0,
  53. },
  54. {
  55. name: "Both Empty",
  56. t1: []input.Input{},
  57. t2: []input.Input{},
  58. expected: 0,
  59. },
  60. }
  61. for _, tt := range tests {
  62. t.Run(tt.name, func(t *testing.T) {
  63. result := countCommonPrefix(tt.t1, tt.t2)
  64. if result != tt.expected {
  65. t.Errorf("countCommonPrefix(%v, %v): have %v; want %v", tt.t1, tt.t2, result, tt.expected)
  66. }
  67. })
  68. }
  69. }
  70. func TestFindCacheSlot(t *testing.T) {
  71. type expected struct {
  72. result int
  73. len int32
  74. }
  75. tests := []struct {
  76. name string
  77. cache InputCache
  78. prompt []input.Input
  79. longest expected
  80. best expected
  81. }{
  82. {
  83. name: "Empty",
  84. cache: InputCache{slots: []InputCacheSlot{
  85. {
  86. Id: 0,
  87. Inputs: []input.Input{},
  88. InUse: false,
  89. lastUsed: time.Time{},
  90. },
  91. {
  92. Id: 1,
  93. Inputs: []input.Input{},
  94. InUse: false,
  95. lastUsed: time.Time{},
  96. },
  97. }},
  98. prompt: []input.Input{{Token: 1}},
  99. longest: expected{result: 0, len: 0},
  100. best: expected{result: 0, len: 0},
  101. },
  102. {
  103. name: "Extend",
  104. cache: InputCache{slots: []InputCacheSlot{
  105. {
  106. Id: 0,
  107. Inputs: []input.Input{{Token: 1}},
  108. InUse: false,
  109. lastUsed: time.Now().Add(-time.Second),
  110. },
  111. {
  112. Id: 1,
  113. Inputs: []input.Input{{Token: 1}, {Token: 2}},
  114. InUse: false,
  115. lastUsed: time.Now().Add(-2 * time.Second),
  116. },
  117. }},
  118. prompt: []input.Input{{Token: 1}, {Token: 2}},
  119. longest: expected{result: 1, len: 2},
  120. best: expected{result: 1, len: 2},
  121. },
  122. {
  123. name: "New",
  124. cache: InputCache{slots: []InputCacheSlot{
  125. {
  126. Id: 0,
  127. Inputs: []input.Input{{Token: 1}, {Token: 2}},
  128. InUse: false,
  129. lastUsed: time.Now().Add(-time.Second),
  130. },
  131. {
  132. Id: 1,
  133. Inputs: []input.Input{},
  134. InUse: false,
  135. lastUsed: time.Time{},
  136. },
  137. }},
  138. prompt: []input.Input{{Token: 2}},
  139. longest: expected{result: 0, len: 0},
  140. best: expected{result: 1, len: 0},
  141. },
  142. {
  143. name: "Fork",
  144. cache: InputCache{
  145. slots: []InputCacheSlot{
  146. {
  147. Id: 0,
  148. Inputs: []input.Input{{Token: 1}, {Token: 2}},
  149. InUse: false,
  150. lastUsed: time.Now().Add(-time.Second),
  151. },
  152. {
  153. Id: 1,
  154. Inputs: []input.Input{},
  155. InUse: false,
  156. lastUsed: time.Time{},
  157. },
  158. },
  159. },
  160. prompt: []input.Input{{Token: 1}},
  161. longest: expected{result: 0, len: 1},
  162. best: expected{result: 1, len: 1},
  163. },
  164. {
  165. name: "Evict",
  166. cache: InputCache{slots: []InputCacheSlot{
  167. {
  168. Id: 0,
  169. Inputs: []input.Input{{Token: 1}},
  170. InUse: false,
  171. lastUsed: time.Now().Add(-time.Second),
  172. },
  173. {
  174. Id: 1,
  175. Inputs: []input.Input{{Token: 1}, {Token: 2}},
  176. InUse: false,
  177. lastUsed: time.Now().Add(-2 * time.Second),
  178. },
  179. }},
  180. prompt: []input.Input{{Token: 2}, {Token: 3}},
  181. longest: expected{result: 0, len: 0},
  182. best: expected{result: 1, len: 0},
  183. },
  184. {
  185. name: "In use",
  186. cache: InputCache{slots: []InputCacheSlot{
  187. {
  188. Id: 0,
  189. Inputs: []input.Input{{Token: 1}, {Token: 2}},
  190. InUse: true,
  191. lastUsed: time.Now().Add(-time.Second),
  192. },
  193. {
  194. Id: 1,
  195. Inputs: []input.Input{{Token: 1}},
  196. InUse: false,
  197. lastUsed: time.Now().Add(-2 * time.Second),
  198. },
  199. }},
  200. prompt: []input.Input{{Token: 1}, {Token: 2}},
  201. longest: expected{result: 1, len: 1},
  202. best: expected{result: 1, len: 2},
  203. },
  204. }
  205. for _, tt := range tests {
  206. t.Run("Longest-"+tt.name, func(t *testing.T) {
  207. result, resultLen, err := tt.cache.findLongestCacheSlot(tt.prompt)
  208. if err != nil {
  209. t.Errorf("findLongestCacheSlot: err %v", err)
  210. } else if result.Id != tt.longest.result || resultLen != tt.longest.len {
  211. t.Errorf("findLongestCacheSlot: slot have %v, want %v len have %v, want %v",
  212. result.Id, tt.longest.result, resultLen, tt.longest.len)
  213. }
  214. })
  215. }
  216. for _, tt := range tests {
  217. t.Run("Best-"+tt.name, func(t *testing.T) {
  218. result, resultLen, err := tt.cache.findBestCacheSlot(tt.prompt)
  219. if err != nil {
  220. t.Errorf("findBestCacheSlot: err %v", err)
  221. } else if result.Id != tt.best.result || resultLen != tt.best.len {
  222. t.Errorf("findBestCacheSlot: slot have %v, want %v len have %v, want %v",
  223. result.Id, tt.best.result, resultLen, tt.best.len)
  224. }
  225. })
  226. }
  227. }
  228. func TestShiftDiscard(t *testing.T) {
  229. tests := []struct {
  230. name string
  231. numCtx int32
  232. numKeep int32
  233. inputLen int32
  234. expected int32
  235. }{
  236. {
  237. name: "Shift",
  238. numCtx: 2048,
  239. numKeep: 5,
  240. inputLen: 2048,
  241. expected: 1021,
  242. },
  243. {
  244. name: "Max Keep",
  245. numCtx: 2048,
  246. numKeep: 2047,
  247. inputLen: 2048,
  248. expected: 1,
  249. },
  250. {
  251. name: "No Keep",
  252. numCtx: 2048,
  253. numKeep: 0,
  254. inputLen: 2048,
  255. expected: 1024,
  256. },
  257. {
  258. name: "Truncate",
  259. numCtx: 2048,
  260. numKeep: 5,
  261. inputLen: 5000,
  262. expected: 3973,
  263. },
  264. {
  265. name: "Truncate Keep",
  266. numCtx: 2048,
  267. numKeep: 2047,
  268. inputLen: 5000,
  269. expected: 2953,
  270. },
  271. {
  272. name: "No Op",
  273. numCtx: 2048,
  274. numKeep: 5,
  275. inputLen: 512,
  276. expected: 0,
  277. },
  278. }
  279. for _, tt := range tests {
  280. t.Run(tt.name, func(t *testing.T) {
  281. c := InputCache{numCtx: tt.numCtx}
  282. result := c.ShiftDiscard(tt.inputLen, tt.numKeep)
  283. if result != tt.expected {
  284. t.Errorf("shiftDiscard(ctx: %v, keep: %v input: %v): have %v; want %v", tt.numCtx, tt.numKeep, tt.inputLen, result, tt.expected)
  285. }
  286. })
  287. }
  288. }