ggml.go 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903
  1. package ggml
  2. // #cgo CPPFLAGS: -I${SRCDIR}/ggml/include
  3. // #include <stdlib.h>
  4. // #include <stdint.h>
  5. // #include "ggml.h"
  6. // #include "ggml-cpu.h"
  7. // #include "ggml-backend.h"
  8. import "C"
  9. import (
  10. "errors"
  11. "fmt"
  12. "io"
  13. "iter"
  14. "log/slog"
  15. "maps"
  16. "os"
  17. "slices"
  18. "strconv"
  19. "strings"
  20. "unicode"
  21. "unsafe"
  22. "github.com/ollama/ollama/format"
  23. fs "github.com/ollama/ollama/fs/ggml"
  24. "github.com/ollama/ollama/ml"
  25. ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
  26. "golang.org/x/sync/errgroup"
  27. )
  28. func devices() iter.Seq[*C.struct_ggml_backend_device] {
  29. return func(yield func(*C.struct_ggml_backend_device) bool) {
  30. ggml.OnceLoad()
  31. for i := range C.ggml_backend_dev_count() {
  32. if !yield(C.ggml_backend_dev_get(i)) {
  33. return
  34. }
  35. }
  36. }
  37. }
  38. type Backend struct {
  39. meta *fs.GGML
  40. sched *C.struct_ggml_backend_sched
  41. tensors map[string]*C.struct_ggml_tensor
  42. input *C.struct_ggml_backend
  43. output *C.struct_ggml_backend
  44. layers map[int]*C.struct_ggml_backend
  45. flashAttention bool
  46. }
  47. func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
  48. meta, n, err := fs.Decode(r, -1)
  49. if err != nil {
  50. return nil, err
  51. }
  52. slog.Info(
  53. "",
  54. "architecture", meta.KV().Architecture(),
  55. "file_type", meta.KV().FileType(),
  56. "name", meta.KV().String("general.name"),
  57. "description", meta.KV().String("general.description"),
  58. "num_tensors", len(meta.Tensors().Items()),
  59. "num_key_values", len(meta.KV()),
  60. )
  61. type deviceBufferType struct {
  62. d *C.struct_ggml_backend_device
  63. bts []*C.struct_ggml_backend_buffer_type
  64. }
  65. var cpus, accels, gpus []*C.struct_ggml_backend_device
  66. for d := range devices() {
  67. switch C.ggml_backend_dev_type(d) {
  68. case C.GGML_BACKEND_DEVICE_TYPE_CPU:
  69. cpus = append(cpus, d)
  70. case C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
  71. accels = append(accels, d)
  72. case C.GGML_BACKEND_DEVICE_TYPE_GPU:
  73. gpus = append(gpus, d)
  74. }
  75. }
  76. var cpuBufferTypes []*C.struct_ggml_backend_buffer_type
  77. for _, d := range append(accels, append(gpus, cpus...)...) {
  78. switch C.ggml_backend_dev_type(d) {
  79. case C.GGML_BACKEND_DEVICE_TYPE_CPU,
  80. C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
  81. cpuBufferTypes = append(cpuBufferTypes, C.ggml_backend_dev_buffer_type(d))
  82. }
  83. }
  84. var sum uint64
  85. var cumsum []uint64
  86. var gpuDeviceBufferTypes []deviceBufferType
  87. for _, d := range gpus {
  88. var free, total C.size_t
  89. C.ggml_backend_dev_memory(d, &free, &total)
  90. sum += uint64(free)
  91. cumsum = append(cumsum, sum)
  92. bt := C.ggml_backend_dev_buffer_type(d)
  93. gpuDeviceBufferTypes = append(gpuDeviceBufferTypes, deviceBufferType{
  94. d: d,
  95. bts: append([]*C.struct_ggml_backend_buffer_type{bt}, cpuBufferTypes...),
  96. })
  97. }
  98. splits := make([]float64, len(cumsum))
  99. for i := range splits {
  100. splits[i] = float64(cumsum[i]) / float64(sum)
  101. }
  102. cpuDeviceBufferTypes := deviceBufferType{C.ggml_backend_dev_by_type(C.GGML_BACKEND_DEVICE_TYPE_CPU), cpuBufferTypes}
  103. input := cpuDeviceBufferTypes
  104. var blocks int
  105. for key, value := range meta.KV() {
  106. if strings.HasSuffix(key, ".block_count") {
  107. blocks += int(value.(uint32))
  108. }
  109. }
  110. assignLayer := func(i int) (temp deviceBufferType) {
  111. if i >= params.NumGPULayers {
  112. return cpuDeviceBufferTypes
  113. }
  114. return gpuDeviceBufferTypes[slices.IndexFunc(splits, func(f float64) bool {
  115. return float64(i)/float64(blocks+1) < f
  116. })]
  117. }
  118. layers := make([]deviceBufferType, blocks)
  119. for i := range layers {
  120. layers[i] = assignLayer(i)
  121. }
  122. output := assignLayer(blocks)
  123. maxTensors := len(meta.Tensors().Items())
  124. maxTensors += 1
  125. maxTensors += blocks * 2
  126. type tensor struct {
  127. source *fs.Tensor
  128. target string
  129. }
  130. targets := make(map[string][]string)
  131. ctxs := make(map[*C.struct_ggml_backend_buffer_type]*C.struct_ggml_context)
  132. createTensor := func(t tensor, bts []*C.struct_ggml_backend_buffer_type) *C.struct_ggml_tensor {
  133. for _, bt := range bts {
  134. if _, ok := ctxs[bt]; !ok {
  135. ctxs[bt] = C.ggml_init(C.struct_ggml_init_params{
  136. mem_size: C.ggml_tensor_overhead() * C.size_t(maxTensors),
  137. no_alloc: true,
  138. })
  139. }
  140. targets[t.source.Name] = append(targets[t.source.Name], t.target)
  141. name := t.source.Name
  142. if t.target != "" {
  143. name = t.target
  144. }
  145. cname := C.CString(name)
  146. defer C.free(unsafe.Pointer(cname))
  147. if tt := C.ggml_get_tensor(ctxs[bt], cname); tt != nil {
  148. return tt
  149. }
  150. tt := C.ggml_new_tensor(ctxs[bt], t.source.Kind, C.int(len(t.source.Shape)), (*C.int64_t)(unsafe.Pointer(&t.source.Shape[0])))
  151. C.ggml_set_name(tt, cname)
  152. slog.Debug("created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
  153. //nolint:staticcheck // TODO: check if buffer type supports this tensor
  154. return tt
  155. }
  156. return nil
  157. }
  158. hasPart := func(s string, parts ...string) bool {
  159. split := strings.Split(s, ".")
  160. for _, part := range parts {
  161. if slices.Contains(split, part) {
  162. return true
  163. }
  164. }
  165. return false
  166. }
  167. for _, t := range meta.Tensors().Items() {
  168. switch {
  169. case hasPart(t.Name, "position_embd", "token_embd", "token_norm_embd", "token_types"):
  170. createTensor(tensor{source: t}, input.bts)
  171. case hasPart(t.Name, "cls", "output", "output_norm"):
  172. createTensor(tensor{source: t}, output.bts)
  173. default:
  174. if i := func() int {
  175. if fields := strings.FieldsFunc(t.Name, func(r rune) bool { return !unicode.IsNumber(r) }); len(fields) > 0 {
  176. if i, err := strconv.Atoi(fields[0]); err == nil {
  177. return i
  178. }
  179. }
  180. return -1
  181. }(); i >= 0 {
  182. createTensor(tensor{source: t}, layers[i].bts)
  183. } else {
  184. for i, layer := range layers {
  185. createTensor(tensor{
  186. source: t,
  187. target: "blk." + strconv.Itoa(i) + "." + t.Name,
  188. }, layer.bts)
  189. }
  190. }
  191. }
  192. }
  193. bbs := make(map[*C.struct_ggml_context][]*C.struct_ggml_backend_buffer, len(ctxs))
  194. for bt, c := range ctxs {
  195. if C.ggml_get_first_tensor(c) == nil {
  196. continue
  197. }
  198. b := C.ggml_backend_alloc_ctx_tensors_from_buft(c, bt)
  199. C.ggml_backend_buffer_set_usage(b, C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS)
  200. bbs[c] = append(bbs[c], b)
  201. }
  202. for bs := range maps.Values(bbs) {
  203. for _, b := range bs {
  204. slog.Info("model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(b)), "size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(b))))
  205. }
  206. }
  207. tensors := make(map[string]*C.struct_ggml_tensor)
  208. for _, c := range ctxs {
  209. for t := C.ggml_get_first_tensor(c); t != nil; t = C.ggml_get_next_tensor(c, t) {
  210. tensors[C.GoString(C.ggml_get_name(t))] = t
  211. }
  212. }
  213. sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset))
  214. var g errgroup.Group
  215. for _, t := range meta.Tensors().Items() {
  216. for _, target := range targets[t.Name] {
  217. g.Go(func() error {
  218. if target == "" {
  219. target = t.Name
  220. }
  221. tt, ok := tensors[target]
  222. if !ok {
  223. return fmt.Errorf("unassigned tensor: %s", t.Name)
  224. }
  225. bts := make([]byte, t.Size())
  226. n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), bts)
  227. if err != nil {
  228. return err
  229. }
  230. if n != len(bts) {
  231. return errors.New("short read")
  232. }
  233. cname := C.CString(t.Name)
  234. C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), 0, C.size_t(t.Size()))
  235. C.free(unsafe.Pointer(cname))
  236. return nil
  237. })
  238. }
  239. }
  240. if g.Wait() != nil {
  241. return nil, err
  242. }
  243. deviceBackends := make(map[*C.struct_ggml_backend_device]*C.struct_ggml_backend)
  244. var backends []*C.struct_ggml_backend
  245. var bufts []*C.struct_ggml_backend_buffer_type
  246. for _, d := range append(gpus, append(accels, cpus...)...) {
  247. b := C.ggml_backend_dev_init(d, nil)
  248. backends = append(backends, b)
  249. deviceBackends[d] = b
  250. bt := C.ggml_backend_get_default_buffer_type(b)
  251. if d := C.ggml_backend_get_device(b); C.ggml_backend_dev_type(d) == C.GGML_BACKEND_DEVICE_TYPE_CPU && len(gpus) > 0 {
  252. if hbt := C.ggml_backend_dev_host_buffer_type(d); hbt != nil {
  253. bt = hbt
  254. }
  255. }
  256. bufts = append(bufts, bt)
  257. slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(b)), "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
  258. if C.ggml_backend_is_cpu(b) {
  259. C.ggml_backend_cpu_set_n_threads(b, C.int(params.NumThreads))
  260. }
  261. }
  262. return &Backend{
  263. flashAttention: params.FlashAttention,
  264. meta: meta,
  265. tensors: tensors,
  266. sched: C.ggml_backend_sched_new(
  267. (*C.ggml_backend_t)(unsafe.Pointer(&backends[0])),
  268. (*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])),
  269. C.int(len(backends)),
  270. C.size_t(max(8192, len(meta.Tensors().Items())*5)),
  271. true,
  272. ),
  273. input: deviceBackends[input.d],
  274. output: deviceBackends[output.d],
  275. layers: func() map[int]*C.struct_ggml_backend {
  276. m := make(map[int]*C.struct_ggml_backend)
  277. for i, layer := range layers {
  278. m[i] = deviceBackends[layer.d]
  279. }
  280. return m
  281. }(),
  282. }, nil
  283. }
  284. func init() {
  285. ml.RegisterBackend("ggml", New)
  286. }
  287. func (b *Backend) Config() ml.Config {
  288. return b.meta.KV()
  289. }
  290. func (b *Backend) Get(name string) ml.Tensor {
  291. if t, ok := b.tensors[name]; ok {
  292. return &Tensor{b: b, t: t}
  293. }
  294. return nil
  295. }
  296. func (b *Backend) NewContext() ml.Context {
  297. return b.NewContextSize(max(8192, len(b.meta.Tensors().Items())*5))
  298. }
  299. func (b *Backend) NewContextSize(n int) ml.Context {
  300. return &Context{
  301. b: b,
  302. ctx: C.ggml_init(C.struct_ggml_init_params{
  303. mem_size: C.size_t(n)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(n), false),
  304. no_alloc: true,
  305. }),
  306. backend: C.ggml_backend_sched_get_backend(b.sched, 0),
  307. maxGraphNodes: n,
  308. input: b.input,
  309. output: b.output,
  310. layers: b.layers,
  311. }
  312. }
  313. func (b *Backend) CacheConfig() ml.CacheConfig {
  314. if b.flashAttention {
  315. return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD}
  316. } else {
  317. return ml.CacheConfig{CachePadding: 32, PermutedV: true}
  318. }
  319. }
  320. type Context struct {
  321. b *Backend
  322. ctx *C.struct_ggml_context
  323. graph *C.struct_ggml_cgraph
  324. // backend is the backend used for new tensors
  325. backend *C.struct_ggml_backend
  326. // input is the backend used for inputs
  327. input *C.struct_ggml_backend
  328. // output is the backend used for outputs
  329. output *C.struct_ggml_backend
  330. // output is the backend used for repeating layers
  331. layers map[int]*C.struct_ggml_backend
  332. maxGraphNodes int
  333. }
  334. func (c *Context) Input() ml.Context {
  335. if c.input != nil {
  336. return &Context{
  337. b: c.b,
  338. ctx: c.ctx,
  339. backend: c.input,
  340. maxGraphNodes: c.maxGraphNodes,
  341. }
  342. }
  343. return c
  344. }
  345. func (c *Context) Output() ml.Context {
  346. if c.output != nil {
  347. return &Context{
  348. b: c.b,
  349. ctx: c.ctx,
  350. backend: c.output,
  351. maxGraphNodes: c.maxGraphNodes,
  352. }
  353. }
  354. return c
  355. }
  356. func (c *Context) Layer(i int) ml.Context {
  357. if backend, ok := c.layers[i]; ok {
  358. return &Context{
  359. b: c.b,
  360. ctx: c.ctx,
  361. backend: backend,
  362. maxGraphNodes: c.maxGraphNodes,
  363. }
  364. }
  365. return c
  366. }
  367. func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
  368. if c.graph == nil {
  369. c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.maxGraphNodes), false)
  370. }
  371. for _, tensor := range tensors {
  372. C.ggml_build_forward_expand(c.graph, tensor.(*Tensor).t)
  373. }
  374. return c
  375. }
  376. func (c *Context) Compute(tensors ...ml.Tensor) {
  377. C.ggml_backend_sched_reset(c.b.sched)
  378. C.ggml_backend_sched_alloc_graph(c.b.sched, c.graph)
  379. C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph)
  380. needSync := true
  381. sync := func() {
  382. if needSync {
  383. C.ggml_backend_sched_synchronize(c.b.sched)
  384. needSync = false
  385. }
  386. }
  387. for _, t := range tensors {
  388. if C.ggml_nbytes(t.(*Tensor).t) > 0 {
  389. t.(*Tensor).sync = sync
  390. }
  391. }
  392. }
  393. func (c *Context) MaxGraphNodes() int {
  394. return c.maxGraphNodes
  395. }
  396. func shapeToGGML(shape []int) *C.int64_t {
  397. sh := make([]C.int64_t, len(shape))
  398. for i, s := range shape {
  399. sh[i] = C.int64_t(s)
  400. }
  401. return &sh[0]
  402. }
  403. func (c Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
  404. if len(shape) < 1 || len(shape) > 4 {
  405. panic("unsupported number of dimensions")
  406. }
  407. for _, dim := range shape {
  408. if dim < 1 {
  409. panic("invalid shape")
  410. }
  411. }
  412. var t *C.struct_ggml_tensor
  413. switch dtype {
  414. case ml.DTypeF32:
  415. t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape))
  416. case ml.DTypeF16:
  417. t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape))
  418. case ml.DTypeI32:
  419. t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape))
  420. default:
  421. panic("unsupported dtype")
  422. }
  423. b := C.ggml_backend_alloc_buffer(c.backend, C.ggml_nbytes(t))
  424. C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
  425. return &Tensor{b: c.b, t: t}
  426. }
  427. func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
  428. return c.newTensor(dtype, shape)
  429. }
  430. func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
  431. t := c.newTensor(dtype, shape)
  432. C.ggml_set_zero(t.(*Tensor).t)
  433. return t
  434. }
  435. func checkShape[S ~[]E, E any](s S, shape ...int) error {
  436. n := len(s)
  437. for _, v := range shape {
  438. n /= v
  439. }
  440. if n != 1 {
  441. return fmt.Errorf("invalid shape: %v", shape)
  442. }
  443. return nil
  444. }
  445. func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
  446. if err := checkShape(s, shape...); err != nil {
  447. return nil, err
  448. }
  449. t := c.newTensor(ml.DTypeF32, shape)
  450. C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
  451. return t, nil
  452. }
  453. func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
  454. if err := checkShape(s, shape...); err != nil {
  455. return nil, err
  456. }
  457. t := c.newTensor(ml.DTypeI32, shape)
  458. C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
  459. return t, nil
  460. }
  461. func (c Context) Close() {
  462. if c.ctx != nil {
  463. C.ggml_free(c.ctx)
  464. }
  465. }
  466. type Tensor struct {
  467. b *Backend
  468. t *C.struct_ggml_tensor
  469. sync func()
  470. }
  471. func (t *Tensor) LogValue() slog.Value {
  472. return slog.GroupValue(
  473. slog.String("name", C.GoString(C.ggml_get_name(t.t))),
  474. slog.String("type", C.GoString(C.ggml_type_name(t.t._type))),
  475. slog.Any("shape", t.Shape()),
  476. )
  477. }
  478. func (t *Tensor) Dim(n int) int {
  479. return int(t.t.ne[n])
  480. }
  481. func (t *Tensor) Stride(n int) int {
  482. return int(t.t.nb[n])
  483. }
  484. func (t *Tensor) Shape() []int {
  485. shape := make([]int, C.ggml_n_dims(t.t))
  486. for i := range shape {
  487. shape[i] = t.Dim(i)
  488. }
  489. return shape
  490. }
  491. func (t *Tensor) Bytes() (data []byte) {
  492. if t.sync != nil {
  493. data = make([]byte, C.ggml_nbytes(t.t))
  494. t.sync()
  495. C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
  496. }
  497. return
  498. }
  499. func (t *Tensor) Floats() (data []float32) {
  500. if t.sync != nil {
  501. data = make([]float32, C.ggml_nelements(t.t))
  502. t.sync()
  503. C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
  504. }
  505. return
  506. }
  507. func (t *Tensor) DType() ml.DType {
  508. switch t.t._type {
  509. case C.GGML_TYPE_F32:
  510. return ml.DTypeF32
  511. case C.GGML_TYPE_F16:
  512. return ml.DTypeF16
  513. case C.GGML_TYPE_I32:
  514. return ml.DTypeI32
  515. default:
  516. return ml.DTypeOther
  517. }
  518. }
  519. func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  520. return &Tensor{
  521. b: t.b,
  522. t: C.ggml_add(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
  523. }
  524. }
  525. func (t *Tensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
  526. if len(s) > 0 {
  527. return t.Concat(ctx, s[0].Stack(ctx, dim, s[1:]...), dim)
  528. }
  529. return t
  530. }
  531. func (t *Tensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
  532. return &Tensor{
  533. b: t.b,
  534. t: C.ggml_concat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(dim)),
  535. }
  536. }
  537. func (t *Tensor) Contiguous(ctx ml.Context) ml.Tensor {
  538. return &Tensor{
  539. b: t.b,
  540. t: C.ggml_cont(ctx.(*Context).ctx, t.t),
  541. }
  542. }
  543. func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  544. return &Tensor{
  545. b: t.b,
  546. t: C.ggml_mul(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
  547. }
  548. }
  549. func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  550. return &Tensor{
  551. b: t.b,
  552. t: C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
  553. }
  554. }
  555. func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  556. mul := C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t)
  557. C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32)
  558. return &Tensor{
  559. b: t.b,
  560. t: mul,
  561. }
  562. }
  563. func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
  564. tt := (&Tensor{b: t.b, t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
  565. if b != nil {
  566. tt = tt.Add(ctx, b)
  567. }
  568. return tt
  569. }
  570. func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
  571. return (&Tensor{b: t.b, t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
  572. }
  573. func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
  574. if len(shape) != 4 {
  575. panic("expected 4 dimensions")
  576. }
  577. return &Tensor{
  578. b: t.b,
  579. t: C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
  580. }
  581. }
  582. func (t *Tensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
  583. if len(shape) != 4 {
  584. panic("expected 4 dimensions")
  585. }
  586. return &Tensor{
  587. b: t.b,
  588. t: C.ggml_permute(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
  589. }
  590. }
  591. func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  592. return &Tensor{
  593. b: t.b,
  594. t: C.ggml_get_rows(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
  595. }
  596. }
  597. func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  598. return &Tensor{
  599. b: t.b,
  600. t: C.ggml_cpy(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
  601. }
  602. }
  603. func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
  604. switch len(shape) {
  605. case 1:
  606. return &Tensor{
  607. b: t.b,
  608. t: C.ggml_reshape_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])),
  609. }
  610. case 2:
  611. return &Tensor{
  612. b: t.b,
  613. t: C.ggml_reshape_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])),
  614. }
  615. case 3:
  616. return &Tensor{
  617. b: t.b,
  618. t: C.ggml_reshape_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])),
  619. }
  620. case 4:
  621. return &Tensor{
  622. b: t.b,
  623. t: C.ggml_reshape_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])),
  624. }
  625. default:
  626. panic("unsupported number of dimensions")
  627. }
  628. }
  629. func (t *Tensor) Scale(ctx ml.Context, s float64) ml.Tensor {
  630. return &Tensor{
  631. b: t.b,
  632. t: C.ggml_scale(ctx.(*Context).ctx, t.t, (C.float)(s)),
  633. }
  634. }
  635. func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor {
  636. return &Tensor{
  637. b: t.b,
  638. t: C.ggml_soft_max(ctx.(*Context).ctx, t.t),
  639. }
  640. }
  641. func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor {
  642. return &Tensor{
  643. b: t.b,
  644. t: C.ggml_tanh_inplace(ctx.(*Context).ctx, t.t),
  645. }
  646. }
  647. func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
  648. if len(shape) != 4 {
  649. panic("expected 4 dimensions")
  650. }
  651. return &Tensor{
  652. b: t.b,
  653. t: C.ggml_unpad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
  654. }
  655. }
  656. func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
  657. switch len(shape) {
  658. case 1:
  659. return &Tensor{
  660. b: t.b,
  661. t: C.ggml_view_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.size_t(offset)),
  662. }
  663. case 3:
  664. return &Tensor{
  665. b: t.b,
  666. t: C.ggml_view_2d(ctx.(*Context).ctx, t.t,
  667. C.int64_t(shape[0]), C.int64_t(shape[2]),
  668. C.size_t(shape[1]),
  669. C.size_t(offset)),
  670. }
  671. case 5:
  672. return &Tensor{
  673. b: t.b,
  674. t: C.ggml_view_3d(ctx.(*Context).ctx, t.t,
  675. C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]),
  676. C.size_t(shape[1]), C.size_t(shape[3]),
  677. C.size_t(offset)),
  678. }
  679. case 7:
  680. return &Tensor{
  681. b: t.b,
  682. t: C.ggml_view_4d(ctx.(*Context).ctx, t.t,
  683. C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), C.int64_t(shape[6]),
  684. C.size_t(shape[1]), C.size_t(shape[3]), C.size_t(shape[5]),
  685. C.size_t(offset)),
  686. }
  687. default:
  688. panic("unsupported number of dimensions")
  689. }
  690. }
  691. const (
  692. ropeTypeNorm C.int = iota
  693. )
  694. func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, ropeBase, ropeScale float32) ml.Tensor {
  695. if ropeFactors == nil {
  696. ropeFactors = &Tensor{b: t.b}
  697. }
  698. dequant := t.t
  699. if C.ggml_is_quantized(t.t._type) {
  700. dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
  701. }
  702. return &Tensor{
  703. b: t.b,
  704. t: C.ggml_rope_ext(
  705. ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
  706. C.int(ropeDim),
  707. 131072, // YaRN n_ctx_train
  708. ropeTypeNorm, // ROPE_TYPE_NORM
  709. C.float(ropeBase),
  710. C.float(ropeScale),
  711. 0., // YaRN ext_factor
  712. 1., // YaRN attn_factor
  713. 32., // YaRN beta_fast
  714. 1., // YaRN beta_slow
  715. ),
  716. }
  717. }
  718. func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
  719. return &Tensor{
  720. b: t.b,
  721. t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
  722. }
  723. }
  724. func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
  725. return &Tensor{
  726. b: t.b,
  727. t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
  728. }
  729. }
  730. func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
  731. return &Tensor{
  732. b: t.b,
  733. t: C.ggml_conv_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1)),
  734. }
  735. }
  736. func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor {
  737. var kqMask *C.struct_ggml_tensor
  738. if mask != nil {
  739. kqMask = mask.(*Tensor).t
  740. }
  741. query := t.Permute(ctx, 0, 2, 1, 3)
  742. key = key.Permute(ctx, 0, 2, 1, 3)
  743. if t.b.flashAttention {
  744. value = value.Permute(ctx, 0, 2, 1, 3)
  745. kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
  746. C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32)
  747. return &Tensor{b: t.b, t: kqv}
  748. } else {
  749. kq := key.MulmatFullPrec(ctx, query)
  750. kq = &Tensor{
  751. b: t.b,
  752. t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
  753. }
  754. kqv := value.Mulmat(ctx, kq)
  755. return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  756. }
  757. }