causal_test.go 17 KB


  1. package kvcache
  2. import (
  3. "math"
  4. "slices"
  5. "testing"
  6. "github.com/ollama/ollama/ml"
  7. )
  8. type testCase struct {
  9. name string
  10. in []float32
  11. inShape []int
  12. seqs []int
  13. pos []int32
  14. expected []float32
  15. expectedShape []int
  16. expectedMask []float32
  17. }
  18. func TestStore(t *testing.T) {
  19. backend := &testBackend{}
  20. cache := NewCausalCache(nil)
  21. defer cache.Close()
  22. cache.Init(backend, ml.DTypeF16, 16)
  23. tests := []testCase{
  24. {
  25. name: "FirstBatch",
  26. in: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
  27. inShape: []int{2, 3, 4},
  28. seqs: []int{0, 0, 0, 0},
  29. pos: []int32{0, 1, 2, 3},
  30. expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
  31. expectedShape: []int{2, 3, 4},
  32. expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
  33. },
  34. {
  35. name: "SecondBatch",
  36. in: []float32{115, 215, 125, 225, 135, 235},
  37. inShape: []int{2, 3, 1},
  38. seqs: []int{0},
  39. pos: []int32{4},
  40. expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234, 115, 215, 125, 225, 135, 235},
  41. expectedShape: []int{2, 3, 5},
  42. expectedMask: []float32{0, 0, 0, 0, 0},
  43. },
  44. }
  45. testCache(t, backend, cache, tests)
  46. }
  47. func TestSWA(t *testing.T) {
  48. backend := &testBackend{}
  49. cache := NewSWACache(1, nil)
  50. defer cache.Close()
  51. cache.Init(backend, ml.DTypeF32, 16)
  52. tests := []testCase{
  53. {
  54. name: "SlidingWindow",
  55. in: []float32{1, 2, 3, 4},
  56. inShape: []int{1, 1, 4},
  57. seqs: []int{0, 0, 0, 0},
  58. pos: []int32{0, 1, 2, 3},
  59. expected: []float32{1, 2, 3, 4},
  60. expectedShape: []int{1, 1, 4},
  61. expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
  62. },
  63. }
  64. testCache(t, backend, cache, tests)
  65. }
  66. func TestSequences(t *testing.T) {
  67. backend := &testBackend{}
  68. cache := NewCausalCache(nil)
  69. defer cache.Close()
  70. cache.Init(backend, ml.DTypeF16, 16)
  71. tests := []testCase{
  72. {
  73. name: "FirstBatch",
  74. in: []float32{1, 2, 3, 4},
  75. inShape: []int{1, 1, 4},
  76. seqs: []int{0, 0, 1, 1},
  77. pos: []int32{0, 1, 0, 1},
  78. expected: []float32{1, 2, 3, 4},
  79. expectedShape: []int{1, 1, 4},
  80. expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
  81. },
  82. {
  83. name: "SecondBatch",
  84. in: []float32{5, 6},
  85. inShape: []int{1, 1, 2},
  86. seqs: []int{0, 1},
  87. pos: []int32{2, 2},
  88. expected: []float32{1, 2, 3, 4, 5, 6},
  89. expectedShape: []int{1, 1, 6},
  90. expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
  91. },
  92. }
  93. testCache(t, backend, cache, tests)
  94. }
  95. func TestRemove(t *testing.T) {
  96. backend := &testBackend{}
  97. cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
  98. return key.Add(ctx, shift), nil
  99. })
  100. defer cache.Close()
  101. cache.Init(backend, ml.DTypeF16, 16)
  102. tests := []testCase{
  103. {
  104. name: "FirstBatch",
  105. in: []float32{1, 2, 3, 4},
  106. inShape: []int{1, 1, 4},
  107. seqs: []int{0, 0, 1, 1},
  108. pos: []int32{0, 1, 0, 1},
  109. expected: []float32{1, 2, 3, 4},
  110. expectedShape: []int{1, 1, 4},
  111. expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
  112. },
  113. }
  114. testCache(t, backend, cache, tests)
  115. err := cache.Remove(0, 1, math.MaxInt32)
  116. if err != nil {
  117. panic(err)
  118. }
  119. tests = []testCase{
  120. {
  121. name: "RemoveEnd",
  122. in: []float32{5, 6},
  123. inShape: []int{1, 1, 2},
  124. seqs: []int{0, 1},
  125. pos: []int32{1, 2},
  126. expected: []float32{1, 2, 3, 4, 5, 6},
  127. expectedShape: []int{1, 1, 6},
  128. expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
  129. },
  130. }
  131. testCache(t, backend, cache, tests)
  132. err = cache.Remove(0, 0, 1)
  133. if err != nil {
  134. panic(err)
  135. }
  136. tests = []testCase{
  137. {
  138. name: "RemoveMiddle",
  139. in: []float32{7, 8},
  140. inShape: []int{1, 1, 2},
  141. seqs: []int{0, 0},
  142. pos: []int32{1, 2},
  143. expected: []float32{7, 8, 3, 4, 4},
  144. expectedShape: []int{1, 1, 5},
  145. expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0},
  146. },
  147. }
  148. testCache(t, backend, cache, tests)
  149. }
  150. func TestDefrag(t *testing.T) {
  151. backend := &testBackend{}
  152. cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
  153. return key.Add(ctx, shift), nil
  154. })
  155. defer cache.Close()
  156. cache.Init(backend, ml.DTypeF16, 16)
  157. tests := []testCase{
  158. {
  159. name: "FirstBatch",
  160. in: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
  161. inShape: []int{1, 1, 16},
  162. seqs: []int{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
  163. pos: []int32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
  164. expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
  165. expectedShape: []int{1, 1, 16},
  166. expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
  167. },
  168. }
  169. testCache(t, backend, cache, tests)
  170. err := cache.Remove(0, 2, 4)
  171. if err != nil {
  172. panic(err)
  173. }
  174. err = cache.Remove(0, 13, math.MaxInt32)
  175. if err != nil {
  176. panic(err)
  177. }
  178. tests = []testCase{
  179. {
  180. name: "Defrag",
  181. in: []float32{17, 18, 19},
  182. inShape: []int{1, 1, 3},
  183. seqs: []int{0, 0, 0},
  184. pos: []int32{16, 17, 18},
  185. expected: []float32{1, 2, 12, 13, 3, 4, 5, 6, 7, 8, 9, 10, 11, 17, 18, 19},
  186. expectedShape: []int{1, 1, 16},
  187. expectedMask: []float32{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
  188. },
  189. }
  190. testCache(t, backend, cache, tests)
  191. }
  192. func TestCopy(t *testing.T) {
  193. backend := &testBackend{}
  194. cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
  195. defer cache.Close()
  196. cache.Init(backend, ml.DTypeF16, 16)
  197. tests := []testCase{
  198. {
  199. name: "FirstBatch",
  200. in: []float32{1, 2, 3, 4},
  201. inShape: []int{1, 1, 4},
  202. seqs: []int{0, 0, 0, 0},
  203. pos: []int32{0, 1, 2, 3},
  204. expected: []float32{1, 2, 3, 4},
  205. expectedShape: []int{1, 1, 4},
  206. expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
  207. },
  208. }
  209. testCache(t, backend, cache, tests)
  210. cache.CopyPrefix(0, 1, 2)
  211. tests = []testCase{
  212. {
  213. name: "Copy",
  214. in: []float32{5, 6},
  215. inShape: []int{1, 1, 2},
  216. seqs: []int{1, 1},
  217. pos: []int32{3, 4},
  218. expected: []float32{1, 2, 3, 4, 5, 6},
  219. expectedShape: []int{1, 1, 6},
  220. expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
  221. },
  222. }
  223. testCache(t, backend, cache, tests)
  224. }
  225. func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) {
  226. for _, test := range tests {
  227. t.Run(test.name, func(t *testing.T) {
  228. context := backend.NewContext()
  229. defer context.Close()
  230. err := cache.StartForward(context, test.pos, test.seqs)
  231. if err != nil {
  232. panic(err)
  233. }
  234. cache.SetLayer(0)
  235. tensor, _ := context.FromFloatSlice(test.in, test.inShape...)
  236. cache.Put(context, tensor, tensor)
  237. out, _, mask := cache.Get(context)
  238. context.Forward(out)
  239. context.Forward(mask)
  240. context.Compute(out, mask)
  241. if !slices.Equal(out.Floats(), test.expected) || !slices.Equal(out.Shape(), test.expectedShape) || !slices.Equal(mask.Floats(), test.expectedMask) {
  242. t.Errorf("TestCache: have %v (shape %v); want %v (shape %v); mask: have %v (shape %v) want %v", out.Floats(), out.Shape(), test.expected, test.expectedShape, mask.Floats(), mask.Shape(), test.expectedMask)
  243. }
  244. })
  245. }
  246. }
  247. type testBackend struct{}
  248. func (b *testBackend) Config() ml.Config {
  249. panic("not implemented")
  250. }
  251. func (b *testBackend) Get(name string) ml.Tensor {
  252. panic("not implemented")
  253. }
  254. func (b *testBackend) NewContext() ml.Context {
  255. return &testContext{}
  256. }
  257. func (b *testBackend) SystemInfo() string {
  258. return "not implemented"
  259. }
  260. type testContext struct{}
  261. func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
  262. total := 0
  263. if len(shape) > 0 {
  264. total = 1
  265. for _, s := range shape {
  266. total *= s
  267. }
  268. }
  269. return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
  270. }
  271. func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
  272. t := c.Zeros(ml.DTypeF32, shape...).(*testTensor)
  273. copy(t.data, s)
  274. return t, nil
  275. }
  276. func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
  277. f := make([]float32, len(s))
  278. for i := range f {
  279. f[i] = float32(s[i])
  280. }
  281. out, _ := c.FromFloatSlice(f, shape...)
  282. out.(*testTensor).dtype = ml.DTypeI32
  283. return out, nil
  284. }
  285. func (c *testContext) Forward(ml.Tensor) {}
  286. func (c *testContext) Compute(...ml.Tensor) {}
  287. func (c *testContext) MaxTensors() int {
  288. return 10
  289. }
  290. func (c *testContext) Close() {}
  291. type testTensor struct {
  292. dtype ml.DType
  293. elementSize int
  294. data []float32
  295. shape []int
  296. }
  297. func (t *testTensor) Dim(n int) int {
  298. return t.shape[n]
  299. }
  300. func (t *testTensor) Stride(n int) int {
  301. stride := t.elementSize
  302. for i := range n {
  303. stride *= t.shape[i]
  304. }
  305. return stride
  306. }
  307. func (t *testTensor) Shape() []int {
  308. return t.shape
  309. }
  310. func (t *testTensor) DType() ml.DType {
  311. return t.dtype
  312. }
  313. func (t *testTensor) Bytes() []byte {
  314. panic("not implemented")
  315. }
  316. func (t *testTensor) Floats() []float32 {
  317. out := make([]float32, len(t.data))
  318. copy(out, t.data)
  319. return out
  320. }
  321. func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  322. out := ctx.Zeros(t.DType(), t.Shape()...).(*testTensor)
  323. for i := range out.data {
  324. out.data[i] = t.data[i] + t2.(*testTensor).data[i]
  325. }
  326. return out
  327. }
  328. func (t *testTensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  329. panic("not implemented")
  330. }
  331. func (t *testTensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  332. panic("not implemented")
  333. }
  334. func (t *testTensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  335. panic("not implemented")
  336. }
  337. func (t *testTensor) Softmax(ctx ml.Context) ml.Tensor {
  338. panic("not implemented")
  339. }
  340. func (t *testTensor) LayerNorm(ctx ml.Context, weight, bias ml.Tensor, eps float32) ml.Tensor {
  341. panic("not implemented")
  342. }
  343. func (t *testTensor) RMSNorm(ctx ml.Context, weight ml.Tensor, eps float32) ml.Tensor {
  344. panic("not implemented")
  345. }
  346. func (t *testTensor) Scale(ctx ml.Context, s float64) ml.Tensor {
  347. panic("not implemented")
  348. }
  349. func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
  350. panic("not implemented")
  351. }
  352. func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim uint32, base, scale float32) ml.Tensor {
  353. panic("not implemented")
  354. }
  355. func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor {
  356. panic("not implemented")
  357. }
  358. func (t *testTensor) GELU(ctx ml.Context) ml.Tensor {
  359. panic("not implemented")
  360. }
  361. func (t *testTensor) SILU(ctx ml.Context) ml.Tensor {
  362. panic("not implemented")
  363. }
  364. func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
  365. panic("not implemented")
  366. }
  367. func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
  368. offset /= t.elementSize
  369. var s []int
  370. switch len(shape) {
  371. case 1:
  372. s = []int{shape[0]}
  373. case 5:
  374. s = []int{shape[0], shape[2], shape[4]}
  375. default:
  376. panic("unsupported number of dimensions")
  377. }
  378. context := &testContext{}
  379. view := context.Zeros(t.dtype, s...).(*testTensor)
  380. view.data = t.data[offset : offset+len(view.data)]
  381. return view
  382. }
  383. func (t *testTensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
  384. panic("not implemented")
  385. }
  386. func (t *testTensor) Contiguous(ctx ml.Context) ml.Tensor {
  387. panic("not implemented")
  388. }
  389. func (t *testTensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
  390. panic("not implemented")
  391. }
  392. func (t *testTensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
  393. panic("not implemented")
  394. }
  395. func (t *testTensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
  396. panic("not implemented")
  397. }
  398. func (t *testTensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
  399. panic("not implemented")
  400. }
  401. func (t *testTensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  402. panic("not implemented")
  403. }
  404. func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  405. copy(t2.(*testTensor).data, t.data)
  406. return nil
  407. }