ggml.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715
  1. package ggml
  2. // #cgo CPPFLAGS: -I${SRCDIR}/ggml/include
  3. // #include <stdlib.h>
  4. // #include <stdint.h>
  5. // #include "ggml.h"
  6. // #include "ggml-cpu.h"
  7. // #include "ggml-backend.h"
  8. import "C"
  9. import (
  10. "fmt"
  11. "io"
  12. "log/slog"
  13. "os"
  14. "sync"
  15. "unsafe"
  16. "github.com/ollama/ollama/format"
  17. fs "github.com/ollama/ollama/fs/ggml"
  18. "github.com/ollama/ollama/ml"
  19. "golang.org/x/sync/errgroup"
  20. ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
  21. )
  22. type device struct {
  23. d *C.struct_ggml_backend_device
  24. }
  25. func (d device) LogValue() slog.Value {
  26. var free, total uint64
  27. C.ggml_backend_dev_memory(d.d, (*C.size_t)(&free), (*C.size_t)(&total))
  28. kind := "unknown"
  29. switch C.ggml_backend_dev_type(d.d) {
  30. case C.GGML_BACKEND_DEVICE_TYPE_CPU:
  31. kind = "cpu"
  32. case C.GGML_BACKEND_DEVICE_TYPE_GPU:
  33. kind = "gpu"
  34. case C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
  35. kind = "accel"
  36. }
  37. return slog.GroupValue(
  38. slog.String("name", C.GoString(C.ggml_backend_dev_name(d.d))),
  39. slog.String("description", C.GoString(C.ggml_backend_dev_description(d.d))),
  40. slog.String("kind", kind),
  41. slog.String("free", format.HumanBytes2(free)),
  42. slog.String("total", format.HumanBytes2(total)),
  43. )
  44. }
  45. var devices = sync.OnceValue(func() []device {
  46. ggml.OnceLoad()
  47. s := make([]device, C.ggml_backend_dev_count())
  48. for i := range s {
  49. s[i] = device{C.ggml_backend_dev_get(C.size_t(i))}
  50. }
  51. return s
  52. })
  53. type Backend struct {
  54. flashAttention bool
  55. meta *fs.GGML
  56. cpus, gpus []Context
  57. tensors map[string]*Context
  58. sched *C.struct_ggml_backend_sched
  59. }
  60. func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
  61. meta, n, err := fs.Decode(r, -1)
  62. if err != nil {
  63. return nil, err
  64. }
  65. slog.Info(
  66. "",
  67. "architecture", meta.KV().Architecture(),
  68. "file_type", meta.KV().FileType(),
  69. "name", meta.KV().String("general.name"),
  70. "description", meta.KV().String("general.description"),
  71. "num_tensors", len(meta.Tensors().Items()),
  72. "num_key_values", len(meta.KV()),
  73. )
  74. var cpus, gpus []Context
  75. for _, d := range devices() {
  76. switch C.ggml_backend_dev_type(d.d) {
  77. case C.GGML_BACKEND_DEVICE_TYPE_CPU,
  78. C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
  79. slog.Info("cpu", "device", d)
  80. cpus = append(cpus, Context{
  81. ctx: C.ggml_init(C.struct_ggml_init_params{
  82. mem_size: C.size_t(int(C.ggml_tensor_overhead()) * (len(meta.Tensors().Items()) + 1 + int(meta.KV().BlockCount())*2)),
  83. no_alloc: true,
  84. }),
  85. backend: C.ggml_backend_dev_init(d.d, nil),
  86. })
  87. case C.GGML_BACKEND_DEVICE_TYPE_GPU:
  88. slog.Info("gpu", "device", d)
  89. gpus = append(gpus, Context{
  90. ctx: C.ggml_init(C.struct_ggml_init_params{
  91. mem_size: C.size_t(int(C.ggml_tensor_overhead()) * (len(meta.Tensors().Items()) + 1 + int(meta.KV().BlockCount())*2)),
  92. no_alloc: true,
  93. }),
  94. backend: C.ggml_backend_dev_init(d.d, nil),
  95. })
  96. }
  97. }
  98. ctxFunc := func(s []Context) (*Context, error) {
  99. for _, e := range s {
  100. return &e, nil
  101. }
  102. return nil, fmt.Errorf("no devices available")
  103. }
  104. tensors := make(map[*fs.Tensor]*Context, len(meta.Tensors().Items()))
  105. for _, t := range meta.Tensors().Items() {
  106. c, err := ctxFunc(append(gpus, cpus...))
  107. if err != nil {
  108. return nil, err
  109. }
  110. func() {
  111. tt := C.ggml_new_tensor(c.ctx, t.Kind, C.int(len(t.Shape)), (*C.int64_t)(unsafe.Pointer(&t.Shape[0])))
  112. cname := C.CString(t.Name)
  113. defer C.free(unsafe.Pointer(cname))
  114. C.ggml_set_name(tt, cname)
  115. tensors[t] = c
  116. }()
  117. }
  118. for _, b := range append(gpus, cpus...) {
  119. C.ggml_backend_alloc_ctx_tensors(b.ctx, b.backend)
  120. }
  121. sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset))
  122. var g errgroup.Group
  123. for t, c := range tensors {
  124. g.Go(func() error {
  125. bts := make([]byte, t.Size())
  126. n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), bts)
  127. if err != nil {
  128. return err
  129. }
  130. if n != int(t.Size()) {
  131. return fmt.Errorf("expected %d bytes, got %d", t.Size(), n)
  132. }
  133. cname := C.CString(t.Name)
  134. defer C.free(unsafe.Pointer(cname))
  135. C.ggml_backend_tensor_set(C.ggml_get_tensor(c.ctx, cname), unsafe.Pointer(&bts[0]), 0, C.size_t(n))
  136. return nil
  137. })
  138. }
  139. if err := g.Wait(); err != nil {
  140. return nil, err
  141. }
  142. backends := make([]*C.struct_ggml_backend, len(gpus)+len(cpus))
  143. bufts := make([]*C.struct_ggml_backend_buffer_type, len(gpus)+len(cpus))
  144. for i, c := range append(gpus, cpus...) {
  145. backends[i] = c.backend
  146. bufts[i] = C.ggml_backend_get_default_buffer_type(c.backend)
  147. }
  148. return &Backend{
  149. flashAttention: params.FlashAttention,
  150. meta: meta,
  151. cpus: cpus,
  152. gpus: gpus,
  153. sched: C.ggml_backend_sched_new(
  154. (*C.ggml_backend_t)(unsafe.Pointer(&backends[0])),
  155. (*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])),
  156. C.int(len(backends)),
  157. C.size_t(max(8192, len(meta.Tensors().Items())*5)),
  158. true,
  159. ),
  160. }, nil
  161. }
  162. func init() {
  163. ml.RegisterBackend("ggml", New)
  164. }
  165. func (b *Backend) Config() ml.Config {
  166. return b.meta.KV()
  167. }
  168. func (b *Backend) Get(name string) ml.Tensor {
  169. cname := C.CString(name)
  170. defer C.free(unsafe.Pointer(cname))
  171. for _, c := range append(b.gpus, b.cpus...) {
  172. if t := C.ggml_get_tensor(c.ctx, cname); t != nil {
  173. return &Tensor{b: b, t: t}
  174. }
  175. }
  176. return nil
  177. }
  178. func (b *Backend) NewContext() ml.Context {
  179. nodes := max(8192, len(b.meta.Tensors().Items())*5)
  180. c := C.ggml_init(C.struct_ggml_init_params{
  181. mem_buffer: nil,
  182. mem_size: C.size_t(nodes)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(nodes), false),
  183. no_alloc: true,
  184. })
  185. backends := make([]*C.struct_ggml_backend, len(b.gpus)+len(b.cpus))
  186. for i, c := range append(b.gpus, b.cpus...) {
  187. backends[i] = c.backend
  188. }
  189. return &Context{
  190. b: b,
  191. ctx: c,
  192. backend: backends[0],
  193. nodes: nodes,
  194. }
  195. }
  196. func (b *Backend) CacheConfig() ml.CacheConfig {
  197. if b.flashAttention {
  198. return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD}
  199. } else {
  200. return ml.CacheConfig{CachePadding: 32, PermutedV: true}
  201. }
  202. }
  203. type Context struct {
  204. b *Backend
  205. ctx *C.struct_ggml_context
  206. backend *C.struct_ggml_backend
  207. graph *C.struct_ggml_cgraph
  208. nodes int
  209. }
  210. func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
  211. if c.graph == nil {
  212. c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.nodes), false)
  213. }
  214. for _, tensor := range tensors {
  215. C.ggml_build_forward_expand(c.graph, tensor.(*Tensor).t)
  216. }
  217. return c
  218. }
  219. func (c *Context) Compute(tensors ...ml.Tensor) {
  220. C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph)
  221. C.ggml_backend_sched_reset(c.b.sched)
  222. needSync := true
  223. sync := func() {
  224. if needSync {
  225. C.ggml_backend_sched_synchronize(c.b.sched)
  226. needSync = false
  227. }
  228. }
  229. for _, t := range tensors {
  230. if C.ggml_nbytes(t.(*Tensor).t) > 0 {
  231. t.(*Tensor).sync = sync
  232. }
  233. }
  234. }
  235. func (c *Context) MaxTensors() int {
  236. return c.nodes
  237. }
  238. func shapeToGGML(shape []int) *C.int64_t {
  239. sh := make([]C.int64_t, len(shape))
  240. for i, s := range shape {
  241. sh[i] = (C.int64_t)(s)
  242. }
  243. return &sh[0]
  244. }
  245. func newTensor(ctx Context, dtype ml.DType, zero bool, shape []int) ml.Tensor {
  246. if len(shape) < 1 || len(shape) > 4 {
  247. panic("unsupported number of dimensions")
  248. }
  249. for _, dim := range shape {
  250. if dim < 1 {
  251. panic("invalid shape")
  252. }
  253. }
  254. var t *C.struct_ggml_tensor
  255. switch dtype {
  256. case ml.DTypeF32:
  257. t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape))
  258. case ml.DTypeF16:
  259. t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape))
  260. case ml.DTypeI32:
  261. t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape))
  262. default:
  263. panic("unsupported dtype")
  264. }
  265. b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t))
  266. C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
  267. if zero {
  268. C.ggml_set_zero(t)
  269. }
  270. return &Tensor{b: ctx.b, t: t}
  271. }
  272. func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
  273. return newTensor(c, dtype, false, shape)
  274. }
  275. func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
  276. return newTensor(c, dtype, true, shape)
  277. }
  278. func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype uint32) (ml.Tensor, error) {
  279. n := len(s)
  280. if n == 0 {
  281. var shape C.int64_t = 0
  282. t := C.ggml_new_tensor(ctx.ctx, dtype, 1, &shape)
  283. return &Tensor{b: ctx.b, t: t}, nil
  284. }
  285. for _, v := range shape {
  286. n /= v
  287. }
  288. if n != 1 {
  289. return nil, fmt.Errorf("invalid shape %v for %d elements", shape, len(s))
  290. }
  291. t := C.ggml_new_tensor(ctx.ctx, dtype, C.int(len(shape)), shapeToGGML(shape))
  292. b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t))
  293. C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
  294. C.ggml_backend_tensor_set(t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t))
  295. return &Tensor{b: ctx.b, t: t}, nil
  296. }
  297. func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
  298. return fromSlice(c, s, shape, C.GGML_TYPE_F32)
  299. }
  300. func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
  301. return fromSlice(c, s, shape, C.GGML_TYPE_I32)
  302. }
  303. func (c *Context) Close() {
  304. if c != nil {
  305. C.ggml_free(c.ctx)
  306. }
  307. }
  308. type Tensor struct {
  309. b *Backend
  310. t *C.struct_ggml_tensor
  311. sync func()
  312. }
  313. func (t *Tensor) LogValue() slog.Value {
  314. return slog.GroupValue(
  315. slog.String("name", C.GoString(C.ggml_get_name(t.t))),
  316. slog.String("type", C.GoString(C.ggml_type_name(t.t._type))),
  317. slog.Any("shape", t.Shape()),
  318. )
  319. }
  320. func (t *Tensor) Dim(n int) int {
  321. return int(t.t.ne[n])
  322. }
  323. func (t *Tensor) Stride(n int) int {
  324. return int(t.t.nb[n])
  325. }
  326. func (t *Tensor) Shape() []int {
  327. shape := make([]int, C.ggml_n_dims(t.t))
  328. for i := range shape {
  329. shape[i] = t.Dim(i)
  330. }
  331. return shape
  332. }
  333. func (t *Tensor) Bytes() (data []byte) {
  334. if t.sync != nil {
  335. data = make([]byte, C.ggml_nbytes(t.t))
  336. t.sync()
  337. C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
  338. }
  339. return
  340. }
  341. func (t *Tensor) Floats() (data []float32) {
  342. if t.sync != nil {
  343. data = make([]float32, C.ggml_nelements(t.t))
  344. t.sync()
  345. C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
  346. }
  347. return
  348. }
  349. func (t *Tensor) DType() ml.DType {
  350. switch t.t._type {
  351. case C.GGML_TYPE_F32:
  352. return ml.DTypeF32
  353. case C.GGML_TYPE_F16:
  354. return ml.DTypeF16
  355. case C.GGML_TYPE_I32:
  356. return ml.DTypeI32
  357. default:
  358. return ml.DTypeOther
  359. }
  360. }
  361. func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  362. return &Tensor{
  363. b: t.b,
  364. t: C.ggml_add(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
  365. }
  366. }
  367. func (t *Tensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
  368. if len(s) > 0 {
  369. return t.Concat(ctx, s[0].Stack(ctx, dim, s[1:]...), dim)
  370. }
  371. return t
  372. }
  373. func (t *Tensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
  374. return &Tensor{
  375. b: t.b,
  376. t: C.ggml_concat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(dim)),
  377. }
  378. }
  379. func (t *Tensor) Contiguous(ctx ml.Context) ml.Tensor {
  380. return &Tensor{
  381. b: t.b,
  382. t: C.ggml_cont(ctx.(*Context).ctx, t.t),
  383. }
  384. }
  385. func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  386. return &Tensor{
  387. b: t.b,
  388. t: C.ggml_mul(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
  389. }
  390. }
  391. func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  392. return &Tensor{
  393. b: t.b,
  394. t: C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
  395. }
  396. }
  397. func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  398. mul := C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t)
  399. C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32)
  400. return &Tensor{
  401. b: t.b,
  402. t: mul,
  403. }
  404. }
  405. func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
  406. tt := (&Tensor{b: t.b, t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
  407. if b != nil {
  408. tt = tt.Add(ctx, b)
  409. }
  410. return tt
  411. }
  412. func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
  413. return (&Tensor{b: t.b, t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
  414. }
  415. func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
  416. if len(shape) != 4 {
  417. panic("expected 4 dimensions")
  418. }
  419. return &Tensor{
  420. b: t.b,
  421. t: C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
  422. }
  423. }
  424. func (t *Tensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
  425. if len(shape) != 4 {
  426. panic("expected 4 dimensions")
  427. }
  428. return &Tensor{
  429. b: t.b,
  430. t: C.ggml_permute(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
  431. }
  432. }
  433. func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  434. return &Tensor{
  435. b: t.b,
  436. t: C.ggml_get_rows(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
  437. }
  438. }
  439. func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  440. return &Tensor{
  441. b: t.b,
  442. t: C.ggml_cpy(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
  443. }
  444. }
  445. func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
  446. switch len(shape) {
  447. case 1:
  448. return &Tensor{
  449. b: t.b,
  450. t: C.ggml_reshape_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])),
  451. }
  452. case 2:
  453. return &Tensor{
  454. b: t.b,
  455. t: C.ggml_reshape_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])),
  456. }
  457. case 3:
  458. return &Tensor{
  459. b: t.b,
  460. t: C.ggml_reshape_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])),
  461. }
  462. case 4:
  463. return &Tensor{
  464. b: t.b,
  465. t: C.ggml_reshape_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])),
  466. }
  467. default:
  468. panic("unsupported number of dimensions")
  469. }
  470. }
  471. func (t *Tensor) Scale(ctx ml.Context, s float64) ml.Tensor {
  472. return &Tensor{
  473. b: t.b,
  474. t: C.ggml_scale(ctx.(*Context).ctx, t.t, (C.float)(s)),
  475. }
  476. }
  477. func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor {
  478. return &Tensor{
  479. b: t.b,
  480. t: C.ggml_soft_max(ctx.(*Context).ctx, t.t),
  481. }
  482. }
  483. func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor {
  484. return &Tensor{
  485. b: t.b,
  486. t: C.ggml_tanh_inplace(ctx.(*Context).ctx, t.t),
  487. }
  488. }
  489. func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
  490. if len(shape) != 4 {
  491. panic("expected 4 dimensions")
  492. }
  493. return &Tensor{
  494. b: t.b,
  495. t: C.ggml_unpad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
  496. }
  497. }
  498. func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
  499. switch len(shape) {
  500. case 1:
  501. return &Tensor{
  502. b: t.b,
  503. t: C.ggml_view_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.size_t(offset)),
  504. }
  505. case 3:
  506. return &Tensor{
  507. b: t.b,
  508. t: C.ggml_view_2d(ctx.(*Context).ctx, t.t,
  509. C.int64_t(shape[0]), C.int64_t(shape[2]),
  510. C.size_t(shape[1]),
  511. C.size_t(offset)),
  512. }
  513. case 5:
  514. return &Tensor{
  515. b: t.b,
  516. t: C.ggml_view_3d(ctx.(*Context).ctx, t.t,
  517. C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]),
  518. C.size_t(shape[1]), C.size_t(shape[3]),
  519. C.size_t(offset)),
  520. }
  521. case 7:
  522. return &Tensor{
  523. b: t.b,
  524. t: C.ggml_view_4d(ctx.(*Context).ctx, t.t,
  525. C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), C.int64_t(shape[6]),
  526. C.size_t(shape[1]), C.size_t(shape[3]), C.size_t(shape[5]),
  527. C.size_t(offset)),
  528. }
  529. default:
  530. panic("unsupported number of dimensions")
  531. }
  532. }
  533. const (
  534. ropeTypeNorm C.int = iota
  535. )
  536. func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, ropeBase, ropeScale float32) ml.Tensor {
  537. if ropeFactors == nil {
  538. ropeFactors = &Tensor{b: t.b}
  539. }
  540. dequant := t.t
  541. if C.ggml_is_quantized(t.t._type) {
  542. dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
  543. }
  544. return &Tensor{
  545. b: t.b,
  546. t: C.ggml_rope_ext(
  547. ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
  548. C.int(ropeDim),
  549. 131072, // YaRN n_ctx_train
  550. ropeTypeNorm, // ROPE_TYPE_NORM
  551. C.float(ropeBase),
  552. C.float(ropeScale),
  553. 0., // YaRN ext_factor
  554. 1., // YaRN attn_factor
  555. 32., // YaRN beta_fast
  556. 1., // YaRN beta_slow
  557. ),
  558. }
  559. }
  560. func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
  561. return &Tensor{
  562. b: t.b,
  563. t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
  564. }
  565. }
  566. func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
  567. return &Tensor{
  568. b: t.b,
  569. t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
  570. }
  571. }
  572. func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
  573. return &Tensor{
  574. b: t.b,
  575. t: C.ggml_conv_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1)),
  576. }
  577. }
  578. func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor {
  579. var kqMask *C.struct_ggml_tensor
  580. if mask != nil {
  581. kqMask = mask.(*Tensor).t
  582. }
  583. query := t.Permute(ctx, 0, 2, 1, 3)
  584. key = key.Permute(ctx, 0, 2, 1, 3)
  585. if t.b.flashAttention {
  586. value = value.Permute(ctx, 0, 2, 1, 3)
  587. kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
  588. C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32)
  589. return &Tensor{b: t.b, t: kqv}
  590. } else {
  591. kq := key.MulmatFullPrec(ctx, query)
  592. kq = &Tensor{
  593. b: t.b,
  594. t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
  595. }
  596. kqv := value.Mulmat(ctx, kq)
  597. return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  598. }
  599. }