causal_test.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512
  1. package kvcache
  2. import (
  3. "math"
  4. "slices"
  5. "testing"
  6. "github.com/ollama/ollama/ml"
  7. )
  8. type testCase struct {
  9. name string
  10. in []float32
  11. inShape []int
  12. seqs []int
  13. pos []int32
  14. expected []float32
  15. expectedShape []int
  16. expectedMask []float32
  17. }
  18. func TestStore(t *testing.T) {
  19. backend := &testBackend{}
  20. cache := NewCausalCache(nil)
  21. defer cache.Close()
  22. cache.Init(backend, ml.DTypeF16, 16)
  23. tests := []testCase{
  24. {
  25. name: "FirstBatch",
  26. in: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
  27. inShape: []int{2, 3, 4},
  28. seqs: []int{0, 0, 0, 0},
  29. pos: []int32{0, 1, 2, 3},
  30. expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
  31. expectedShape: []int{2, 3, 4},
  32. expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
  33. },
  34. {
  35. name: "SecondBatch",
  36. in: []float32{115, 215, 125, 225, 135, 235},
  37. inShape: []int{2, 3, 1},
  38. seqs: []int{0},
  39. pos: []int32{4},
  40. expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234, 115, 215, 125, 225, 135, 235},
  41. expectedShape: []int{2, 3, 5},
  42. expectedMask: []float32{0, 0, 0, 0, 0},
  43. },
  44. }
  45. testCache(t, backend, cache, tests)
  46. }
  47. func TestSWA(t *testing.T) {
  48. backend := &testBackend{}
  49. cache := NewSWACache(1, nil)
  50. defer cache.Close()
  51. cache.Init(backend, ml.DTypeF32, 16)
  52. tests := []testCase{
  53. {
  54. name: "SlidingWindow",
  55. in: []float32{1, 2, 3, 4},
  56. inShape: []int{1, 1, 4},
  57. seqs: []int{0, 0, 0, 0},
  58. pos: []int32{0, 1, 2, 3},
  59. expected: []float32{1, 2, 3, 4},
  60. expectedShape: []int{1, 1, 4},
  61. expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
  62. },
  63. }
  64. testCache(t, backend, cache, tests)
  65. }
  66. func TestSequences(t *testing.T) {
  67. backend := &testBackend{}
  68. cache := NewCausalCache(nil)
  69. defer cache.Close()
  70. cache.Init(backend, ml.DTypeF16, 16)
  71. tests := []testCase{
  72. {
  73. name: "FirstBatch",
  74. in: []float32{1, 2, 3, 4},
  75. inShape: []int{1, 1, 4},
  76. seqs: []int{0, 0, 1, 1},
  77. pos: []int32{0, 1, 0, 1},
  78. expected: []float32{1, 2, 3, 4},
  79. expectedShape: []int{1, 1, 4},
  80. expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
  81. },
  82. {
  83. name: "SecondBatch",
  84. in: []float32{5, 6},
  85. inShape: []int{1, 1, 2},
  86. seqs: []int{0, 1},
  87. pos: []int32{2, 2},
  88. expected: []float32{1, 2, 3, 4, 5, 6},
  89. expectedShape: []int{1, 1, 6},
  90. expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
  91. },
  92. }
  93. testCache(t, backend, cache, tests)
  94. }
  95. func TestRemove(t *testing.T) {
  96. backend := &testBackend{}
  97. cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
  98. return key.Add(ctx, shift), nil
  99. })
  100. defer cache.Close()
  101. cache.Init(backend, ml.DTypeF16, 16)
  102. tests := []testCase{
  103. {
  104. name: "FirstBatch",
  105. in: []float32{1, 2, 3, 4},
  106. inShape: []int{1, 1, 4},
  107. seqs: []int{0, 0, 1, 1},
  108. pos: []int32{0, 1, 0, 1},
  109. expected: []float32{1, 2, 3, 4},
  110. expectedShape: []int{1, 1, 4},
  111. expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
  112. },
  113. }
  114. testCache(t, backend, cache, tests)
  115. err := cache.Remove(0, 1, math.MaxInt32)
  116. if err != nil {
  117. panic(err)
  118. }
  119. tests = []testCase{
  120. {
  121. name: "RemoveEnd",
  122. in: []float32{5, 6},
  123. inShape: []int{1, 1, 2},
  124. seqs: []int{0, 1},
  125. pos: []int32{1, 2},
  126. expected: []float32{1, 2, 3, 4, 5, 6},
  127. expectedShape: []int{1, 1, 6},
  128. expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
  129. },
  130. }
  131. testCache(t, backend, cache, tests)
  132. err = cache.Remove(0, 0, 1)
  133. if err != nil {
  134. panic(err)
  135. }
  136. tests = []testCase{
  137. {
  138. name: "RemoveMiddle",
  139. in: []float32{7, 8},
  140. inShape: []int{1, 1, 2},
  141. seqs: []int{0, 0},
  142. pos: []int32{1, 2},
  143. expected: []float32{7, 8, 3, 4, 4},
  144. expectedShape: []int{1, 1, 5},
  145. expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0},
  146. },
  147. }
  148. testCache(t, backend, cache, tests)
  149. }
  150. func TestDefrag(t *testing.T) {
  151. backend := &testBackend{}
  152. cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
  153. return key.Add(ctx, shift), nil
  154. })
  155. defer cache.Close()
  156. cache.Init(backend, ml.DTypeF16, 16)
  157. tests := []testCase{
  158. {
  159. name: "FirstBatch",
  160. in: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
  161. inShape: []int{1, 1, 16},
  162. seqs: []int{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
  163. pos: []int32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
  164. expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
  165. expectedShape: []int{1, 1, 16},
  166. expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
  167. },
  168. }
  169. testCache(t, backend, cache, tests)
  170. err := cache.Remove(0, 2, 4)
  171. if err != nil {
  172. panic(err)
  173. }
  174. err = cache.Remove(0, 13, math.MaxInt32)
  175. if err != nil {
  176. panic(err)
  177. }
  178. tests = []testCase{
  179. {
  180. name: "Defrag",
  181. in: []float32{17, 18, 19},
  182. inShape: []int{1, 1, 3},
  183. seqs: []int{0, 0, 0},
  184. pos: []int32{16, 17, 18},
  185. expected: []float32{1, 2, 12, 13, 3, 4, 5, 6, 7, 8, 9, 10, 11, 17, 18, 19},
  186. expectedShape: []int{1, 1, 16},
  187. expectedMask: []float32{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
  188. },
  189. }
  190. testCache(t, backend, cache, tests)
  191. }
  192. func TestCopy(t *testing.T) {
  193. backend := &testBackend{}
  194. cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
  195. defer cache.Close()
  196. cache.Init(backend, ml.DTypeF16, 16)
  197. tests := []testCase{
  198. {
  199. name: "FirstBatch",
  200. in: []float32{1, 2, 3, 4},
  201. inShape: []int{1, 1, 4},
  202. seqs: []int{0, 0, 0, 0},
  203. pos: []int32{0, 1, 2, 3},
  204. expected: []float32{1, 2, 3, 4},
  205. expectedShape: []int{1, 1, 4},
  206. expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
  207. },
  208. }
  209. testCache(t, backend, cache, tests)
  210. cache.CopyPrefix(0, 1, 2)
  211. tests = []testCase{
  212. {
  213. name: "Copy",
  214. in: []float32{5, 6},
  215. inShape: []int{1, 1, 2},
  216. seqs: []int{1, 1},
  217. pos: []int32{3, 4},
  218. expected: []float32{1, 2, 3, 4, 5, 6},
  219. expectedShape: []int{1, 1, 6},
  220. expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
  221. },
  222. }
  223. testCache(t, backend, cache, tests)
  224. }
  225. func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) {
  226. for _, test := range tests {
  227. t.Run(test.name, func(t *testing.T) {
  228. context := backend.NewContext()
  229. defer context.Close()
  230. err := cache.StartForward(context, test.pos, test.seqs)
  231. if err != nil {
  232. panic(err)
  233. }
  234. cache.SetLayer(0)
  235. tensor, _ := context.FromFloatSlice(test.in, test.inShape...)
  236. cache.Put(context, tensor, tensor)
  237. out, _, mask := cache.Get(context)
  238. context.Forward(out, mask).Compute(out, mask)
  239. if !slices.Equal(out.Floats(), test.expected) || !slices.Equal(out.Shape(), test.expectedShape) || !slices.Equal(mask.Floats(), test.expectedMask) {
  240. t.Errorf("TestCache: have %v (shape %v); want %v (shape %v); mask: have %v (shape %v) want %v", out.Floats(), out.Shape(), test.expected, test.expectedShape, mask.Floats(), mask.Shape(), test.expectedMask)
  241. }
  242. })
  243. }
  244. }
  245. type testBackend struct{}
  246. func (b *testBackend) Config() ml.Config {
  247. panic("not implemented")
  248. }
  249. func (b *testBackend) Get(name string) ml.Tensor {
  250. panic("not implemented")
  251. }
  252. func (b *testBackend) NewContext() ml.Context {
  253. return &testContext{}
  254. }
  255. func (b *testBackend) SystemInfo() string {
  256. return "not implemented"
  257. }
  258. type testContext struct{}
  259. func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
  260. total := 0
  261. if len(shape) > 0 {
  262. total = 1
  263. for _, s := range shape {
  264. total *= s
  265. }
  266. }
  267. return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
  268. }
  269. func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
  270. return c.Empty(dtype, shape...)
  271. }
  272. func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
  273. t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
  274. copy(t.data, s)
  275. return t, nil
  276. }
  277. func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
  278. f := make([]float32, len(s))
  279. for i := range f {
  280. f[i] = float32(s[i])
  281. }
  282. out, _ := c.FromFloatSlice(f, shape...)
  283. out.(*testTensor).dtype = ml.DTypeI32
  284. return out, nil
  285. }
  286. func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
  287. func (c *testContext) Compute(...ml.Tensor) {}
  288. func (c *testContext) MaxTensors() int {
  289. return 10
  290. }
  291. func (c *testContext) Close() {}
  292. type testTensor struct {
  293. dtype ml.DType
  294. elementSize int
  295. data []float32
  296. shape []int
  297. }
  298. func (t *testTensor) Dim(n int) int {
  299. return t.shape[n]
  300. }
  301. func (t *testTensor) Stride(n int) int {
  302. stride := t.elementSize
  303. for i := range n {
  304. stride *= t.shape[i]
  305. }
  306. return stride
  307. }
  308. func (t *testTensor) Shape() []int {
  309. return t.shape
  310. }
  311. func (t *testTensor) DType() ml.DType {
  312. return t.dtype
  313. }
  314. func (t *testTensor) Bytes() []byte {
  315. panic("not implemented")
  316. }
  317. func (t *testTensor) Floats() []float32 {
  318. out := make([]float32, len(t.data))
  319. copy(out, t.data)
  320. return out
  321. }
  322. func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  323. out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
  324. for i := range out.data {
  325. out.data[i] = t.data[i] + t2.(*testTensor).data[i]
  326. }
  327. return out
  328. }
  329. func (t *testTensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  330. panic("not implemented")
  331. }
  332. func (t *testTensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  333. panic("not implemented")
  334. }
  335. func (t *testTensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  336. panic("not implemented")
  337. }
  338. func (t *testTensor) Softmax(ctx ml.Context) ml.Tensor {
  339. panic("not implemented")
  340. }
  341. func (t *testTensor) LayerNorm(ctx ml.Context, weight, bias ml.Tensor, eps float32) ml.Tensor {
  342. panic("not implemented")
  343. }
  344. func (t *testTensor) RMSNorm(ctx ml.Context, weight ml.Tensor, eps float32) ml.Tensor {
  345. panic("not implemented")
  346. }
  347. func (t *testTensor) Scale(ctx ml.Context, s float64) ml.Tensor {
  348. panic("not implemented")
  349. }
  350. func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
  351. panic("not implemented")
  352. }
  353. func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim uint32, base, scale float32) ml.Tensor {
  354. panic("not implemented")
  355. }
  356. func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor {
  357. panic("not implemented")
  358. }
  359. func (t *testTensor) GELU(ctx ml.Context) ml.Tensor {
  360. panic("not implemented")
  361. }
  362. func (t *testTensor) SILU(ctx ml.Context) ml.Tensor {
  363. panic("not implemented")
  364. }
  365. func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
  366. panic("not implemented")
  367. }
  368. func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
  369. offset /= t.elementSize
  370. var s []int
  371. switch len(shape) {
  372. case 1:
  373. s = []int{shape[0]}
  374. case 5:
  375. s = []int{shape[0], shape[2], shape[4]}
  376. default:
  377. panic("unsupported number of dimensions")
  378. }
  379. context := &testContext{}
  380. view := context.Empty(t.dtype, s...).(*testTensor)
  381. view.data = t.data[offset : offset+len(view.data)]
  382. return view
  383. }
  384. func (t *testTensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
  385. panic("not implemented")
  386. }
  387. func (t *testTensor) Contiguous(ctx ml.Context) ml.Tensor {
  388. panic("not implemented")
  389. }
  390. func (t *testTensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
  391. panic("not implemented")
  392. }
  393. func (t *testTensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
  394. panic("not implemented")
  395. }
  396. func (t *testTensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
  397. panic("not implemented")
  398. }
  399. func (t *testTensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
  400. panic("not implemented")
  401. }
  402. func (t *testTensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  403. panic("not implemented")
  404. }
  405. func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  406. copy(t2.(*testTensor).data, t.data)
  407. return nil
  408. }