ggml.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762
  1. package ggml
  2. /*
  3. #cgo CPPFLAGS: -I${SRCDIR}/ggml/include
  4. #include <stdlib.h>
  5. #include <stdint.h>
  6. #include "ggml.h"
  7. #include "ggml-cpu.h"
  8. #include "ggml-backend.h"
  9. static struct ggml_backend_feature * getBackendFeatures(void *fp, ggml_backend_reg_t reg) {return ((ggml_backend_get_features_t)(fp))(reg);}
  10. static struct ggml_backend_feature * getNextBackendFeatures(struct ggml_backend_feature * feature) { return &feature[1];}
  11. typedef enum {COMP_UNKNOWN,COMP_GCC,COMP_CLANG} COMPILER;
  12. COMPILER inline get_compiler() {
  13. #if defined(__clang__)
  14. return COMP_CLANG;
  15. #elif defined(__GNUC__)
  16. return COMP_GCC;
  17. #else
  18. return UNKNOWN_COMPILER;
  19. #endif
  20. }
  21. */
  22. import "C"
  23. import (
  24. "fmt"
  25. "io"
  26. "log/slog"
  27. "os"
  28. "sync"
  29. "unsafe"
  30. "github.com/ollama/ollama/format"
  31. fs "github.com/ollama/ollama/fs/ggml"
  32. "github.com/ollama/ollama/ml"
  33. "golang.org/x/sync/errgroup"
  34. ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
  35. )
  36. type device struct {
  37. d *C.struct_ggml_backend_device
  38. }
  39. func (d device) LogValue() slog.Value {
  40. var free, total uint64
  41. C.ggml_backend_dev_memory(d.d, (*C.size_t)(&free), (*C.size_t)(&total))
  42. kind := "unknown"
  43. switch C.ggml_backend_dev_type(d.d) {
  44. case C.GGML_BACKEND_DEVICE_TYPE_CPU:
  45. kind = "cpu"
  46. case C.GGML_BACKEND_DEVICE_TYPE_GPU:
  47. kind = "gpu"
  48. case C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
  49. kind = "accel"
  50. }
  51. return slog.GroupValue(
  52. slog.String("name", C.GoString(C.ggml_backend_dev_name(d.d))),
  53. slog.String("description", C.GoString(C.ggml_backend_dev_description(d.d))),
  54. slog.String("kind", kind),
  55. slog.String("free", format.HumanBytes2(free)),
  56. slog.String("total", format.HumanBytes2(total)),
  57. )
  58. }
  59. var devices = sync.OnceValue(func() []device {
  60. ggml.OnceLoad()
  61. s := make([]device, C.ggml_backend_dev_count())
  62. for i := range s {
  63. s[i] = device{C.ggml_backend_dev_get(C.size_t(i))}
  64. }
  65. return s
  66. })
  67. type Backend struct {
  68. flashAttention bool
  69. meta *fs.GGML
  70. cpus, gpus []Context
  71. tensors map[string]*Context
  72. sched *C.struct_ggml_backend_sched
  73. }
  74. func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
  75. meta, n, err := fs.Decode(r, -1)
  76. if err != nil {
  77. return nil, err
  78. }
  79. slog.Info(
  80. "",
  81. "architecture", meta.KV().Architecture(),
  82. "file_type", meta.KV().FileType(),
  83. "name", meta.KV().String("general.name"),
  84. "description", meta.KV().String("general.description"),
  85. "num_tensors", len(meta.Tensors().Items()),
  86. "num_key_values", len(meta.KV()),
  87. )
  88. var cpus, gpus []Context
  89. for _, d := range devices() {
  90. switch C.ggml_backend_dev_type(d.d) {
  91. case C.GGML_BACKEND_DEVICE_TYPE_CPU,
  92. C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
  93. slog.Info("cpu", "device", d)
  94. cpus = append(cpus, Context{
  95. ctx: C.ggml_init(C.struct_ggml_init_params{
  96. mem_size: C.size_t(int(C.ggml_tensor_overhead()) * (len(meta.Tensors().Items()) + 1 + int(meta.KV().BlockCount())*2)),
  97. no_alloc: true,
  98. }),
  99. backend: C.ggml_backend_dev_init(d.d, nil),
  100. })
  101. case C.GGML_BACKEND_DEVICE_TYPE_GPU:
  102. slog.Info("gpu", "device", d)
  103. gpus = append(gpus, Context{
  104. ctx: C.ggml_init(C.struct_ggml_init_params{
  105. mem_size: C.size_t(int(C.ggml_tensor_overhead()) * (len(meta.Tensors().Items()) + 1 + int(meta.KV().BlockCount())*2)),
  106. no_alloc: true,
  107. }),
  108. backend: C.ggml_backend_dev_init(d.d, nil),
  109. })
  110. }
  111. }
  112. ctxFunc := func(s []Context) (*Context, error) {
  113. for _, e := range s {
  114. return &e, nil
  115. }
  116. return nil, fmt.Errorf("no devices available")
  117. }
  118. tensors := make(map[*fs.Tensor]*Context, len(meta.Tensors().Items()))
  119. for _, t := range meta.Tensors().Items() {
  120. c, err := ctxFunc(append(gpus, cpus...))
  121. if err != nil {
  122. return nil, err
  123. }
  124. func() {
  125. tt := C.ggml_new_tensor(c.ctx, t.Kind, C.int(len(t.Shape)), (*C.int64_t)(unsafe.Pointer(&t.Shape[0])))
  126. cname := C.CString(t.Name)
  127. defer C.free(unsafe.Pointer(cname))
  128. C.ggml_set_name(tt, cname)
  129. tensors[t] = c
  130. }()
  131. }
  132. for _, b := range append(gpus, cpus...) {
  133. C.ggml_backend_alloc_ctx_tensors(b.ctx, b.backend)
  134. }
  135. sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset))
  136. var g errgroup.Group
  137. for t, c := range tensors {
  138. g.Go(func() error {
  139. bts := make([]byte, t.Size())
  140. n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), bts)
  141. if err != nil {
  142. return err
  143. }
  144. if n != int(t.Size()) {
  145. return fmt.Errorf("expected %d bytes, got %d", t.Size(), n)
  146. }
  147. cname := C.CString(t.Name)
  148. defer C.free(unsafe.Pointer(cname))
  149. C.ggml_backend_tensor_set(C.ggml_get_tensor(c.ctx, cname), unsafe.Pointer(&bts[0]), 0, C.size_t(n))
  150. return nil
  151. })
  152. }
  153. if err := g.Wait(); err != nil {
  154. return nil, err
  155. }
  156. backends := make([]*C.struct_ggml_backend, len(gpus)+len(cpus))
  157. bufts := make([]*C.struct_ggml_backend_buffer_type, len(gpus)+len(cpus))
  158. for i, c := range append(gpus, cpus...) {
  159. backends[i] = c.backend
  160. bufts[i] = C.ggml_backend_get_default_buffer_type(c.backend)
  161. }
  162. return &Backend{
  163. flashAttention: params.FlashAttention,
  164. meta: meta,
  165. cpus: cpus,
  166. gpus: gpus,
  167. sched: C.ggml_backend_sched_new(
  168. (*C.ggml_backend_t)(unsafe.Pointer(&backends[0])),
  169. (*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])),
  170. C.int(len(backends)),
  171. C.size_t(max(8192, len(meta.Tensors().Items())*5)),
  172. true,
  173. ),
  174. }, nil
  175. }
  176. func init() {
  177. ml.RegisterBackend("ggml", New)
  178. }
  179. func (b *Backend) Config() ml.Config {
  180. return b.meta.KV()
  181. }
  182. func (b *Backend) Get(name string) ml.Tensor {
  183. cname := C.CString(name)
  184. defer C.free(unsafe.Pointer(cname))
  185. for _, c := range append(b.gpus, b.cpus...) {
  186. if t := C.ggml_get_tensor(c.ctx, cname); t != nil {
  187. return &Tensor{b: b, t: t}
  188. }
  189. }
  190. return nil
  191. }
  192. func (b *Backend) NewContext() ml.Context {
  193. nodes := max(8192, len(b.meta.Tensors().Items())*5)
  194. c := C.ggml_init(C.struct_ggml_init_params{
  195. mem_buffer: nil,
  196. mem_size: C.size_t(nodes)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(nodes), false),
  197. no_alloc: true,
  198. })
  199. backends := make([]*C.struct_ggml_backend, len(b.gpus)+len(b.cpus))
  200. for i, c := range append(b.gpus, b.cpus...) {
  201. backends[i] = c.backend
  202. }
  203. return &Context{
  204. b: b,
  205. ctx: c,
  206. backend: backends[0],
  207. nodes: nodes,
  208. }
  209. }
  210. func (b *Backend) CacheConfig() ml.CacheConfig {
  211. if b.flashAttention {
  212. return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD}
  213. } else {
  214. return ml.CacheConfig{CachePadding: 32, PermutedV: true}
  215. }
  216. }
  217. type Context struct {
  218. b *Backend
  219. ctx *C.struct_ggml_context
  220. backend *C.struct_ggml_backend
  221. graph *C.struct_ggml_cgraph
  222. nodes int
  223. }
  224. func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
  225. if c.graph == nil {
  226. c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.nodes), false)
  227. }
  228. for _, tensor := range tensors {
  229. C.ggml_build_forward_expand(c.graph, tensor.(*Tensor).t)
  230. }
  231. return c
  232. }
  233. func (c *Context) Compute(tensors ...ml.Tensor) {
  234. C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph)
  235. C.ggml_backend_sched_reset(c.b.sched)
  236. needSync := true
  237. sync := func() {
  238. if needSync {
  239. C.ggml_backend_sched_synchronize(c.b.sched)
  240. needSync = false
  241. }
  242. }
  243. for _, t := range tensors {
  244. if C.ggml_nbytes(t.(*Tensor).t) > 0 {
  245. t.(*Tensor).sync = sync
  246. }
  247. }
  248. }
  249. func (c *Context) MaxTensors() int {
  250. return c.nodes
  251. }
  252. func shapeToGGML(shape []int) *C.int64_t {
  253. sh := make([]C.int64_t, len(shape))
  254. for i, s := range shape {
  255. sh[i] = (C.int64_t)(s)
  256. }
  257. return &sh[0]
  258. }
  259. func newTensor(ctx Context, dtype ml.DType, zero bool, shape []int) ml.Tensor {
  260. if len(shape) < 1 || len(shape) > 4 {
  261. panic("unsupported number of dimensions")
  262. }
  263. for _, dim := range shape {
  264. if dim < 1 {
  265. panic("invalid shape")
  266. }
  267. }
  268. var t *C.struct_ggml_tensor
  269. switch dtype {
  270. case ml.DTypeF32:
  271. t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape))
  272. case ml.DTypeF16:
  273. t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape))
  274. case ml.DTypeI32:
  275. t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape))
  276. default:
  277. panic("unsupported dtype")
  278. }
  279. b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t))
  280. C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
  281. if zero {
  282. C.ggml_set_zero(t)
  283. }
  284. return &Tensor{b: ctx.b, t: t}
  285. }
  286. func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
  287. return newTensor(c, dtype, false, shape)
  288. }
  289. func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
  290. return newTensor(c, dtype, true, shape)
  291. }
  292. func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype uint32) (ml.Tensor, error) {
  293. n := len(s)
  294. if n == 0 {
  295. var shape C.int64_t = 0
  296. t := C.ggml_new_tensor(ctx.ctx, dtype, 1, &shape)
  297. return &Tensor{b: ctx.b, t: t}, nil
  298. }
  299. for _, v := range shape {
  300. n /= v
  301. }
  302. if n != 1 {
  303. return nil, fmt.Errorf("invalid shape %v for %d elements", shape, len(s))
  304. }
  305. t := C.ggml_new_tensor(ctx.ctx, dtype, C.int(len(shape)), shapeToGGML(shape))
  306. b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t))
  307. C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
  308. C.ggml_backend_tensor_set(t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t))
  309. return &Tensor{b: ctx.b, t: t}, nil
  310. }
  311. func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
  312. return fromSlice(c, s, shape, C.GGML_TYPE_F32)
  313. }
  314. func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
  315. return fromSlice(c, s, shape, C.GGML_TYPE_I32)
  316. }
  317. func (c *Context) Close() {
  318. if c != nil {
  319. C.ggml_free(c.ctx)
  320. }
  321. }
  322. type Tensor struct {
  323. b *Backend
  324. t *C.struct_ggml_tensor
  325. sync func()
  326. }
  327. func (t *Tensor) LogValue() slog.Value {
  328. return slog.GroupValue(
  329. slog.String("name", C.GoString(C.ggml_get_name(t.t))),
  330. slog.String("type", C.GoString(C.ggml_type_name(t.t._type))),
  331. slog.Any("shape", t.Shape()),
  332. )
  333. }
  334. func (t *Tensor) Dim(n int) int {
  335. return int(t.t.ne[n])
  336. }
  337. func (t *Tensor) Stride(n int) int {
  338. return int(t.t.nb[n])
  339. }
  340. func (t *Tensor) Shape() []int {
  341. shape := make([]int, C.ggml_n_dims(t.t))
  342. for i := range shape {
  343. shape[i] = t.Dim(i)
  344. }
  345. return shape
  346. }
  347. func (t *Tensor) Bytes() (data []byte) {
  348. if t.sync != nil {
  349. data = make([]byte, C.ggml_nbytes(t.t))
  350. t.sync()
  351. C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
  352. }
  353. return
  354. }
  355. func (t *Tensor) Floats() (data []float32) {
  356. if t.sync != nil {
  357. data = make([]float32, C.ggml_nelements(t.t))
  358. t.sync()
  359. C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
  360. }
  361. return
  362. }
  363. func (t *Tensor) DType() ml.DType {
  364. switch t.t._type {
  365. case C.GGML_TYPE_F32:
  366. return ml.DTypeF32
  367. case C.GGML_TYPE_F16:
  368. return ml.DTypeF16
  369. case C.GGML_TYPE_I32:
  370. return ml.DTypeI32
  371. default:
  372. return ml.DTypeOther
  373. }
  374. }
  375. func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  376. return &Tensor{
  377. b: t.b,
  378. t: C.ggml_add(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
  379. }
  380. }
  381. func (t *Tensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
  382. if len(s) > 0 {
  383. return t.Concat(ctx, s[0].Stack(ctx, dim, s[1:]...), dim)
  384. }
  385. return t
  386. }
  387. func (t *Tensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
  388. return &Tensor{
  389. b: t.b,
  390. t: C.ggml_concat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(dim)),
  391. }
  392. }
  393. func (t *Tensor) Contiguous(ctx ml.Context) ml.Tensor {
  394. return &Tensor{
  395. b: t.b,
  396. t: C.ggml_cont(ctx.(*Context).ctx, t.t),
  397. }
  398. }
  399. func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  400. return &Tensor{
  401. b: t.b,
  402. t: C.ggml_mul(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
  403. }
  404. }
  405. func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  406. return &Tensor{
  407. b: t.b,
  408. t: C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
  409. }
  410. }
  411. func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  412. mul := C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t)
  413. C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32)
  414. return &Tensor{
  415. b: t.b,
  416. t: mul,
  417. }
  418. }
  419. func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
  420. tt := (&Tensor{b: t.b, t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
  421. if b != nil {
  422. tt = tt.Add(ctx, b)
  423. }
  424. return tt
  425. }
  426. func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
  427. return (&Tensor{b: t.b, t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
  428. }
  429. func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
  430. if len(shape) != 4 {
  431. panic("expected 4 dimensions")
  432. }
  433. return &Tensor{
  434. b: t.b,
  435. t: C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
  436. }
  437. }
  438. func (t *Tensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
  439. if len(shape) != 4 {
  440. panic("expected 4 dimensions")
  441. }
  442. return &Tensor{
  443. b: t.b,
  444. t: C.ggml_permute(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
  445. }
  446. }
  447. func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  448. return &Tensor{
  449. b: t.b,
  450. t: C.ggml_get_rows(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
  451. }
  452. }
  453. func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  454. return &Tensor{
  455. b: t.b,
  456. t: C.ggml_cpy(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
  457. }
  458. }
  459. func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
  460. switch len(shape) {
  461. case 1:
  462. return &Tensor{
  463. b: t.b,
  464. t: C.ggml_reshape_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])),
  465. }
  466. case 2:
  467. return &Tensor{
  468. b: t.b,
  469. t: C.ggml_reshape_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])),
  470. }
  471. case 3:
  472. return &Tensor{
  473. b: t.b,
  474. t: C.ggml_reshape_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])),
  475. }
  476. case 4:
  477. return &Tensor{
  478. b: t.b,
  479. t: C.ggml_reshape_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])),
  480. }
  481. default:
  482. panic("unsupported number of dimensions")
  483. }
  484. }
  485. func (t *Tensor) Scale(ctx ml.Context, s float64) ml.Tensor {
  486. return &Tensor{
  487. b: t.b,
  488. t: C.ggml_scale(ctx.(*Context).ctx, t.t, (C.float)(s)),
  489. }
  490. }
  491. func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor {
  492. return &Tensor{
  493. b: t.b,
  494. t: C.ggml_soft_max(ctx.(*Context).ctx, t.t),
  495. }
  496. }
  497. func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor {
  498. return &Tensor{
  499. b: t.b,
  500. t: C.ggml_tanh_inplace(ctx.(*Context).ctx, t.t),
  501. }
  502. }
  503. func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
  504. if len(shape) != 4 {
  505. panic("expected 4 dimensions")
  506. }
  507. return &Tensor{
  508. b: t.b,
  509. t: C.ggml_unpad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
  510. }
  511. }
  512. func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
  513. switch len(shape) {
  514. case 1:
  515. return &Tensor{
  516. b: t.b,
  517. t: C.ggml_view_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.size_t(offset)),
  518. }
  519. case 3:
  520. return &Tensor{
  521. b: t.b,
  522. t: C.ggml_view_2d(ctx.(*Context).ctx, t.t,
  523. C.int64_t(shape[0]), C.int64_t(shape[2]),
  524. C.size_t(shape[1]),
  525. C.size_t(offset)),
  526. }
  527. case 5:
  528. return &Tensor{
  529. b: t.b,
  530. t: C.ggml_view_3d(ctx.(*Context).ctx, t.t,
  531. C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]),
  532. C.size_t(shape[1]), C.size_t(shape[3]),
  533. C.size_t(offset)),
  534. }
  535. case 7:
  536. return &Tensor{
  537. b: t.b,
  538. t: C.ggml_view_4d(ctx.(*Context).ctx, t.t,
  539. C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), C.int64_t(shape[6]),
  540. C.size_t(shape[1]), C.size_t(shape[3]), C.size_t(shape[5]),
  541. C.size_t(offset)),
  542. }
  543. default:
  544. panic("unsupported number of dimensions")
  545. }
  546. }
  547. const (
  548. ropeTypeNorm C.int = iota
  549. )
  550. func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, ropeBase, ropeScale float32) ml.Tensor {
  551. if ropeFactors == nil {
  552. ropeFactors = &Tensor{b: t.b}
  553. }
  554. dequant := t.t
  555. if C.ggml_is_quantized(t.t._type) {
  556. dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
  557. }
  558. return &Tensor{
  559. b: t.b,
  560. t: C.ggml_rope_ext(
  561. ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
  562. C.int(ropeDim),
  563. 131072, // YaRN n_ctx_train
  564. ropeTypeNorm, // ROPE_TYPE_NORM
  565. C.float(ropeBase),
  566. C.float(ropeScale),
  567. 0., // YaRN ext_factor
  568. 1., // YaRN attn_factor
  569. 32., // YaRN beta_fast
  570. 1., // YaRN beta_slow
  571. ),
  572. }
  573. }
  574. func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
  575. return &Tensor{
  576. b: t.b,
  577. t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
  578. }
  579. }
  580. func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
  581. return &Tensor{
  582. b: t.b,
  583. t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
  584. }
  585. }
  586. func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
  587. return &Tensor{
  588. b: t.b,
  589. t: C.ggml_conv_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1)),
  590. }
  591. }
  592. func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor {
  593. var kqMask *C.struct_ggml_tensor
  594. if mask != nil {
  595. kqMask = mask.(*Tensor).t
  596. }
  597. query := t.Permute(ctx, 0, 2, 1, 3)
  598. key = key.Permute(ctx, 0, 2, 1, 3)
  599. if t.b.flashAttention {
  600. value = value.Permute(ctx, 0, 2, 1, 3)
  601. kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
  602. C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32)
  603. return &Tensor{b: t.b, t: kqv}
  604. } else {
  605. kq := key.MulmatFullPrec(ctx, query)
  606. kq = &Tensor{
  607. b: t.b,
  608. t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
  609. }
  610. kqv := value.Mulmat(ctx, kq)
  611. return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  612. }
  613. }
  614. func (b *Backend) SystemInfo() string {
  615. var compiler string
  616. switch C.get_compiler() {
  617. case C.COMP_UNKNOWN:
  618. compiler = "cgo(unknown_compiler)"
  619. case C.COMP_GCC:
  620. compiler = "cgo(gcc)"
  621. case C.COMP_CLANG:
  622. compiler = "cgo(clang)"
  623. }
  624. var s string
  625. for i := range C.ggml_backend_reg_count() {
  626. reg := C.ggml_backend_reg_get(i)
  627. fName := C.CString("ggml_backend_get_features")
  628. defer C.free(unsafe.Pointer(fName))
  629. get_features_fn := C.ggml_backend_reg_get_proc_address(reg, fName)
  630. if get_features_fn != nil {
  631. s += C.GoString(C.ggml_backend_reg_name(reg))
  632. s += " : "
  633. for features := C.getBackendFeatures(get_features_fn, reg); features.name != nil; features = C.getNextBackendFeatures(features) {
  634. s += C.GoString(features.name)
  635. s += " = "
  636. s += C.GoString(features.value)
  637. s += " | "
  638. }
  639. }
  640. }
  641. return s + compiler
  642. }