causal_test.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543
  1. package kvcache
  2. import (
  3. "math"
  4. "slices"
  5. "testing"
  6. "github.com/ollama/ollama/ml"
  7. "github.com/ollama/ollama/model/input"
  8. )
  9. type testCase struct {
  10. name string
  11. in []float32
  12. inShape []int
  13. seqs []int
  14. pos []int32
  15. expected []float32
  16. expectedShape []int
  17. expectedMask []float32
  18. }
  19. func TestStore(t *testing.T) {
  20. backend := &testBackend{}
  21. cache := NewCausalCache(nil)
  22. defer cache.Close()
  23. cache.Init(backend, ml.DTypeF16, 1, 16, 16)
  24. tests := []testCase{
  25. {
  26. name: "FirstBatch",
  27. in: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
  28. inShape: []int{2, 3, 4},
  29. seqs: []int{0, 0, 0, 0},
  30. pos: []int32{0, 1, 2, 3},
  31. expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
  32. expectedShape: []int{2, 3, 4},
  33. expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
  34. },
  35. {
  36. name: "SecondBatch",
  37. in: []float32{115, 215, 125, 225, 135, 235},
  38. inShape: []int{2, 3, 1},
  39. seqs: []int{0},
  40. pos: []int32{4},
  41. expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234, 115, 215, 125, 225, 135, 235},
  42. expectedShape: []int{2, 3, 5},
  43. expectedMask: []float32{0, 0, 0, 0, 0},
  44. },
  45. }
  46. testCache(t, backend, cache, tests)
  47. }
  48. func TestSWA(t *testing.T) {
  49. backend := &testBackend{}
  50. cache := NewSWACache(1, nil)
  51. defer cache.Close()
  52. cache.Init(backend, ml.DTypeF16, 1, 16, 16)
  53. tests := []testCase{
  54. {
  55. name: "FirstBatch",
  56. in: []float32{1, 2, 3, 4},
  57. inShape: []int{1, 1, 4},
  58. seqs: []int{0, 0, 0, 0},
  59. pos: []int32{0, 1, 2, 3},
  60. expected: []float32{1, 2, 3, 4},
  61. expectedShape: []int{1, 1, 4},
  62. expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
  63. },
  64. {
  65. name: "SecondBatch",
  66. in: []float32{5, 6},
  67. inShape: []int{1, 1, 2},
  68. seqs: []int{0, 0},
  69. pos: []int32{4, 5},
  70. expected: []float32{5, 6, 3, 4},
  71. expectedShape: []int{1, 1, 4},
  72. expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1))},
  73. },
  74. }
  75. testCache(t, backend, cache, tests)
  76. }
  77. func TestSequences(t *testing.T) {
  78. backend := &testBackend{}
  79. cache := NewCausalCache(nil)
  80. defer cache.Close()
  81. cache.Init(backend, ml.DTypeF16, 1, 16, 16)
  82. tests := []testCase{
  83. {
  84. name: "FirstBatch",
  85. in: []float32{1, 2, 3, 4},
  86. inShape: []int{1, 1, 4},
  87. seqs: []int{0, 0, 1, 1},
  88. pos: []int32{0, 1, 0, 1},
  89. expected: []float32{1, 2, 3, 4},
  90. expectedShape: []int{1, 1, 4},
  91. expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
  92. },
  93. {
  94. name: "SecondBatch",
  95. in: []float32{5, 6},
  96. inShape: []int{1, 1, 2},
  97. seqs: []int{0, 1},
  98. pos: []int32{2, 2},
  99. expected: []float32{1, 2, 3, 4, 5, 6},
  100. expectedShape: []int{1, 1, 6},
  101. expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
  102. },
  103. }
  104. testCache(t, backend, cache, tests)
  105. }
  106. func TestRemove(t *testing.T) {
  107. backend := &testBackend{}
  108. cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
  109. return key.Add(ctx, shift), nil
  110. })
  111. defer cache.Close()
  112. cache.Init(backend, ml.DTypeF16, 1, 16, 16)
  113. tests := []testCase{
  114. {
  115. name: "FirstBatch",
  116. in: []float32{1, 2, 3, 4},
  117. inShape: []int{1, 1, 4},
  118. seqs: []int{0, 0, 1, 1},
  119. pos: []int32{0, 1, 0, 1},
  120. expected: []float32{1, 2, 3, 4},
  121. expectedShape: []int{1, 1, 4},
  122. expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
  123. },
  124. }
  125. testCache(t, backend, cache, tests)
  126. err := cache.Remove(0, 1, math.MaxInt32)
  127. if err != nil {
  128. panic(err)
  129. }
  130. tests = []testCase{
  131. {
  132. name: "RemoveEnd",
  133. in: []float32{5, 6},
  134. inShape: []int{1, 1, 2},
  135. seqs: []int{0, 1},
  136. pos: []int32{1, 2},
  137. expected: []float32{1, 2, 3, 4, 5, 6},
  138. expectedShape: []int{1, 1, 6},
  139. expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
  140. },
  141. }
  142. testCache(t, backend, cache, tests)
  143. err = cache.Remove(0, 0, 1)
  144. if err != nil {
  145. panic(err)
  146. }
  147. tests = []testCase{
  148. {
  149. name: "RemoveMiddle",
  150. in: []float32{7, 8},
  151. inShape: []int{1, 1, 2},
  152. seqs: []int{0, 0},
  153. pos: []int32{1, 2},
  154. expected: []float32{7, 8, 3, 4, 4},
  155. expectedShape: []int{1, 1, 5},
  156. expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0},
  157. },
  158. }
  159. testCache(t, backend, cache, tests)
  160. }
  161. func TestDefrag(t *testing.T) {
  162. backend := &testBackend{}
  163. cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
  164. return key.Add(ctx, shift), nil
  165. })
  166. defer cache.Close()
  167. cache.Init(backend, ml.DTypeF16, 1, 16, 16)
  168. tests := []testCase{
  169. {
  170. name: "FirstBatch",
  171. in: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
  172. inShape: []int{1, 1, 16},
  173. seqs: []int{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
  174. pos: []int32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
  175. expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
  176. expectedShape: []int{1, 1, 16},
  177. expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
  178. },
  179. }
  180. testCache(t, backend, cache, tests)
  181. err := cache.Remove(0, 2, 4)
  182. if err != nil {
  183. panic(err)
  184. }
  185. err = cache.Remove(0, 13, math.MaxInt32)
  186. if err != nil {
  187. panic(err)
  188. }
  189. tests = []testCase{
  190. {
  191. name: "Defrag",
  192. in: []float32{17, 18, 19},
  193. inShape: []int{1, 1, 3},
  194. seqs: []int{0, 0, 0},
  195. pos: []int32{16, 17, 18},
  196. expected: []float32{1, 2, 12, 13, 3, 4, 5, 6, 7, 8, 9, 10, 11, 17, 18, 19},
  197. expectedShape: []int{1, 1, 16},
  198. expectedMask: []float32{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
  199. },
  200. }
  201. testCache(t, backend, cache, tests)
  202. }
  203. func TestCopy(t *testing.T) {
  204. backend := &testBackend{}
  205. cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
  206. defer cache.Close()
  207. cache.Init(backend, ml.DTypeF16, 1, 16, 16)
  208. tests := []testCase{
  209. {
  210. name: "FirstBatch",
  211. in: []float32{1, 2, 3, 4},
  212. inShape: []int{1, 1, 4},
  213. seqs: []int{0, 0, 0, 0},
  214. pos: []int32{0, 1, 2, 3},
  215. expected: []float32{1, 2, 3, 4},
  216. expectedShape: []int{1, 1, 4},
  217. expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
  218. },
  219. }
  220. testCache(t, backend, cache, tests)
  221. cache.CopyPrefix(0, 1, 2)
  222. tests = []testCase{
  223. {
  224. name: "Copy",
  225. in: []float32{5, 6},
  226. inShape: []int{1, 1, 2},
  227. seqs: []int{1, 1},
  228. pos: []int32{3, 4},
  229. expected: []float32{1, 2, 3, 4, 5, 6},
  230. expectedShape: []int{1, 1, 6},
  231. expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
  232. },
  233. }
  234. testCache(t, backend, cache, tests)
  235. }
  236. func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) {
  237. for _, test := range tests {
  238. t.Run(test.name, func(t *testing.T) {
  239. context := backend.NewContext()
  240. defer context.Close()
  241. err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs})
  242. if err != nil {
  243. panic(err)
  244. }
  245. cache.SetLayer(0)
  246. tensor, _ := context.FromFloatSlice(test.in, test.inShape...)
  247. cache.Put(context, tensor, tensor)
  248. out, _, mask := cache.Get(context)
  249. context.Forward(out, mask).Compute(out, mask)
  250. if !slices.Equal(out.Floats(), test.expected) || !slices.Equal(out.Shape(), test.expectedShape) || !slices.Equal(mask.Floats(), test.expectedMask) {
  251. t.Errorf("TestCache: have %v (shape %v); want %v (shape %v); mask: have %v (shape %v) want %v", out.Floats(), out.Shape(), test.expected, test.expectedShape, mask.Floats(), mask.Shape(), test.expectedMask)
  252. }
  253. })
  254. }
  255. }
  256. type testBackend struct{}
  257. func (b *testBackend) Config() ml.Config {
  258. panic("not implemented")
  259. }
  260. func (b *testBackend) Get(name string) ml.Tensor {
  261. panic("not implemented")
  262. }
  263. func (b *testBackend) NewContext() ml.Context {
  264. return &testContext{}
  265. }
  266. func (b *testBackend) NewContextSize(int) ml.Context {
  267. return &testContext{}
  268. }
  269. func (b *testBackend) SystemInfo() string {
  270. return "not implemented"
  271. }
  272. type testContext struct{}
  273. func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
  274. total := 0
  275. if len(shape) > 0 {
  276. total = 1
  277. for _, s := range shape {
  278. total *= s
  279. }
  280. }
  281. return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
  282. }
  283. func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
  284. return c.Empty(dtype, shape...)
  285. }
  286. func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
  287. t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
  288. copy(t.data, s)
  289. return t, nil
  290. }
  291. func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
  292. f := make([]float32, len(s))
  293. for i := range f {
  294. f[i] = float32(s[i])
  295. }
  296. out, _ := c.FromFloatSlice(f, shape...)
  297. out.(*testTensor).dtype = ml.DTypeI32
  298. return out, nil
  299. }
  300. func (c *testContext) Input() ml.Context { return c }
  301. func (c *testContext) Output() ml.Context { return c }
  302. func (c *testContext) Layer(int) ml.Context { return c }
  303. func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
  304. func (c *testContext) Compute(...ml.Tensor) {}
  305. func (c *testContext) MaxGraphNodes() int {
  306. return 10
  307. }
  308. func (c *testContext) Close() {}
  309. type testTensor struct {
  310. dtype ml.DType
  311. elementSize int
  312. data []float32
  313. shape []int
  314. }
  315. func (t *testTensor) Dim(n int) int {
  316. return t.shape[n]
  317. }
  318. func (t *testTensor) Stride(n int) int {
  319. stride := t.elementSize
  320. for i := range n {
  321. stride *= t.shape[i]
  322. }
  323. return stride
  324. }
  325. func (t *testTensor) Shape() []int {
  326. return t.shape
  327. }
  328. func (t *testTensor) DType() ml.DType {
  329. return t.dtype
  330. }
  331. func (t *testTensor) Bytes() []byte {
  332. panic("not implemented")
  333. }
  334. func (t *testTensor) Floats() []float32 {
  335. out := make([]float32, len(t.data))
  336. copy(out, t.data)
  337. return out
  338. }
  339. func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  340. out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
  341. for i := range out.data {
  342. out.data[i] = t.data[i] + t2.(*testTensor).data[i]
  343. }
  344. return out
  345. }
  346. func (t *testTensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  347. panic("not implemented")
  348. }
  349. func (t *testTensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  350. panic("not implemented")
  351. }
  352. func (t *testTensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  353. panic("not implemented")
  354. }
  355. func (t *testTensor) Softmax(ctx ml.Context) ml.Tensor {
  356. panic("not implemented")
  357. }
  358. func (t *testTensor) LayerNorm(ctx ml.Context, weight, bias ml.Tensor, eps float32) ml.Tensor {
  359. panic("not implemented")
  360. }
  361. func (t *testTensor) RMSNorm(ctx ml.Context, weight ml.Tensor, eps float32) ml.Tensor {
  362. panic("not implemented")
  363. }
  364. func (t *testTensor) Scale(ctx ml.Context, s float64) ml.Tensor {
  365. panic("not implemented")
  366. }
  367. func (t *testTensor) AvgPool1D(ctx ml.Context, k, s, p int) ml.Tensor {
  368. panic("not implemented")
  369. }
  370. func (t *testTensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
  371. panic("not implemented")
  372. }
  373. func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
  374. panic("not implemented")
  375. }
  376. func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim, ropeType uint32, base, scale float32) ml.Tensor {
  377. panic("not implemented")
  378. }
  379. func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor {
  380. panic("not implemented")
  381. }
  382. func (t *testTensor) GELU(ctx ml.Context) ml.Tensor {
  383. panic("not implemented")
  384. }
  385. func (t *testTensor) SILU(ctx ml.Context) ml.Tensor {
  386. panic("not implemented")
  387. }
  388. func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
  389. panic("not implemented")
  390. }
  391. func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
  392. offset /= t.elementSize
  393. var s []int
  394. switch len(shape) {
  395. case 1:
  396. s = []int{shape[0]}
  397. case 5:
  398. s = []int{shape[0], shape[2], shape[4]}
  399. default:
  400. panic("unsupported number of dimensions")
  401. }
  402. context := &testContext{}
  403. view := context.Empty(t.dtype, s...).(*testTensor)
  404. view.data = t.data[offset : offset+len(view.data)]
  405. return view
  406. }
  407. func (t *testTensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
  408. panic("not implemented")
  409. }
  410. func (t *testTensor) Contiguous(ctx ml.Context) ml.Tensor {
  411. panic("not implemented")
  412. }
  413. func (t *testTensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor {
  414. panic("not implemented")
  415. }
  416. func (t *testTensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
  417. panic("not implemented")
  418. }
  419. func (t *testTensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
  420. panic("not implemented")
  421. }
  422. func (t *testTensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
  423. panic("not implemented")
  424. }
  425. func (t *testTensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
  426. panic("not implemented")
  427. }
  428. func (t *testTensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  429. panic("not implemented")
  430. }
  431. func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  432. copy(t2.(*testTensor).data, t.data)
  433. return nil
  434. }