ggml.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702
  1. package ggml
  2. /*
  3. #cgo CPPFLAGS: -I${SRCDIR}/ggml/include
  4. #include <stdlib.h>
  5. #include <stdint.h>
  6. #include "ggml.h"
  7. #include "ggml-cpu.h"
  8. #include "ggml-backend.h"
  9. static struct ggml_backend_feature * getBackendFeatures(void *fp, ggml_backend_reg_t reg) {return ((ggml_backend_get_features_t)(fp))(reg);}
  10. static struct ggml_backend_feature * getNextBackendFeatures(struct ggml_backend_feature * feature) { return &feature[1];}
  11. typedef enum {COMP_UNKNOWN,COMP_GCC,COMP_CLANG} COMPILER;
  12. COMPILER inline get_compiler() {
  13. #if defined(__clang__)
  14. return COMP_CLANG;
  15. #elif defined(__GNUC__)
  16. return COMP_GCC;
  17. #else
  18. return UNKNOWN_COMPILER;
  19. #endif
  20. }
  21. */
  22. import "C"
  23. import (
  24. "fmt"
  25. "io"
  26. "log/slog"
  27. "os"
  28. "sync"
  29. "unsafe"
  30. "github.com/ollama/ollama/format"
  31. fs "github.com/ollama/ollama/fs/ggml"
  32. "github.com/ollama/ollama/ml"
  33. "golang.org/x/sync/errgroup"
  34. ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
  35. )
  36. type device struct {
  37. d *C.struct_ggml_backend_device
  38. }
  39. func (d device) LogValue() slog.Value {
  40. var free, total uint64
  41. C.ggml_backend_dev_memory(d.d, (*C.size_t)(&free), (*C.size_t)(&total))
  42. kind := "unknown"
  43. switch C.ggml_backend_dev_type(d.d) {
  44. case C.GGML_BACKEND_DEVICE_TYPE_CPU:
  45. kind = "cpu"
  46. case C.GGML_BACKEND_DEVICE_TYPE_GPU:
  47. kind = "gpu"
  48. case C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
  49. kind = "accel"
  50. }
  51. return slog.GroupValue(
  52. slog.String("name", C.GoString(C.ggml_backend_dev_name(d.d))),
  53. slog.String("description", C.GoString(C.ggml_backend_dev_description(d.d))),
  54. slog.String("kind", kind),
  55. slog.String("free", format.HumanBytes2(free)),
  56. slog.String("total", format.HumanBytes2(total)),
  57. )
  58. }
  59. var devices = sync.OnceValue(func() []device {
  60. ggml.OnceLoad()
  61. s := make([]device, C.ggml_backend_dev_count())
  62. for i := range s {
  63. s[i] = device{C.ggml_backend_dev_get(C.size_t(i))}
  64. }
  65. return s
  66. })
  67. type Backend struct {
  68. meta *fs.GGML
  69. cpus, gpus []Context
  70. tensors map[string]*Context
  71. sched *C.struct_ggml_backend_sched
  72. }
  73. func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
  74. meta, n, err := fs.Decode(r, -1)
  75. if err != nil {
  76. return nil, err
  77. }
  78. slog.Info(
  79. "",
  80. "architecture", meta.KV().Architecture(),
  81. "file_type", meta.KV().FileType(),
  82. "name", meta.KV().String("general.name"),
  83. "description", meta.KV().String("general.description"),
  84. "num_tensors", len(meta.Tensors().Items()),
  85. "num_key_values", len(meta.KV()),
  86. )
  87. var cpus, gpus []Context
  88. for _, d := range devices() {
  89. switch C.ggml_backend_dev_type(d.d) {
  90. case C.GGML_BACKEND_DEVICE_TYPE_CPU,
  91. C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
  92. slog.Info("cpu", "device", d)
  93. cpus = append(cpus, Context{
  94. ctx: C.ggml_init(C.struct_ggml_init_params{
  95. mem_size: C.size_t(int(C.ggml_tensor_overhead()) * (len(meta.Tensors().Items()) + 1 + int(meta.KV().BlockCount())*2)),
  96. no_alloc: true,
  97. }),
  98. backend: C.ggml_backend_dev_init(d.d, nil),
  99. })
  100. case C.GGML_BACKEND_DEVICE_TYPE_GPU:
  101. slog.Info("gpu", "device", d)
  102. gpus = append(gpus, Context{
  103. ctx: C.ggml_init(C.struct_ggml_init_params{
  104. mem_size: C.size_t(int(C.ggml_tensor_overhead()) * (len(meta.Tensors().Items()) + 1 + int(meta.KV().BlockCount())*2)),
  105. no_alloc: true,
  106. }),
  107. backend: C.ggml_backend_dev_init(d.d, nil),
  108. })
  109. }
  110. }
  111. ctxFunc := func(s []Context) (*Context, error) {
  112. for _, e := range s {
  113. return &e, nil
  114. }
  115. return nil, fmt.Errorf("no devices available")
  116. }
  117. tensors := make(map[*fs.Tensor]*Context, len(meta.Tensors().Items()))
  118. for _, t := range meta.Tensors().Items() {
  119. c, err := ctxFunc(append(gpus, cpus...))
  120. if err != nil {
  121. return nil, err
  122. }
  123. func() {
  124. tt := C.ggml_new_tensor(c.ctx, t.Kind, C.int(len(t.Shape)), (*C.int64_t)(unsafe.Pointer(&t.Shape[0])))
  125. cname := C.CString(t.Name)
  126. defer C.free(unsafe.Pointer(cname))
  127. C.ggml_set_name(tt, cname)
  128. tensors[t] = c
  129. }()
  130. }
  131. for _, b := range append(gpus, cpus...) {
  132. C.ggml_backend_alloc_ctx_tensors(b.ctx, b.backend)
  133. }
  134. sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset))
  135. var g errgroup.Group
  136. for t, c := range tensors {
  137. g.Go(func() error {
  138. bts := make([]byte, t.Size())
  139. n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), bts)
  140. if err != nil {
  141. return err
  142. }
  143. if n != int(t.Size()) {
  144. return fmt.Errorf("expected %d bytes, got %d", t.Size(), n)
  145. }
  146. cname := C.CString(t.Name)
  147. defer C.free(unsafe.Pointer(cname))
  148. C.ggml_backend_tensor_set(C.ggml_get_tensor(c.ctx, cname), unsafe.Pointer(&bts[0]), 0, C.size_t(n))
  149. return nil
  150. })
  151. }
  152. if err := g.Wait(); err != nil {
  153. return nil, err
  154. }
  155. backends := make([]*C.struct_ggml_backend, len(gpus)+len(cpus))
  156. bufts := make([]*C.struct_ggml_backend_buffer_type, len(gpus)+len(cpus))
  157. for i, c := range append(gpus, cpus...) {
  158. backends[i] = c.backend
  159. bufts[i] = C.ggml_backend_get_default_buffer_type(c.backend)
  160. }
  161. return &Backend{
  162. meta: meta,
  163. cpus: cpus,
  164. gpus: gpus,
  165. sched: C.ggml_backend_sched_new(
  166. (*C.ggml_backend_t)(unsafe.Pointer(&backends[0])),
  167. (*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])),
  168. C.int(len(backends)),
  169. C.size_t(max(8192, len(meta.Tensors().Items())*5)),
  170. true,
  171. ),
  172. }, nil
  173. }
  174. func init() {
  175. ml.RegisterBackend("ggml", New)
  176. }
  177. func (b *Backend) Config() ml.Config {
  178. return b.meta.KV()
  179. }
  180. func (b *Backend) Get(name string) ml.Tensor {
  181. cname := C.CString(name)
  182. defer C.free(unsafe.Pointer(cname))
  183. for _, c := range append(b.gpus, b.cpus...) {
  184. if t := C.ggml_get_tensor(c.ctx, cname); t != nil {
  185. return &Tensor{t: t}
  186. }
  187. }
  188. return nil
  189. }
  190. func (b *Backend) NewContext() ml.Context {
  191. nodes := max(8192, len(b.meta.Tensors().Items())*5)
  192. c := C.ggml_init(C.struct_ggml_init_params{
  193. mem_buffer: nil,
  194. mem_size: C.size_t(nodes)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(nodes), false),
  195. no_alloc: true,
  196. })
  197. backends := make([]*C.struct_ggml_backend, len(b.gpus)+len(b.cpus))
  198. for i, c := range append(b.gpus, b.cpus...) {
  199. backends[i] = c.backend
  200. }
  201. return &Context{
  202. b: b,
  203. ctx: c,
  204. backend: backends[0],
  205. nodes: nodes,
  206. }
  207. }
  208. type Context struct {
  209. b *Backend
  210. ctx *C.struct_ggml_context
  211. backend *C.struct_ggml_backend
  212. graph *C.struct_ggml_cgraph
  213. nodes int
  214. }
  215. func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
  216. if c.graph == nil {
  217. c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.nodes), false)
  218. }
  219. for _, tensor := range tensors {
  220. C.ggml_build_forward_expand(c.graph, tensor.(*Tensor).t)
  221. }
  222. return c
  223. }
  224. func (c *Context) Compute(tensors ...ml.Tensor) {
  225. C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph)
  226. C.ggml_backend_sched_reset(c.b.sched)
  227. needSync := true
  228. sync := func() {
  229. if needSync {
  230. C.ggml_backend_sched_synchronize(c.b.sched)
  231. needSync = false
  232. }
  233. }
  234. for _, t := range tensors {
  235. if C.ggml_nbytes(t.(*Tensor).t) > 0 {
  236. t.(*Tensor).sync = sync
  237. }
  238. }
  239. }
  240. func (c *Context) MaxTensors() int {
  241. return c.nodes
  242. }
  243. func shapeToGGML(shape []int) *C.int64_t {
  244. sh := make([]C.int64_t, len(shape))
  245. for i, s := range shape {
  246. sh[i] = (C.int64_t)(s)
  247. }
  248. return &sh[0]
  249. }
  250. func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
  251. if len(shape) < 1 || len(shape) > 4 {
  252. panic("unsupported number of dimensions")
  253. }
  254. for _, dim := range shape {
  255. if dim < 1 {
  256. panic("invalid shape")
  257. }
  258. }
  259. var t *C.struct_ggml_tensor
  260. switch dtype {
  261. case ml.DTypeF32:
  262. t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape))
  263. case ml.DTypeF16:
  264. t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape))
  265. case ml.DTypeI32:
  266. t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape))
  267. default:
  268. panic("unsupported dtype")
  269. }
  270. b := C.ggml_backend_alloc_buffer(c.backend, C.ggml_nbytes(t))
  271. C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
  272. C.ggml_set_zero(t)
  273. return &Tensor{t: t}
  274. }
  275. func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype uint32) (ml.Tensor, error) {
  276. n := len(s)
  277. if n == 0 {
  278. var shape C.int64_t = 0
  279. t := C.ggml_new_tensor(ctx.ctx, dtype, 1, &shape)
  280. return &Tensor{t: t}, nil
  281. }
  282. for _, v := range shape {
  283. n /= v
  284. }
  285. if n != 1 {
  286. return nil, fmt.Errorf("invalid shape %v for %d elements", shape, len(s))
  287. }
  288. t := C.ggml_new_tensor(ctx.ctx, dtype, C.int(len(shape)), shapeToGGML(shape))
  289. b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t))
  290. C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
  291. C.ggml_backend_tensor_set(t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t))
  292. return &Tensor{t: t}, nil
  293. }
  294. func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
  295. return fromSlice(c, s, shape, C.GGML_TYPE_F32)
  296. }
  297. func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
  298. return fromSlice(c, s, shape, C.GGML_TYPE_I32)
  299. }
  300. func (c *Context) Close() {
  301. if c != nil {
  302. C.ggml_free(c.ctx)
  303. }
  304. }
  305. type Tensor struct {
  306. t *C.struct_ggml_tensor
  307. sync func()
  308. }
  309. func (t *Tensor) LogValue() slog.Value {
  310. return slog.GroupValue(
  311. slog.String("name", C.GoString(C.ggml_get_name(t.t))),
  312. slog.String("type", C.GoString(C.ggml_type_name(t.t._type))),
  313. slog.Any("shape", t.Shape()),
  314. )
  315. }
  316. func (t *Tensor) Dim(n int) int {
  317. return int(t.t.ne[n])
  318. }
  319. func (t *Tensor) Stride(n int) int {
  320. return int(t.t.nb[n])
  321. }
  322. func (t *Tensor) Shape() []int {
  323. shape := make([]int, C.ggml_n_dims(t.t))
  324. for i := range shape {
  325. shape[i] = t.Dim(i)
  326. }
  327. return shape
  328. }
  329. func (t *Tensor) Bytes() (data []byte) {
  330. if t.sync != nil {
  331. data = make([]byte, C.ggml_nbytes(t.t))
  332. t.sync()
  333. C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
  334. }
  335. return
  336. }
  337. func (t *Tensor) Floats() (data []float32) {
  338. if t.sync != nil {
  339. data = make([]float32, C.ggml_nelements(t.t))
  340. t.sync()
  341. C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
  342. }
  343. return
  344. }
  345. func (t *Tensor) DType() ml.DType {
  346. switch t.t._type {
  347. case C.GGML_TYPE_F32:
  348. return ml.DTypeF32
  349. case C.GGML_TYPE_F16:
  350. return ml.DTypeF16
  351. case C.GGML_TYPE_I32:
  352. return ml.DTypeI32
  353. default:
  354. return ml.DTypeOther
  355. }
  356. }
  357. func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  358. return &Tensor{
  359. t: C.ggml_add(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
  360. }
  361. }
  362. func (t *Tensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
  363. if len(s) > 0 {
  364. return t.Concat(ctx, s[0].Stack(ctx, dim, s[1:]...), dim)
  365. }
  366. return t
  367. }
  368. func (t *Tensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
  369. return &Tensor{
  370. t: C.ggml_concat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(dim)),
  371. }
  372. }
  373. func (t *Tensor) Contiguous(ctx ml.Context) ml.Tensor {
  374. return &Tensor{
  375. t: C.ggml_cont(ctx.(*Context).ctx, t.t),
  376. }
  377. }
  378. func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  379. return &Tensor{
  380. t: C.ggml_mul(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
  381. }
  382. }
  383. func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  384. return &Tensor{
  385. t: C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
  386. }
  387. }
  388. func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  389. mul := C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t)
  390. C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32)
  391. return &Tensor{
  392. t: mul,
  393. }
  394. }
  395. func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
  396. tt := (&Tensor{t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
  397. if b != nil {
  398. tt = tt.Add(ctx, b)
  399. }
  400. return tt
  401. }
  402. func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
  403. return (&Tensor{t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
  404. }
  405. func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
  406. if len(shape) != 4 {
  407. panic("expected 4 dimensions")
  408. }
  409. return &Tensor{
  410. t: C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
  411. }
  412. }
  413. func (t *Tensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
  414. if len(shape) != 4 {
  415. panic("expected 4 dimensions")
  416. }
  417. return &Tensor{
  418. t: C.ggml_permute(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
  419. }
  420. }
  421. func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  422. return &Tensor{
  423. t: C.ggml_get_rows(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
  424. }
  425. }
  426. func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
  427. return &Tensor{
  428. t: C.ggml_cpy(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
  429. }
  430. }
  431. func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
  432. switch len(shape) {
  433. case 1:
  434. return &Tensor{
  435. t: C.ggml_reshape_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])),
  436. }
  437. case 2:
  438. return &Tensor{
  439. t: C.ggml_reshape_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])),
  440. }
  441. case 3:
  442. return &Tensor{
  443. t: C.ggml_reshape_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])),
  444. }
  445. case 4:
  446. return &Tensor{
  447. t: C.ggml_reshape_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])),
  448. }
  449. default:
  450. panic("unsupported number of dimensions")
  451. }
  452. }
  453. func (t *Tensor) Scale(ctx ml.Context, s float64) ml.Tensor {
  454. return &Tensor{
  455. t: C.ggml_scale(ctx.(*Context).ctx, t.t, (C.float)(s)),
  456. }
  457. }
  458. func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor {
  459. return &Tensor{
  460. t: C.ggml_soft_max(ctx.(*Context).ctx, t.t),
  461. }
  462. }
  463. func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor {
  464. return &Tensor{
  465. t: C.ggml_tanh_inplace(ctx.(*Context).ctx, t.t),
  466. }
  467. }
  468. func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
  469. if len(shape) != 4 {
  470. panic("expected 4 dimensions")
  471. }
  472. return &Tensor{
  473. t: C.ggml_unpad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
  474. }
  475. }
  476. func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
  477. switch len(shape) {
  478. case 1:
  479. return &Tensor{
  480. t: C.ggml_view_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.size_t(offset)),
  481. }
  482. case 3:
  483. return &Tensor{
  484. t: C.ggml_view_2d(ctx.(*Context).ctx, t.t,
  485. C.int64_t(shape[0]), C.int64_t(shape[2]),
  486. C.size_t(shape[1]),
  487. C.size_t(offset)),
  488. }
  489. case 5:
  490. return &Tensor{
  491. t: C.ggml_view_3d(ctx.(*Context).ctx, t.t,
  492. C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]),
  493. C.size_t(shape[1]), C.size_t(shape[3]),
  494. C.size_t(offset)),
  495. }
  496. case 7:
  497. return &Tensor{
  498. t: C.ggml_view_4d(ctx.(*Context).ctx, t.t,
  499. C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), C.int64_t(shape[6]),
  500. C.size_t(shape[1]), C.size_t(shape[3]), C.size_t(shape[5]),
  501. C.size_t(offset)),
  502. }
  503. default:
  504. panic("unsupported number of dimensions")
  505. }
  506. }
  507. const (
  508. ropeTypeNorm C.int = iota
  509. )
  510. func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, ropeBase, ropeScale float32) ml.Tensor {
  511. if ropeFactors == nil {
  512. ropeFactors = &Tensor{}
  513. }
  514. dequant := t.t
  515. if C.ggml_is_quantized(t.t._type) {
  516. dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
  517. }
  518. return &Tensor{
  519. t: C.ggml_rope_ext(
  520. ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
  521. C.int(ropeDim),
  522. 131072, // YaRN n_ctx_train
  523. ropeTypeNorm, // ROPE_TYPE_NORM
  524. C.float(ropeBase),
  525. C.float(ropeScale),
  526. 0., // YaRN ext_factor
  527. 1., // YaRN attn_factor
  528. 32., // YaRN beta_fast
  529. 1., // YaRN beta_slow
  530. ),
  531. }
  532. }
  533. func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
  534. return &Tensor{
  535. t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
  536. }
  537. }
  538. func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
  539. return &Tensor{
  540. t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
  541. }
  542. }
  543. func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
  544. return &Tensor{
  545. t: C.ggml_conv_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1)),
  546. }
  547. }
  548. func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor {
  549. var kqMask *C.struct_ggml_tensor
  550. if mask != nil {
  551. kqMask = mask.(*Tensor).t
  552. }
  553. kq := key.MulmatFullPrec(ctx, t)
  554. kq = &Tensor{
  555. t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
  556. }
  557. kqv := value.Mulmat(ctx, kq)
  558. return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  559. }
  560. func (b *Backend) SystemInfo() string {
  561. var compiler string
  562. switch C.get_compiler() {
  563. case C.COMP_UNKNOWN:
  564. compiler = "cgo(unknown_compiler)"
  565. case C.COMP_GCC:
  566. compiler = "cgo(gcc)"
  567. case C.COMP_CLANG:
  568. compiler = "cgo(clang)"
  569. }
  570. var s string
  571. for i := range C.ggml_backend_reg_count() {
  572. reg := C.ggml_backend_reg_get(i)
  573. fName := C.CString("ggml_backend_get_features")
  574. defer C.free(unsafe.Pointer(fName))
  575. get_features_fn := C.ggml_backend_reg_get_proc_address(reg, fName)
  576. if get_features_fn != nil {
  577. s += C.GoString(C.ggml_backend_reg_name(reg))
  578. s += " : "
  579. for features := C.getBackendFeatures(get_features_fn, reg); features.name != nil; features = C.getNextBackendFeatures(features) {
  580. s += C.GoString(features.name)
  581. s += " = "
  582. s += C.GoString(features.value)
  583. s += " | "
  584. }
  585. }
  586. }
  587. return s + compiler
  588. }