ggml.go 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833
  1. package ggml
  2. // #cgo CPPFLAGS: -I${SRCDIR}/ggml/include
  3. // #include <stdlib.h>
  4. // #include <stdint.h>
  5. // #include "ggml.h"
  6. // #include "ggml-cpu.h"
  7. // #include "ggml-backend.h"
  8. import "C"
  9. import (
  10. "errors"
  11. "fmt"
  12. "io"
  13. "iter"
  14. "log/slog"
  15. "maps"
  16. "os"
  17. "slices"
  18. "strconv"
  19. "strings"
  20. "unicode"
  21. "unsafe"
  22. "github.com/ollama/ollama/format"
  23. fs "github.com/ollama/ollama/fs/ggml"
  24. "github.com/ollama/ollama/ml"
  25. ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
  26. "golang.org/x/sync/errgroup"
  27. )
  28. func devices() iter.Seq[*C.struct_ggml_backend_device] {
  29. return func(yield func(*C.struct_ggml_backend_device) bool) {
  30. ggml.OnceLoad()
  31. for i := range C.ggml_backend_dev_count() {
  32. if !yield(C.ggml_backend_dev_get(i)) {
  33. return
  34. }
  35. }
  36. }
  37. }
  38. type Backend struct {
  39. meta *fs.GGML
  40. flashAttention bool
  41. sched *C.struct_ggml_backend_sched
  42. tensors map[string]*C.struct_ggml_tensor
  43. ctxs []*C.struct_ggml_context
  44. backends []*C.struct_ggml_backend
  45. bufts []*C.struct_ggml_backend_buffer_type
  46. }
  47. func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
  48. meta, n, err := fs.Decode(r, -1)
  49. if err != nil {
  50. return nil, err
  51. }
  52. slog.Info(
  53. "",
  54. "architecture", meta.KV().Architecture(),
  55. "file_type", meta.KV().FileType(),
  56. "name", meta.KV().String("general.name"),
  57. "description", meta.KV().String("general.description"),
  58. "num_tensors", len(meta.Tensors().Items()),
  59. "num_key_values", len(meta.KV()),
  60. )
  61. type dbt struct {
  62. d *C.struct_ggml_backend_device
  63. bts []*C.struct_ggml_backend_buffer_type
  64. }
  65. var cpus, accels, gpus []*C.struct_ggml_backend_device
  66. for d := range devices() {
  67. switch C.ggml_backend_dev_type(d) {
  68. case C.GGML_BACKEND_DEVICE_TYPE_CPU:
  69. cpus = append(cpus, d)
  70. case C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
  71. accels = append(accels, d)
  72. case C.GGML_BACKEND_DEVICE_TYPE_GPU:
  73. gpus = append(gpus, d)
  74. }
  75. }
  76. var cpuBufferTypes []*C.struct_ggml_backend_buffer_type
  77. for _, d := range append(accels, append(gpus, cpus...)...) {
  78. switch C.ggml_backend_dev_type(d) {
  79. case C.GGML_BACKEND_DEVICE_TYPE_CPU,
  80. C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
  81. cpuBufferTypes = append(cpuBufferTypes, C.ggml_backend_dev_buffer_type(d))
  82. }
  83. }
  84. var sum uint64
  85. var cumsum []uint64
  86. var gpuBufferTypes []dbt
  87. for _, d := range gpus {
  88. var free, total C.size_t
  89. C.ggml_backend_dev_memory(d, &free, &total)
  90. sum += uint64(free)
  91. cumsum = append(cumsum, sum)
  92. bt := C.ggml_backend_dev_buffer_type(d)
  93. gpuBufferTypes = append(gpuBufferTypes, dbt{
  94. d: d,
  95. bts: append([]*C.struct_ggml_backend_buffer_type{bt}, cpuBufferTypes...),
  96. })
  97. }
  98. splits := make([]float64, len(cumsum))
  99. for i := range splits {
  100. splits[i] = float64(cumsum[i]) / float64(sum)
  101. }
  102. input := dbt{C.ggml_backend_dev_by_type(C.GGML_BACKEND_DEVICE_TYPE_CPU), cpuBufferTypes}
  103. slog.Info("input layer", "device", C.GoString(C.ggml_backend_dev_name(input.d)))
  104. var blocks int
  105. for key, value := range meta.KV() {
  106. if strings.HasSuffix(key, ".block_count") {
  107. blocks += int(value.(uint32))
  108. }
  109. }
  110. indexFunc := func(i int) func(float64) bool {
  111. return func(f float64) bool {
  112. return float64(i)/float64(blocks+1) < f
  113. }
  114. }
  115. layers := make([]dbt, blocks)
  116. for i := range layers {
  117. layers[i] = gpuBufferTypes[slices.IndexFunc(splits, indexFunc(i))]
  118. slog.Info("layer", "i", i, "device", C.GoString(C.ggml_backend_dev_name(layers[i].d)))
  119. }
  120. output := gpuBufferTypes[slices.IndexFunc(splits, indexFunc(blocks))]
  121. slog.Info("output layer", "device", C.GoString(C.ggml_backend_dev_name(output.d)))
  122. maxTensors := len(meta.Tensors().Items())
  123. maxTensors += 1
  124. maxTensors += blocks * 2
  125. slog.Info("max tensors", "max_tensors", maxTensors)
  126. type tensor struct {
  127. source *fs.Tensor
  128. target string
  129. }
  130. targets := make(map[string][]string)
  131. ctxs := make(map[*C.struct_ggml_backend_buffer_type]*C.struct_ggml_context)
  132. createTensor := func(t tensor, bts []*C.struct_ggml_backend_buffer_type) *C.struct_ggml_tensor {
  133. for _, bt := range bts {
  134. if _, ok := ctxs[bt]; !ok {
  135. ctxs[bt] = C.ggml_init(C.struct_ggml_init_params{
  136. mem_size: C.ggml_tensor_overhead() * C.size_t(maxTensors),
  137. no_alloc: true,
  138. })
  139. }
  140. targets[t.source.Name] = append(targets[t.source.Name], t.target)
  141. name := t.source.Name
  142. if t.target != "" {
  143. name = t.target
  144. }
  145. cname := C.CString(name)
  146. defer C.free(unsafe.Pointer(cname))
  147. if tt := C.ggml_get_tensor(ctxs[bt], cname); tt != nil {
  148. return tt
  149. }
  150. tt := C.ggml_new_tensor(ctxs[bt], t.source.Kind, C.int(len(t.source.Shape)), (*C.int64_t)(unsafe.Pointer(&t.source.Shape[0])))
  151. C.ggml_set_name(tt, cname)
  152. slog.Debug("created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
  153. //nolint:staticcheck // TODO: check if buffer type supports this tensor
  154. return tt
  155. }
  156. return nil
  157. }
  158. hasPart := func(s string, parts ...string) bool {
  159. split := strings.Split(s, ".")
  160. for _, part := range parts {
  161. if slices.Contains(split, part) {
  162. return true
  163. }
  164. }
  165. return false
  166. }
  167. for _, t := range meta.Tensors().Items() {
  168. switch {
  169. case hasPart(t.Name, "position_embd", "token_embd", "token_norm_embd", "token_types"):
  170. createTensor(tensor{source: t}, input.bts)
  171. case hasPart(t.Name, "cls", "output", "output_norm"):
  172. createTensor(tensor{source: t}, output.bts)
  173. default:
  174. if i := func() int {
  175. if fields := strings.FieldsFunc(t.Name, func(r rune) bool { return !unicode.IsNumber(r) }); len(fields) > 0 {
  176. if i, err := strconv.Atoi(fields[0]); err == nil {
  177. return i
  178. }
  179. }
  180. return -1
  181. }(); i >= 0 {
  182. createTensor(tensor{source: t}, layers[i].bts)
  183. } else {
  184. for i, layer := range layers {
  185. createTensor(tensor{
  186. source: t,
  187. target: "blk." + strconv.Itoa(i) + "." + t.Name,
  188. }, layer.bts)
  189. }
  190. }
  191. }
  192. }
  193. bbs := make(map[*C.struct_ggml_context][]*C.struct_ggml_backend_buffer, len(ctxs))
  194. for bt, c := range ctxs {
  195. if C.ggml_get_first_tensor(c) == nil {
  196. continue
  197. }
  198. b := C.ggml_backend_alloc_ctx_tensors_from_buft(c, bt)
  199. C.ggml_backend_buffer_set_usage(b, C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS)
  200. bbs[c] = append(bbs[c], b)
  201. }
  202. for bs := range maps.Values(bbs) {
  203. for _, b := range bs {
  204. slog.Info("model", "buffer", C.GoString(C.ggml_backend_buffer_name(b)), "size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(b))))
  205. }
  206. }
  207. tensors := make(map[string]*C.struct_ggml_tensor)
  208. for _, c := range ctxs {
  209. for t := C.ggml_get_first_tensor(c); t != nil; t = C.ggml_get_next_tensor(c, t) {
  210. tensors[C.GoString(C.ggml_get_name(t))] = t
  211. }
  212. }
  213. sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset))
  214. var g errgroup.Group
  215. for _, t := range meta.Tensors().Items() {
  216. for _, target := range targets[t.Name] {
  217. g.Go(func() error {
  218. if target == "" {
  219. target = t.Name
  220. }
  221. tt, ok := tensors[target]
  222. if !ok {
  223. return fmt.Errorf("unassigned tensor: %s", t.Name)
  224. }
  225. bts := make([]byte, t.Size())
  226. n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), bts)
  227. if err != nil {
  228. return err
  229. }
  230. if n != len(bts) {
  231. return errors.New("short read")
  232. }
  233. cname := C.CString(t.Name)
  234. C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), 0, C.size_t(t.Size()))
  235. C.free(unsafe.Pointer(cname))
  236. return nil
  237. })
  238. }
  239. }
  240. if g.Wait() != nil {
  241. return nil, err
  242. }
  243. var backends []*C.struct_ggml_backend
  244. var bufts []*C.struct_ggml_backend_buffer_type
  245. for _, d := range append(gpus, append(accels, cpus...)...) {
  246. b := C.ggml_backend_dev_init(d, nil)
  247. backends = append(backends, b)
  248. bt := C.ggml_backend_get_default_buffer_type(b)
  249. if d := C.ggml_backend_get_device(b); C.ggml_backend_dev_type(d) == C.GGML_BACKEND_DEVICE_TYPE_CPU && len(gpus) > 0 {
  250. if hbt := C.ggml_backend_dev_host_buffer_type(d); hbt != nil {
  251. bt = hbt
  252. }
  253. }
  254. bufts = append(bufts, bt)
  255. slog.Info("compute buffer", "backend", C.GoString(C.ggml_backend_name(b)), "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
  256. }
  257. return &Backend{
  258. flashAttention: params.FlashAttention,
  259. meta: meta,
  260. tensors: tensors,
  261. sched: C.ggml_backend_sched_new(
  262. (*C.ggml_backend_t)(unsafe.Pointer(&backends[0])),
  263. (*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])),
  264. C.int(len(backends)),
  265. C.size_t(max(8192, len(meta.Tensors().Items())*5)),
  266. true,
  267. ),
  268. }, nil
  269. }
  270. func init() {
  271. ml.RegisterBackend("ggml", New)
  272. }
  273. func (b *Backend) Config() ml.Config {
  274. return b.meta.KV()
  275. }
  276. func (b *Backend) Get(name string) ml.Tensor {
  277. if t, ok := b.tensors[name]; ok {
  278. return &Tensor{b: b, t: t}
  279. }
  280. return nil
  281. }
  282. func (b *Backend) NewContext() ml.Context {
  283. maxTensors := max(8192, len(b.meta.Tensors().Items())*5)
  284. return &Context{
  285. b: b,
  286. maxTensors: maxTensors,
  287. ctx: C.ggml_init(C.struct_ggml_init_params{
  288. mem_size: C.size_t(maxTensors)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(maxTensors), false),
  289. no_alloc: true,
  290. }),
  291. }
  292. }
  293. func (b *Backend) CacheConfig() ml.CacheConfig {
  294. if b.flashAttention {
  295. return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD}
  296. } else {
  297. return ml.CacheConfig{CachePadding: 32, PermutedV: true}
  298. }
  299. }
  300. type Context struct {
  301. b *Backend
  302. ctx *C.struct_ggml_context
  303. graph *C.struct_ggml_cgraph
  304. maxTensors int
  305. }
  306. func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
  307. if c.graph == nil {
  308. c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.maxTensors), false)
  309. }
  310. for _, tensor := range tensors {
  311. C.ggml_build_forward_expand(c.graph, tensor.(*Tensor).t)
  312. }
  313. return c
  314. }
  315. func (c *Context) Compute(tensors ...ml.Tensor) {
  316. C.ggml_backend_sched_reset(c.b.sched)
  317. C.ggml_backend_sched_alloc_graph(c.b.sched, c.graph)
  318. C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph)
  319. needSync := true
  320. sync := func() {
  321. if needSync {
  322. C.ggml_backend_sched_synchronize(c.b.sched)
  323. needSync = false
  324. }
  325. }
  326. for _, t := range tensors {
  327. if C.ggml_nbytes(t.(*Tensor).t) > 0 {
  328. t.(*Tensor).sync = sync
  329. }
  330. }
  331. }
  332. func (c *Context) MaxTensors() int {
  333. return c.maxTensors
  334. }
  335. func shapeToGGML(shape []int) *C.int64_t {
  336. sh := make([]C.int64_t, len(shape))
  337. for i, s := range shape {
  338. sh[i] = C.int64_t(s)
  339. }
  340. return &sh[0]
  341. }
  342. func newTensor(ctx Context, dtype ml.DType, shape []int) ml.Tensor {
  343. if len(shape) < 1 || len(shape) > 4 {
  344. panic("unsupported number of dimensions")
  345. }
  346. for _, dim := range shape {
  347. if dim < 1 {
  348. panic("invalid shape")
  349. }
  350. }
  351. var t *C.struct_ggml_tensor
  352. switch dtype {
  353. case ml.DTypeF32:
  354. t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape))
  355. case ml.DTypeF16:
  356. t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape))
  357. case ml.DTypeI32:
  358. t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape))
  359. default:
  360. panic("unsupported dtype")
  361. }
  362. b := C.ggml_backend_alloc_buffer(C.ggml_backend_sched_get_backend(ctx.b.sched, 0), C.ggml_nbytes(t))
  363. C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
  364. C.ggml_set_input(t)
  365. return &Tensor{b: ctx.b, t: t}
  366. }
  367. func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
  368. return newTensor(c, dtype, shape)
  369. }
  370. func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
  371. t := newTensor(c, dtype, shape)
  372. C.ggml_set_zero(t.(*Tensor).t)
  373. return t
  374. }
  375. func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype uint32) (ml.Tensor, error) {
  376. n := len(s)
  377. if n == 0 {
  378. var shape C.int64_t = 0
  379. t := C.ggml_new_tensor(ctx.ctx, dtype, 1, &shape)
  380. return &Tensor{b: ctx.b, t: t}, nil
  381. }
  382. for _, v := range shape {
  383. n /= v
  384. }
  385. if n != 1 {
  386. return nil, fmt.Errorf("invalid shape %v for %d elements", shape, len(s))
  387. }
  388. t := C.ggml_new_tensor(ctx.ctx, dtype, C.int(len(shape)), shapeToGGML(shape))
  389. b := C.ggml_backend_alloc_buffer(C.ggml_backend_sched_get_backend(ctx.b.sched, 0), C.ggml_nbytes(t))
  390. C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
  391. C.ggml_backend_tensor_set(t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t))
  392. C.ggml_set_input(t)
  393. return &Tensor{b: ctx.b, t: t}, nil
  394. }
  395. func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
  396. return fromSlice(c, s, shape, C.GGML_TYPE_F32)
  397. }
  398. func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
  399. return fromSlice(c, s, shape, C.GGML_TYPE_I32)
  400. }
  401. func (c *Context) Close() {
  402. if c != nil {
  403. C.ggml_free(c.ctx)
  404. }
  405. }
  406. type Tensor struct {
  407. b *Backend
  408. t *C.struct_ggml_tensor
  409. sync func()
  410. }
  411. func (t *Tensor) LogValue() slog.Value {
  412. return slog.GroupValue(
  413. slog.String("name", C.GoString(C.ggml_get_name(t.t))),
  414. slog.String("type", C.GoString(C.ggml_type_name(t.t._type))),
  415. slog.Any("shape", t.Shape()),
  416. )
  417. }
  418. func (t *Tensor) Dim(n int) int {
  419. return int(t.t.ne[n])
  420. }
  421. func (t *Tensor) Stride(n int) int {
  422. return int(t.t.nb[n])
  423. }
  424. func (t *Tensor) Shape() []int {
  425. shape := make([]int, C.ggml_n_dims(t.t))
  426. for i := range shape {
  427. shape[i] = t.Dim(i)
  428. }
  429. return shape
  430. }
  431. func (t *Tensor) Bytes() (data []byte) {
  432. if t.sync != nil {
  433. data = make([]byte, C.ggml_nbytes(t.t))
  434. t.sync()
  435. C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
  436. }
  437. return
  438. }
  439. func (t *Tensor) Floats() (data []float32) {
  440. if t.sync != nil {
  441. data = make([]float32, C.ggml_nelements(t.t))
  442. t.sync()
  443. C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
  444. }
  445. return
  446. }
  447. func (t *Tensor) DType() ml.DType {
  448. switch t.t._type {
  449. case C.GGML_TYPE_F32:
  450. return ml.DTypeF32
  451. case C.GGML_TYPE_F16:
  452. return ml.DTypeF16
  453. case C.GGML_TYPE_I32:
  454. return ml.DTypeI32
  455. default:
  456. return ml.DTypeOther
  457. }
  458. }
  459. func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  460. return &Tensor{
  461. b: t.b,
  462. t: C.ggml_add(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
  463. }
  464. }
  465. func (t *Tensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
  466. if len(s) > 0 {
  467. return t.Concat(ctx, s[0].Stack(ctx, dim, s[1:]...), dim)
  468. }
  469. return t
  470. }
  471. func (t *Tensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
  472. return &Tensor{
  473. b: t.b,
  474. t: C.ggml_concat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(dim)),
  475. }
  476. }
  477. func (t *Tensor) Contiguous(ctx ml.Context) ml.Tensor {
  478. return &Tensor{
  479. b: t.b,
  480. t: C.ggml_cont(ctx.(*Context).ctx, t.t),
  481. }
  482. }
  483. func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  484. return &Tensor{
  485. b: t.b,
  486. t: C.ggml_mul(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
  487. }
  488. }
  489. func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  490. return &Tensor{
  491. b: t.b,
  492. t: C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
  493. }
  494. }
  495. func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  496. mul := C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t)
  497. C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32)
  498. return &Tensor{
  499. b: t.b,
  500. t: mul,
  501. }
  502. }
  503. func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
  504. tt := (&Tensor{b: t.b, t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
  505. if b != nil {
  506. tt = tt.Add(ctx, b)
  507. }
  508. return tt
  509. }
  510. func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
  511. return (&Tensor{b: t.b, t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
  512. }
  513. func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
  514. if len(shape) != 4 {
  515. panic("expected 4 dimensions")
  516. }
  517. return &Tensor{
  518. b: t.b,
  519. t: C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
  520. }
  521. }
  522. func (t *Tensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
  523. if len(shape) != 4 {
  524. panic("expected 4 dimensions")
  525. }
  526. return &Tensor{
  527. b: t.b,
  528. t: C.ggml_permute(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
  529. }
  530. }
  531. func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  532. return &Tensor{
  533. b: t.b,
  534. t: C.ggml_get_rows(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
  535. }
  536. }
  537. func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  538. return &Tensor{
  539. b: t.b,
  540. t: C.ggml_cpy(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
  541. }
  542. }
  543. func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
  544. switch len(shape) {
  545. case 1:
  546. return &Tensor{
  547. b: t.b,
  548. t: C.ggml_reshape_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])),
  549. }
  550. case 2:
  551. return &Tensor{
  552. b: t.b,
  553. t: C.ggml_reshape_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])),
  554. }
  555. case 3:
  556. return &Tensor{
  557. b: t.b,
  558. t: C.ggml_reshape_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])),
  559. }
  560. case 4:
  561. return &Tensor{
  562. b: t.b,
  563. t: C.ggml_reshape_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])),
  564. }
  565. default:
  566. panic("unsupported number of dimensions")
  567. }
  568. }
  569. func (t *Tensor) Scale(ctx ml.Context, s float64) ml.Tensor {
  570. return &Tensor{
  571. b: t.b,
  572. t: C.ggml_scale(ctx.(*Context).ctx, t.t, (C.float)(s)),
  573. }
  574. }
  575. func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor {
  576. return &Tensor{
  577. b: t.b,
  578. t: C.ggml_soft_max(ctx.(*Context).ctx, t.t),
  579. }
  580. }
  581. func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor {
  582. return &Tensor{
  583. b: t.b,
  584. t: C.ggml_tanh_inplace(ctx.(*Context).ctx, t.t),
  585. }
  586. }
  587. func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
  588. if len(shape) != 4 {
  589. panic("expected 4 dimensions")
  590. }
  591. return &Tensor{
  592. b: t.b,
  593. t: C.ggml_unpad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
  594. }
  595. }
  596. func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
  597. switch len(shape) {
  598. case 1:
  599. return &Tensor{
  600. b: t.b,
  601. t: C.ggml_view_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.size_t(offset)),
  602. }
  603. case 3:
  604. return &Tensor{
  605. b: t.b,
  606. t: C.ggml_view_2d(ctx.(*Context).ctx, t.t,
  607. C.int64_t(shape[0]), C.int64_t(shape[2]),
  608. C.size_t(shape[1]),
  609. C.size_t(offset)),
  610. }
  611. case 5:
  612. return &Tensor{
  613. b: t.b,
  614. t: C.ggml_view_3d(ctx.(*Context).ctx, t.t,
  615. C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]),
  616. C.size_t(shape[1]), C.size_t(shape[3]),
  617. C.size_t(offset)),
  618. }
  619. case 7:
  620. return &Tensor{
  621. b: t.b,
  622. t: C.ggml_view_4d(ctx.(*Context).ctx, t.t,
  623. C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), C.int64_t(shape[6]),
  624. C.size_t(shape[1]), C.size_t(shape[3]), C.size_t(shape[5]),
  625. C.size_t(offset)),
  626. }
  627. default:
  628. panic("unsupported number of dimensions")
  629. }
  630. }
  631. const (
  632. ropeTypeNorm C.int = iota
  633. )
  634. func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, ropeBase, ropeScale float32) ml.Tensor {
  635. if ropeFactors == nil {
  636. ropeFactors = &Tensor{b: t.b}
  637. }
  638. dequant := t.t
  639. if C.ggml_is_quantized(t.t._type) {
  640. dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
  641. }
  642. return &Tensor{
  643. b: t.b,
  644. t: C.ggml_rope_ext(
  645. ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
  646. C.int(ropeDim),
  647. 131072, // YaRN n_ctx_train
  648. ropeTypeNorm, // ROPE_TYPE_NORM
  649. C.float(ropeBase),
  650. C.float(ropeScale),
  651. 0., // YaRN ext_factor
  652. 1., // YaRN attn_factor
  653. 32., // YaRN beta_fast
  654. 1., // YaRN beta_slow
  655. ),
  656. }
  657. }
  658. func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
  659. return &Tensor{
  660. b: t.b,
  661. t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
  662. }
  663. }
  664. func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
  665. return &Tensor{
  666. b: t.b,
  667. t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
  668. }
  669. }
  670. func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
  671. return &Tensor{
  672. b: t.b,
  673. t: C.ggml_conv_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1)),
  674. }
  675. }
  676. func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor {
  677. var kqMask *C.struct_ggml_tensor
  678. if mask != nil {
  679. kqMask = mask.(*Tensor).t
  680. }
  681. query := t.Permute(ctx, 0, 2, 1, 3)
  682. key = key.Permute(ctx, 0, 2, 1, 3)
  683. if t.b.flashAttention {
  684. value = value.Permute(ctx, 0, 2, 1, 3)
  685. kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
  686. C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32)
  687. return &Tensor{b: t.b, t: kqv}
  688. } else {
  689. kq := key.MulmatFullPrec(ctx, query)
  690. kq = &Tensor{
  691. b: t.b,
  692. t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
  693. }
  694. kqv := value.Mulmat(ctx, kq)
  695. return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  696. }
  697. }