ggml.go 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894
  1. package ggml
  2. // #cgo CPPFLAGS: -I${SRCDIR}/ggml/include
  3. // #include <stdlib.h>
  4. // #include <stdint.h>
  5. // #include "ggml.h"
  6. // #include "ggml-cpu.h"
  7. // #include "ggml-backend.h"
  8. import "C"
  9. import (
  10. "errors"
  11. "fmt"
  12. "io"
  13. "iter"
  14. "log/slog"
  15. "maps"
  16. "os"
  17. "slices"
  18. "strconv"
  19. "strings"
  20. "unicode"
  21. "unsafe"
  22. "github.com/ollama/ollama/format"
  23. fs "github.com/ollama/ollama/fs/ggml"
  24. "github.com/ollama/ollama/ml"
  25. ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
  26. "golang.org/x/sync/errgroup"
  27. )
  28. func devices() iter.Seq[*C.struct_ggml_backend_device] {
  29. return func(yield func(*C.struct_ggml_backend_device) bool) {
  30. ggml.OnceLoad()
  31. for i := range C.ggml_backend_dev_count() {
  32. if !yield(C.ggml_backend_dev_get(i)) {
  33. return
  34. }
  35. }
  36. }
  37. }
  38. type Backend struct {
  39. meta *fs.GGML
  40. sched *C.struct_ggml_backend_sched
  41. tensors map[string]*C.struct_ggml_tensor
  42. input *C.struct_ggml_backend
  43. output *C.struct_ggml_backend
  44. layers map[int]*C.struct_ggml_backend
  45. flashAttention bool
  46. }
  47. func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
  48. meta, n, err := fs.Decode(r, -1)
  49. if err != nil {
  50. return nil, err
  51. }
  52. slog.Info(
  53. "",
  54. "architecture", meta.KV().Architecture(),
  55. "file_type", meta.KV().FileType(),
  56. "name", meta.KV().String("general.name"),
  57. "description", meta.KV().String("general.description"),
  58. "num_tensors", len(meta.Tensors().Items()),
  59. "num_key_values", len(meta.KV()),
  60. )
  61. type dbt struct {
  62. d *C.struct_ggml_backend_device
  63. bts []*C.struct_ggml_backend_buffer_type
  64. }
  65. var cpus, accels, gpus []*C.struct_ggml_backend_device
  66. for d := range devices() {
  67. switch C.ggml_backend_dev_type(d) {
  68. case C.GGML_BACKEND_DEVICE_TYPE_CPU:
  69. cpus = append(cpus, d)
  70. case C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
  71. accels = append(accels, d)
  72. case C.GGML_BACKEND_DEVICE_TYPE_GPU:
  73. gpus = append(gpus, d)
  74. }
  75. }
  76. var cpuBufferTypes []*C.struct_ggml_backend_buffer_type
  77. for _, d := range append(accels, append(gpus, cpus...)...) {
  78. switch C.ggml_backend_dev_type(d) {
  79. case C.GGML_BACKEND_DEVICE_TYPE_CPU,
  80. C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
  81. cpuBufferTypes = append(cpuBufferTypes, C.ggml_backend_dev_buffer_type(d))
  82. }
  83. }
  84. var sum uint64
  85. var cumsum []uint64
  86. var gpuBufferTypes []dbt
  87. for _, d := range gpus {
  88. var free, total C.size_t
  89. C.ggml_backend_dev_memory(d, &free, &total)
  90. sum += uint64(free)
  91. cumsum = append(cumsum, sum)
  92. bt := C.ggml_backend_dev_buffer_type(d)
  93. gpuBufferTypes = append(gpuBufferTypes, dbt{
  94. d: d,
  95. bts: append([]*C.struct_ggml_backend_buffer_type{bt}, cpuBufferTypes...),
  96. })
  97. }
  98. splits := make([]float64, len(cumsum))
  99. for i := range splits {
  100. splits[i] = float64(cumsum[i]) / float64(sum)
  101. }
  102. input := dbt{C.ggml_backend_dev_by_type(C.GGML_BACKEND_DEVICE_TYPE_CPU), cpuBufferTypes}
  103. var blocks int
  104. for key, value := range meta.KV() {
  105. if strings.HasSuffix(key, ".block_count") {
  106. blocks += int(value.(uint32))
  107. }
  108. }
  109. indexFunc := func(i int) func(float64) bool {
  110. return func(f float64) bool {
  111. return float64(i)/float64(blocks+1) < f
  112. }
  113. }
  114. layers := make([]dbt, blocks)
  115. for i := range layers {
  116. layers[i] = gpuBufferTypes[slices.IndexFunc(splits, indexFunc(i))]
  117. }
  118. output := gpuBufferTypes[slices.IndexFunc(splits, indexFunc(blocks))]
  119. maxTensors := len(meta.Tensors().Items())
  120. maxTensors += 1
  121. maxTensors += blocks * 2
  122. type tensor struct {
  123. source *fs.Tensor
  124. target string
  125. }
  126. targets := make(map[string][]string)
  127. ctxs := make(map[*C.struct_ggml_backend_buffer_type]*C.struct_ggml_context)
  128. createTensor := func(t tensor, bts []*C.struct_ggml_backend_buffer_type) *C.struct_ggml_tensor {
  129. for _, bt := range bts {
  130. if _, ok := ctxs[bt]; !ok {
  131. ctxs[bt] = C.ggml_init(C.struct_ggml_init_params{
  132. mem_size: C.ggml_tensor_overhead() * C.size_t(maxTensors),
  133. no_alloc: true,
  134. })
  135. }
  136. targets[t.source.Name] = append(targets[t.source.Name], t.target)
  137. name := t.source.Name
  138. if t.target != "" {
  139. name = t.target
  140. }
  141. cname := C.CString(name)
  142. defer C.free(unsafe.Pointer(cname))
  143. if tt := C.ggml_get_tensor(ctxs[bt], cname); tt != nil {
  144. return tt
  145. }
  146. tt := C.ggml_new_tensor(ctxs[bt], t.source.Kind, C.int(len(t.source.Shape)), (*C.int64_t)(unsafe.Pointer(&t.source.Shape[0])))
  147. C.ggml_set_name(tt, cname)
  148. slog.Debug("created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
  149. //nolint:staticcheck // TODO: check if buffer type supports this tensor
  150. return tt
  151. }
  152. return nil
  153. }
  154. hasPart := func(s string, parts ...string) bool {
  155. split := strings.Split(s, ".")
  156. for _, part := range parts {
  157. if slices.Contains(split, part) {
  158. return true
  159. }
  160. }
  161. return false
  162. }
  163. for _, t := range meta.Tensors().Items() {
  164. switch {
  165. case hasPart(t.Name, "position_embd", "token_embd", "token_norm_embd", "token_types"):
  166. createTensor(tensor{source: t}, input.bts)
  167. case hasPart(t.Name, "cls", "output", "output_norm"):
  168. createTensor(tensor{source: t}, output.bts)
  169. default:
  170. if i := func() int {
  171. if fields := strings.FieldsFunc(t.Name, func(r rune) bool { return !unicode.IsNumber(r) }); len(fields) > 0 {
  172. if i, err := strconv.Atoi(fields[0]); err == nil {
  173. return i
  174. }
  175. }
  176. return -1
  177. }(); i >= 0 {
  178. createTensor(tensor{source: t}, layers[i].bts)
  179. } else {
  180. for i, layer := range layers {
  181. createTensor(tensor{
  182. source: t,
  183. target: "blk." + strconv.Itoa(i) + "." + t.Name,
  184. }, layer.bts)
  185. }
  186. }
  187. }
  188. }
  189. bbs := make(map[*C.struct_ggml_context][]*C.struct_ggml_backend_buffer, len(ctxs))
  190. for bt, c := range ctxs {
  191. if C.ggml_get_first_tensor(c) == nil {
  192. continue
  193. }
  194. b := C.ggml_backend_alloc_ctx_tensors_from_buft(c, bt)
  195. C.ggml_backend_buffer_set_usage(b, C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS)
  196. bbs[c] = append(bbs[c], b)
  197. }
  198. for bs := range maps.Values(bbs) {
  199. for _, b := range bs {
  200. slog.Info("model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(b)), "size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(b))))
  201. }
  202. }
  203. tensors := make(map[string]*C.struct_ggml_tensor)
  204. for _, c := range ctxs {
  205. for t := C.ggml_get_first_tensor(c); t != nil; t = C.ggml_get_next_tensor(c, t) {
  206. tensors[C.GoString(C.ggml_get_name(t))] = t
  207. }
  208. }
  209. sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset))
  210. var g errgroup.Group
  211. for _, t := range meta.Tensors().Items() {
  212. for _, target := range targets[t.Name] {
  213. g.Go(func() error {
  214. if target == "" {
  215. target = t.Name
  216. }
  217. tt, ok := tensors[target]
  218. if !ok {
  219. return fmt.Errorf("unassigned tensor: %s", t.Name)
  220. }
  221. bts := make([]byte, t.Size())
  222. n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), bts)
  223. if err != nil {
  224. return err
  225. }
  226. if n != len(bts) {
  227. return errors.New("short read")
  228. }
  229. cname := C.CString(t.Name)
  230. C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), 0, C.size_t(t.Size()))
  231. C.free(unsafe.Pointer(cname))
  232. return nil
  233. })
  234. }
  235. }
  236. if g.Wait() != nil {
  237. return nil, err
  238. }
  239. deviceBackends := make(map[*C.struct_ggml_backend_device]*C.struct_ggml_backend)
  240. var backends []*C.struct_ggml_backend
  241. var bufts []*C.struct_ggml_backend_buffer_type
  242. for _, d := range append(gpus, append(accels, cpus...)...) {
  243. b := C.ggml_backend_dev_init(d, nil)
  244. backends = append(backends, b)
  245. deviceBackends[d] = b
  246. bt := C.ggml_backend_get_default_buffer_type(b)
  247. if d := C.ggml_backend_get_device(b); C.ggml_backend_dev_type(d) == C.GGML_BACKEND_DEVICE_TYPE_CPU && len(gpus) > 0 {
  248. if hbt := C.ggml_backend_dev_host_buffer_type(d); hbt != nil {
  249. bt = hbt
  250. }
  251. }
  252. bufts = append(bufts, bt)
  253. slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(b)), "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
  254. }
  255. return &Backend{
  256. flashAttention: params.FlashAttention,
  257. meta: meta,
  258. tensors: tensors,
  259. sched: C.ggml_backend_sched_new(
  260. (*C.ggml_backend_t)(unsafe.Pointer(&backends[0])),
  261. (*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])),
  262. C.int(len(backends)),
  263. C.size_t(max(8192, len(meta.Tensors().Items())*5)),
  264. true,
  265. ),
  266. input: deviceBackends[input.d],
  267. output: deviceBackends[output.d],
  268. layers: func() map[int]*C.struct_ggml_backend {
  269. m := make(map[int]*C.struct_ggml_backend)
  270. for i, layer := range layers {
  271. m[i] = deviceBackends[layer.d]
  272. }
  273. return m
  274. }(),
  275. }, nil
  276. }
  277. func init() {
  278. ml.RegisterBackend("ggml", New)
  279. }
  280. func (b *Backend) Config() ml.Config {
  281. return b.meta.KV()
  282. }
  283. func (b *Backend) Get(name string) ml.Tensor {
  284. if t, ok := b.tensors[name]; ok {
  285. return &Tensor{b: b, t: t}
  286. }
  287. return nil
  288. }
  289. func (b *Backend) NewContext() ml.Context {
  290. return b.NewContextSize(max(8192, len(b.meta.Tensors().Items())*5))
  291. }
  292. func (b *Backend) NewContextSize(n int) ml.Context {
  293. return &Context{
  294. b: b,
  295. ctx: C.ggml_init(C.struct_ggml_init_params{
  296. mem_size: C.size_t(n)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(n), false),
  297. no_alloc: true,
  298. }),
  299. backend: C.ggml_backend_sched_get_backend(b.sched, 0),
  300. maxGraphNodes: n,
  301. input: b.input,
  302. output: b.output,
  303. layers: b.layers,
  304. }
  305. }
  306. func (b *Backend) CacheConfig() ml.CacheConfig {
  307. if b.flashAttention {
  308. return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD}
  309. } else {
  310. return ml.CacheConfig{CachePadding: 32, PermutedV: true}
  311. }
  312. }
  313. type Context struct {
  314. b *Backend
  315. ctx *C.struct_ggml_context
  316. graph *C.struct_ggml_cgraph
  317. // backend is the backend used for new tensors
  318. backend *C.struct_ggml_backend
  319. // input is the backend used for inputs
  320. input *C.struct_ggml_backend
  321. // output is the backend used for outputs
  322. output *C.struct_ggml_backend
  323. // output is the backend used for repeating layers
  324. layers map[int]*C.struct_ggml_backend
  325. maxGraphNodes int
  326. }
  327. func (c *Context) Input() ml.Context {
  328. if c.input != nil {
  329. return &Context{
  330. b: c.b,
  331. ctx: c.ctx,
  332. backend: c.input,
  333. maxGraphNodes: c.maxGraphNodes,
  334. }
  335. }
  336. return c
  337. }
  338. func (c *Context) Output() ml.Context {
  339. if c.output != nil {
  340. return &Context{
  341. b: c.b,
  342. ctx: c.ctx,
  343. backend: c.output,
  344. maxGraphNodes: c.maxGraphNodes,
  345. }
  346. }
  347. return c
  348. }
  349. func (c *Context) Layer(i int) ml.Context {
  350. if backend, ok := c.layers[i]; ok {
  351. return &Context{
  352. b: c.b,
  353. ctx: c.ctx,
  354. backend: backend,
  355. maxGraphNodes: c.maxGraphNodes,
  356. }
  357. }
  358. return c
  359. }
  360. func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
  361. if c.graph == nil {
  362. c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.maxGraphNodes), false)
  363. }
  364. for _, tensor := range tensors {
  365. C.ggml_build_forward_expand(c.graph, tensor.(*Tensor).t)
  366. }
  367. return c
  368. }
  369. func (c *Context) Compute(tensors ...ml.Tensor) {
  370. C.ggml_backend_sched_reset(c.b.sched)
  371. C.ggml_backend_sched_alloc_graph(c.b.sched, c.graph)
  372. C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph)
  373. needSync := true
  374. sync := func() {
  375. if needSync {
  376. C.ggml_backend_sched_synchronize(c.b.sched)
  377. needSync = false
  378. }
  379. }
  380. for _, t := range tensors {
  381. if C.ggml_nbytes(t.(*Tensor).t) > 0 {
  382. t.(*Tensor).sync = sync
  383. }
  384. }
  385. }
  386. func (c *Context) MaxGraphNodes() int {
  387. return c.maxGraphNodes
  388. }
  389. func shapeToGGML(shape []int) *C.int64_t {
  390. sh := make([]C.int64_t, len(shape))
  391. for i, s := range shape {
  392. sh[i] = C.int64_t(s)
  393. }
  394. return &sh[0]
  395. }
  396. func (c Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
  397. if len(shape) < 1 || len(shape) > 4 {
  398. panic("unsupported number of dimensions")
  399. }
  400. for _, dim := range shape {
  401. if dim < 1 {
  402. panic("invalid shape")
  403. }
  404. }
  405. var t *C.struct_ggml_tensor
  406. switch dtype {
  407. case ml.DTypeF32:
  408. t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape))
  409. case ml.DTypeF16:
  410. t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape))
  411. case ml.DTypeI32:
  412. t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape))
  413. default:
  414. panic("unsupported dtype")
  415. }
  416. b := C.ggml_backend_alloc_buffer(c.backend, C.ggml_nbytes(t))
  417. C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
  418. return &Tensor{b: c.b, t: t}
  419. }
  420. func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
  421. return c.newTensor(dtype, shape)
  422. }
  423. func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
  424. t := c.newTensor(dtype, shape)
  425. C.ggml_set_zero(t.(*Tensor).t)
  426. return t
  427. }
  428. func checkShape[S ~[]E, E any](s S, shape ...int) error {
  429. n := len(s)
  430. for _, v := range shape {
  431. n /= v
  432. }
  433. if n != 1 {
  434. return fmt.Errorf("invalid shape: %v", shape)
  435. }
  436. return nil
  437. }
  438. func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
  439. if err := checkShape(s, shape...); err != nil {
  440. return nil, err
  441. }
  442. t := c.newTensor(ml.DTypeF32, shape)
  443. C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
  444. return t, nil
  445. }
  446. func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
  447. if err := checkShape(s, shape...); err != nil {
  448. return nil, err
  449. }
  450. t := c.newTensor(ml.DTypeI32, shape)
  451. C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
  452. return t, nil
  453. }
  454. func (c Context) Close() {
  455. if c.ctx != nil {
  456. C.ggml_free(c.ctx)
  457. }
  458. }
  459. type Tensor struct {
  460. b *Backend
  461. t *C.struct_ggml_tensor
  462. sync func()
  463. }
  464. func (t *Tensor) LogValue() slog.Value {
  465. return slog.GroupValue(
  466. slog.String("name", C.GoString(C.ggml_get_name(t.t))),
  467. slog.String("type", C.GoString(C.ggml_type_name(t.t._type))),
  468. slog.Any("shape", t.Shape()),
  469. )
  470. }
  471. func (t *Tensor) Dim(n int) int {
  472. return int(t.t.ne[n])
  473. }
  474. func (t *Tensor) Stride(n int) int {
  475. return int(t.t.nb[n])
  476. }
  477. func (t *Tensor) Shape() []int {
  478. shape := make([]int, C.ggml_n_dims(t.t))
  479. for i := range shape {
  480. shape[i] = t.Dim(i)
  481. }
  482. return shape
  483. }
  484. func (t *Tensor) Bytes() (data []byte) {
  485. if t.sync != nil {
  486. data = make([]byte, C.ggml_nbytes(t.t))
  487. t.sync()
  488. C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
  489. }
  490. return
  491. }
  492. func (t *Tensor) Floats() (data []float32) {
  493. if t.sync != nil {
  494. data = make([]float32, C.ggml_nelements(t.t))
  495. t.sync()
  496. C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
  497. }
  498. return
  499. }
  500. func (t *Tensor) DType() ml.DType {
  501. switch t.t._type {
  502. case C.GGML_TYPE_F32:
  503. return ml.DTypeF32
  504. case C.GGML_TYPE_F16:
  505. return ml.DTypeF16
  506. case C.GGML_TYPE_I32:
  507. return ml.DTypeI32
  508. default:
  509. return ml.DTypeOther
  510. }
  511. }
  512. func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  513. return &Tensor{
  514. b: t.b,
  515. t: C.ggml_add(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
  516. }
  517. }
  518. func (t *Tensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
  519. if len(s) > 0 {
  520. return t.Concat(ctx, s[0].Stack(ctx, dim, s[1:]...), dim)
  521. }
  522. return t
  523. }
  524. func (t *Tensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
  525. return &Tensor{
  526. b: t.b,
  527. t: C.ggml_concat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(dim)),
  528. }
  529. }
  530. func (t *Tensor) Contiguous(ctx ml.Context) ml.Tensor {
  531. return &Tensor{
  532. b: t.b,
  533. t: C.ggml_cont(ctx.(*Context).ctx, t.t),
  534. }
  535. }
  536. func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  537. return &Tensor{
  538. b: t.b,
  539. t: C.ggml_mul(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
  540. }
  541. }
  542. func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  543. return &Tensor{
  544. b: t.b,
  545. t: C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
  546. }
  547. }
  548. func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  549. mul := C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t)
  550. C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32)
  551. return &Tensor{
  552. b: t.b,
  553. t: mul,
  554. }
  555. }
  556. func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
  557. tt := (&Tensor{b: t.b, t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
  558. if b != nil {
  559. tt = tt.Add(ctx, b)
  560. }
  561. return tt
  562. }
  563. func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
  564. return (&Tensor{b: t.b, t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
  565. }
  566. func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
  567. if len(shape) != 4 {
  568. panic("expected 4 dimensions")
  569. }
  570. return &Tensor{
  571. b: t.b,
  572. t: C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
  573. }
  574. }
  575. func (t *Tensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
  576. if len(shape) != 4 {
  577. panic("expected 4 dimensions")
  578. }
  579. return &Tensor{
  580. b: t.b,
  581. t: C.ggml_permute(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
  582. }
  583. }
  584. func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  585. return &Tensor{
  586. b: t.b,
  587. t: C.ggml_get_rows(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
  588. }
  589. }
  590. func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  591. return &Tensor{
  592. b: t.b,
  593. t: C.ggml_cpy(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
  594. }
  595. }
  596. func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
  597. switch len(shape) {
  598. case 1:
  599. return &Tensor{
  600. b: t.b,
  601. t: C.ggml_reshape_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])),
  602. }
  603. case 2:
  604. return &Tensor{
  605. b: t.b,
  606. t: C.ggml_reshape_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])),
  607. }
  608. case 3:
  609. return &Tensor{
  610. b: t.b,
  611. t: C.ggml_reshape_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])),
  612. }
  613. case 4:
  614. return &Tensor{
  615. b: t.b,
  616. t: C.ggml_reshape_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])),
  617. }
  618. default:
  619. panic("unsupported number of dimensions")
  620. }
  621. }
  622. func (t *Tensor) Scale(ctx ml.Context, s float64) ml.Tensor {
  623. return &Tensor{
  624. b: t.b,
  625. t: C.ggml_scale(ctx.(*Context).ctx, t.t, (C.float)(s)),
  626. }
  627. }
  628. func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor {
  629. return &Tensor{
  630. b: t.b,
  631. t: C.ggml_soft_max(ctx.(*Context).ctx, t.t),
  632. }
  633. }
  634. func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor {
  635. return &Tensor{
  636. b: t.b,
  637. t: C.ggml_tanh_inplace(ctx.(*Context).ctx, t.t),
  638. }
  639. }
  640. func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
  641. if len(shape) != 4 {
  642. panic("expected 4 dimensions")
  643. }
  644. return &Tensor{
  645. b: t.b,
  646. t: C.ggml_unpad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
  647. }
  648. }
  649. func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
  650. switch len(shape) {
  651. case 1:
  652. return &Tensor{
  653. b: t.b,
  654. t: C.ggml_view_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.size_t(offset)),
  655. }
  656. case 3:
  657. return &Tensor{
  658. b: t.b,
  659. t: C.ggml_view_2d(ctx.(*Context).ctx, t.t,
  660. C.int64_t(shape[0]), C.int64_t(shape[2]),
  661. C.size_t(shape[1]),
  662. C.size_t(offset)),
  663. }
  664. case 5:
  665. return &Tensor{
  666. b: t.b,
  667. t: C.ggml_view_3d(ctx.(*Context).ctx, t.t,
  668. C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]),
  669. C.size_t(shape[1]), C.size_t(shape[3]),
  670. C.size_t(offset)),
  671. }
  672. case 7:
  673. return &Tensor{
  674. b: t.b,
  675. t: C.ggml_view_4d(ctx.(*Context).ctx, t.t,
  676. C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), C.int64_t(shape[6]),
  677. C.size_t(shape[1]), C.size_t(shape[3]), C.size_t(shape[5]),
  678. C.size_t(offset)),
  679. }
  680. default:
  681. panic("unsupported number of dimensions")
  682. }
  683. }
  684. const (
  685. ropeTypeNorm C.int = iota
  686. )
  687. func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, ropeBase, ropeScale float32) ml.Tensor {
  688. if ropeFactors == nil {
  689. ropeFactors = &Tensor{b: t.b}
  690. }
  691. dequant := t.t
  692. if C.ggml_is_quantized(t.t._type) {
  693. dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
  694. }
  695. return &Tensor{
  696. b: t.b,
  697. t: C.ggml_rope_ext(
  698. ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
  699. C.int(ropeDim),
  700. 131072, // YaRN n_ctx_train
  701. ropeTypeNorm, // ROPE_TYPE_NORM
  702. C.float(ropeBase),
  703. C.float(ropeScale),
  704. 0., // YaRN ext_factor
  705. 1., // YaRN attn_factor
  706. 32., // YaRN beta_fast
  707. 1., // YaRN beta_slow
  708. ),
  709. }
  710. }
  711. func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
  712. return &Tensor{
  713. b: t.b,
  714. t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
  715. }
  716. }
  717. func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
  718. return &Tensor{
  719. b: t.b,
  720. t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
  721. }
  722. }
  723. func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
  724. return &Tensor{
  725. b: t.b,
  726. t: C.ggml_conv_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1)),
  727. }
  728. }
  729. func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor {
  730. var kqMask *C.struct_ggml_tensor
  731. if mask != nil {
  732. kqMask = mask.(*Tensor).t
  733. }
  734. query := t.Permute(ctx, 0, 2, 1, 3)
  735. key = key.Permute(ctx, 0, 2, 1, 3)
  736. if t.b.flashAttention {
  737. value = value.Permute(ctx, 0, 2, 1, 3)
  738. kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
  739. C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32)
  740. return &Tensor{b: t.b, t: kqv}
  741. } else {
  742. kq := key.MulmatFullPrec(ctx, query)
  743. kq = &Tensor{
  744. b: t.b,
  745. t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
  746. }
  747. kqv := value.Mulmat(ctx, kq)
  748. return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  749. }
  750. }