causal_test.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521
  1. package kvcache
  2. import (
  3. "math"
  4. "slices"
  5. "testing"
  6. "github.com/ollama/ollama/ml"
  7. "github.com/ollama/ollama/model/input"
  8. )
  9. type testCase struct {
  10. name string
  11. in []float32
  12. inShape []int
  13. seqs []int
  14. pos []int32
  15. expected []float32
  16. expectedShape []int
  17. expectedMask []float32
  18. }
  19. func TestStore(t *testing.T) {
  20. backend := &testBackend{}
  21. cache := NewCausalCache(nil)
  22. defer cache.Close()
  23. cache.Init(backend, ml.DTypeF16, 16)
  24. tests := []testCase{
  25. {
  26. name: "FirstBatch",
  27. in: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
  28. inShape: []int{2, 3, 4},
  29. seqs: []int{0, 0, 0, 0},
  30. pos: []int32{0, 1, 2, 3},
  31. expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
  32. expectedShape: []int{2, 3, 4},
  33. expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
  34. },
  35. {
  36. name: "SecondBatch",
  37. in: []float32{115, 215, 125, 225, 135, 235},
  38. inShape: []int{2, 3, 1},
  39. seqs: []int{0},
  40. pos: []int32{4},
  41. expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234, 115, 215, 125, 225, 135, 235},
  42. expectedShape: []int{2, 3, 5},
  43. expectedMask: []float32{0, 0, 0, 0, 0},
  44. },
  45. }
  46. testCache(t, backend, cache, tests)
  47. }
  48. func TestSWA(t *testing.T) {
  49. backend := &testBackend{}
  50. cache := NewSWACache(1, nil)
  51. defer cache.Close()
  52. cache.Init(backend, ml.DTypeF32, 16)
  53. tests := []testCase{
  54. {
  55. name: "SlidingWindow",
  56. in: []float32{1, 2, 3, 4},
  57. inShape: []int{1, 1, 4},
  58. seqs: []int{0, 0, 0, 0},
  59. pos: []int32{0, 1, 2, 3},
  60. expected: []float32{1, 2, 3, 4},
  61. expectedShape: []int{1, 1, 4},
  62. expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
  63. },
  64. }
  65. testCache(t, backend, cache, tests)
  66. }
  67. func TestSequences(t *testing.T) {
  68. backend := &testBackend{}
  69. cache := NewCausalCache(nil)
  70. defer cache.Close()
  71. cache.Init(backend, ml.DTypeF16, 16)
  72. tests := []testCase{
  73. {
  74. name: "FirstBatch",
  75. in: []float32{1, 2, 3, 4},
  76. inShape: []int{1, 1, 4},
  77. seqs: []int{0, 0, 1, 1},
  78. pos: []int32{0, 1, 0, 1},
  79. expected: []float32{1, 2, 3, 4},
  80. expectedShape: []int{1, 1, 4},
  81. expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
  82. },
  83. {
  84. name: "SecondBatch",
  85. in: []float32{5, 6},
  86. inShape: []int{1, 1, 2},
  87. seqs: []int{0, 1},
  88. pos: []int32{2, 2},
  89. expected: []float32{1, 2, 3, 4, 5, 6},
  90. expectedShape: []int{1, 1, 6},
  91. expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
  92. },
  93. }
  94. testCache(t, backend, cache, tests)
  95. }
  96. func TestRemove(t *testing.T) {
  97. backend := &testBackend{}
  98. cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
  99. return key.Add(ctx, shift), nil
  100. })
  101. defer cache.Close()
  102. cache.Init(backend, ml.DTypeF16, 16)
  103. tests := []testCase{
  104. {
  105. name: "FirstBatch",
  106. in: []float32{1, 2, 3, 4},
  107. inShape: []int{1, 1, 4},
  108. seqs: []int{0, 0, 1, 1},
  109. pos: []int32{0, 1, 0, 1},
  110. expected: []float32{1, 2, 3, 4},
  111. expectedShape: []int{1, 1, 4},
  112. expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
  113. },
  114. }
  115. testCache(t, backend, cache, tests)
  116. err := cache.Remove(0, 1, math.MaxInt32)
  117. if err != nil {
  118. panic(err)
  119. }
  120. tests = []testCase{
  121. {
  122. name: "RemoveEnd",
  123. in: []float32{5, 6},
  124. inShape: []int{1, 1, 2},
  125. seqs: []int{0, 1},
  126. pos: []int32{1, 2},
  127. expected: []float32{1, 2, 3, 4, 5, 6},
  128. expectedShape: []int{1, 1, 6},
  129. expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
  130. },
  131. }
  132. testCache(t, backend, cache, tests)
  133. err = cache.Remove(0, 0, 1)
  134. if err != nil {
  135. panic(err)
  136. }
  137. tests = []testCase{
  138. {
  139. name: "RemoveMiddle",
  140. in: []float32{7, 8},
  141. inShape: []int{1, 1, 2},
  142. seqs: []int{0, 0},
  143. pos: []int32{1, 2},
  144. expected: []float32{7, 8, 3, 4, 4},
  145. expectedShape: []int{1, 1, 5},
  146. expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0},
  147. },
  148. }
  149. testCache(t, backend, cache, tests)
  150. }
  151. func TestDefrag(t *testing.T) {
  152. backend := &testBackend{}
  153. cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
  154. return key.Add(ctx, shift), nil
  155. })
  156. defer cache.Close()
  157. cache.Init(backend, ml.DTypeF16, 16)
  158. tests := []testCase{
  159. {
  160. name: "FirstBatch",
  161. in: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
  162. inShape: []int{1, 1, 16},
  163. seqs: []int{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
  164. pos: []int32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
  165. expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
  166. expectedShape: []int{1, 1, 16},
  167. expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
  168. },
  169. }
  170. testCache(t, backend, cache, tests)
  171. err := cache.Remove(0, 2, 4)
  172. if err != nil {
  173. panic(err)
  174. }
  175. err = cache.Remove(0, 13, math.MaxInt32)
  176. if err != nil {
  177. panic(err)
  178. }
  179. tests = []testCase{
  180. {
  181. name: "Defrag",
  182. in: []float32{17, 18, 19},
  183. inShape: []int{1, 1, 3},
  184. seqs: []int{0, 0, 0},
  185. pos: []int32{16, 17, 18},
  186. expected: []float32{1, 2, 12, 13, 3, 4, 5, 6, 7, 8, 9, 10, 11, 17, 18, 19},
  187. expectedShape: []int{1, 1, 16},
  188. expectedMask: []float32{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
  189. },
  190. }
  191. testCache(t, backend, cache, tests)
  192. }
  193. func TestCopy(t *testing.T) {
  194. backend := &testBackend{}
  195. cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
  196. defer cache.Close()
  197. cache.Init(backend, ml.DTypeF16, 16)
  198. tests := []testCase{
  199. {
  200. name: "FirstBatch",
  201. in: []float32{1, 2, 3, 4},
  202. inShape: []int{1, 1, 4},
  203. seqs: []int{0, 0, 0, 0},
  204. pos: []int32{0, 1, 2, 3},
  205. expected: []float32{1, 2, 3, 4},
  206. expectedShape: []int{1, 1, 4},
  207. expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
  208. },
  209. }
  210. testCache(t, backend, cache, tests)
  211. cache.CopyPrefix(0, 1, 2)
  212. tests = []testCase{
  213. {
  214. name: "Copy",
  215. in: []float32{5, 6},
  216. inShape: []int{1, 1, 2},
  217. seqs: []int{1, 1},
  218. pos: []int32{3, 4},
  219. expected: []float32{1, 2, 3, 4, 5, 6},
  220. expectedShape: []int{1, 1, 6},
  221. expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
  222. },
  223. }
  224. testCache(t, backend, cache, tests)
  225. }
  226. func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) {
  227. for _, test := range tests {
  228. t.Run(test.name, func(t *testing.T) {
  229. context := backend.NewContext()
  230. defer context.Close()
  231. err := cache.StartForward(context, input.Options{Positions: test.pos, Sequences: test.seqs})
  232. if err != nil {
  233. panic(err)
  234. }
  235. cache.SetLayer(0)
  236. tensor, _ := context.FromFloatSlice(test.in, test.inShape...)
  237. cache.Put(context, tensor, tensor)
  238. out, _, mask := cache.Get(context)
  239. context.Forward(out, mask).Compute(out, mask)
  240. if !slices.Equal(out.Floats(), test.expected) || !slices.Equal(out.Shape(), test.expectedShape) || !slices.Equal(mask.Floats(), test.expectedMask) {
  241. t.Errorf("TestCache: have %v (shape %v); want %v (shape %v); mask: have %v (shape %v) want %v", out.Floats(), out.Shape(), test.expected, test.expectedShape, mask.Floats(), mask.Shape(), test.expectedMask)
  242. }
  243. })
  244. }
  245. }
  246. type testBackend struct{}
  247. func (b *testBackend) Config() ml.Config {
  248. panic("not implemented")
  249. }
  250. func (b *testBackend) Get(name string) ml.Tensor {
  251. panic("not implemented")
  252. }
  253. func (b *testBackend) NewContext() ml.Context {
  254. return &testContext{}
  255. }
  256. func (b *testBackend) NewContextSize(int) ml.Context {
  257. return &testContext{}
  258. }
  259. func (b *testBackend) SystemInfo() string {
  260. return "not implemented"
  261. }
  262. type testContext struct{}
  263. func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
  264. total := 0
  265. if len(shape) > 0 {
  266. total = 1
  267. for _, s := range shape {
  268. total *= s
  269. }
  270. }
  271. return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
  272. }
  273. func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
  274. return c.Empty(dtype, shape...)
  275. }
  276. func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
  277. t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
  278. copy(t.data, s)
  279. return t, nil
  280. }
  281. func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
  282. f := make([]float32, len(s))
  283. for i := range f {
  284. f[i] = float32(s[i])
  285. }
  286. out, _ := c.FromFloatSlice(f, shape...)
  287. out.(*testTensor).dtype = ml.DTypeI32
  288. return out, nil
  289. }
  290. func (c *testContext) Input() ml.Context { return c }
  291. func (c *testContext) Output() ml.Context { return c }
  292. func (c *testContext) Layer(int) ml.Context { return c }
  293. func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
  294. func (c *testContext) Compute(...ml.Tensor) {}
  295. func (c *testContext) MaxGraphNodes() int {
  296. return 10
  297. }
  298. func (c *testContext) Close() {}
  299. type testTensor struct {
  300. dtype ml.DType
  301. elementSize int
  302. data []float32
  303. shape []int
  304. }
  305. func (t *testTensor) Dim(n int) int {
  306. return t.shape[n]
  307. }
  308. func (t *testTensor) Stride(n int) int {
  309. stride := t.elementSize
  310. for i := range n {
  311. stride *= t.shape[i]
  312. }
  313. return stride
  314. }
  315. func (t *testTensor) Shape() []int {
  316. return t.shape
  317. }
  318. func (t *testTensor) DType() ml.DType {
  319. return t.dtype
  320. }
  321. func (t *testTensor) Bytes() []byte {
  322. panic("not implemented")
  323. }
  324. func (t *testTensor) Floats() []float32 {
  325. out := make([]float32, len(t.data))
  326. copy(out, t.data)
  327. return out
  328. }
  329. func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  330. out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
  331. for i := range out.data {
  332. out.data[i] = t.data[i] + t2.(*testTensor).data[i]
  333. }
  334. return out
  335. }
  336. func (t *testTensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  337. panic("not implemented")
  338. }
  339. func (t *testTensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  340. panic("not implemented")
  341. }
  342. func (t *testTensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  343. panic("not implemented")
  344. }
  345. func (t *testTensor) Softmax(ctx ml.Context) ml.Tensor {
  346. panic("not implemented")
  347. }
  348. func (t *testTensor) LayerNorm(ctx ml.Context, weight, bias ml.Tensor, eps float32) ml.Tensor {
  349. panic("not implemented")
  350. }
  351. func (t *testTensor) RMSNorm(ctx ml.Context, weight ml.Tensor, eps float32) ml.Tensor {
  352. panic("not implemented")
  353. }
  354. func (t *testTensor) Scale(ctx ml.Context, s float64) ml.Tensor {
  355. panic("not implemented")
  356. }
  357. func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
  358. panic("not implemented")
  359. }
  360. func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim, ropeType uint32, base, scale float32) ml.Tensor {
  361. panic("not implemented")
  362. }
  363. func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor {
  364. panic("not implemented")
  365. }
  366. func (t *testTensor) GELU(ctx ml.Context) ml.Tensor {
  367. panic("not implemented")
  368. }
  369. func (t *testTensor) SILU(ctx ml.Context) ml.Tensor {
  370. panic("not implemented")
  371. }
  372. func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
  373. panic("not implemented")
  374. }
  375. func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
  376. offset /= t.elementSize
  377. var s []int
  378. switch len(shape) {
  379. case 1:
  380. s = []int{shape[0]}
  381. case 5:
  382. s = []int{shape[0], shape[2], shape[4]}
  383. default:
  384. panic("unsupported number of dimensions")
  385. }
  386. context := &testContext{}
  387. view := context.Empty(t.dtype, s...).(*testTensor)
  388. view.data = t.data[offset : offset+len(view.data)]
  389. return view
  390. }
  391. func (t *testTensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
  392. panic("not implemented")
  393. }
  394. func (t *testTensor) Contiguous(ctx ml.Context) ml.Tensor {
  395. panic("not implemented")
  396. }
  397. func (t *testTensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
  398. panic("not implemented")
  399. }
  400. func (t *testTensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
  401. panic("not implemented")
  402. }
  403. func (t *testTensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
  404. panic("not implemented")
  405. }
  406. func (t *testTensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
  407. panic("not implemented")
  408. }
  409. func (t *testTensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  410. panic("not implemented")
  411. }
  412. func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  413. copy(t2.(*testTensor).data, t.data)
  414. return nil
  415. }