mlx.go 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. package mlx
  2. // #cgo CPPFLAGS: -I${SRCDIR}/../../../build/_deps/mlx-c-src
  3. // #cgo LDFLAGS: -L${SRCDIR}/../../../build/lib -lmlxc -lmlx
  4. // #cgo LDFLAGS: -framework Accelerate
  5. // #cgo LDFLAGS: -Wl,-rpath,${SRCDIR}/../../../build/lib
  6. // #include <stdlib.h>
  7. // #include "mlx/c/array.h"
  8. // #include "mlx/c/fast.h"
  9. // #include "mlx/c/ops.h"
  10. // #include "mlx/c/stream.h"
  11. import "C"
  12. import (
  13. "bytes"
  14. "fmt"
  15. "io"
  16. "log/slog"
  17. "os"
  18. "sync"
  19. "unsafe"
  20. fs "github.com/ollama/ollama/fs/ggml"
  21. "github.com/ollama/ollama/ml"
  22. "golang.org/x/sync/errgroup"
  23. )
  24. func init() {
  25. ml.RegisterBackend("mlx", New)
  26. }
  27. func New(r *os.File) (ml.Backend, error) {
  28. meta, n, err := fs.Decode(r, -1)
  29. if err != nil {
  30. return nil, err
  31. }
  32. tensors := make(map[string]*Array, len(meta.Tensors().Items()))
  33. sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset))
  34. stream := C.mlx_default_cpu_stream_new()
  35. var g errgroup.Group
  36. var mu sync.Mutex
  37. for _, t := range meta.Tensors().Items() {
  38. g.Go(func() error {
  39. var b bytes.Buffer
  40. n, err := io.Copy(&b, io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())))
  41. if err != nil {
  42. return err
  43. }
  44. if n != int64(t.Size()) {
  45. return fmt.Errorf("expected %d bytes, got %d", t.Size(), n)
  46. }
  47. cbytes := C.CBytes(b.Bytes())
  48. defer C.free(cbytes)
  49. shape := make([]C.int, len(t.Shape))
  50. for i, dim := range t.Shape {
  51. shape[i] = C.int(dim)
  52. }
  53. var dtype C.mlx_dtype
  54. switch t.Kind {
  55. case 0:
  56. dtype = C.MLX_FLOAT32
  57. case 1:
  58. dtype = C.MLX_FLOAT16
  59. default:
  60. return fmt.Errorf("unsupported dtype %d", t.Kind)
  61. }
  62. mu.Lock()
  63. defer mu.Unlock()
  64. var a C.mlx_array
  65. C.mlx_transpose_all(
  66. &a,
  67. C.mlx_array_new_data(
  68. cbytes,
  69. (*C.int)(&shape[0]),
  70. C.int(len(shape)),
  71. dtype,
  72. ),
  73. stream,
  74. )
  75. tensors[t.Name] = &Array{
  76. name: t.Name,
  77. a: a,
  78. }
  79. return nil
  80. })
  81. }
  82. if err := g.Wait(); err != nil {
  83. return nil, err
  84. }
  85. return &Backend{
  86. meta: meta,
  87. tensors: tensors,
  88. }, nil
  89. }
  90. type Backend struct {
  91. meta *fs.GGML
  92. tensors map[string]*Array
  93. }
  94. // Config implements ml.Backend.
  95. func (b *Backend) Config() ml.Config {
  96. return b.meta.KV()
  97. }
  98. // Get implements ml.Backend.
  99. func (b *Backend) Get(name string) ml.Tensor {
  100. if a, ok := b.tensors[name]; ok {
  101. return a
  102. }
  103. return nil
  104. }
  105. func (b *Backend) NewContext() ml.Context {
  106. return &Context{
  107. stream: C.mlx_default_cpu_stream_new(),
  108. }
  109. }
  110. type Context struct {
  111. stream C.mlx_stream
  112. }
  113. // Close implements ml.Context.
  114. func (c *Context) Close() error {
  115. panic("unimplemented")
  116. }
  117. // Compute implements ml.Context.
  118. func (c *Context) Compute(ml.Tensor) ml.Tensor {
  119. panic("unimplemented")
  120. }
  121. // Forward implements ml.Context.
  122. func (c *Context) Forward(ml.Tensor) {
  123. panic("unimplemented")
  124. }
  125. // FromFloatSlice implements ml.Context.
  126. func (c *Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
  127. panic("unimplemented")
  128. }
  129. // FromIntSlice implements ml.Context.
  130. func (c *Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
  131. cshape := make([]C.int, len(shape))
  132. for i, dim := range shape {
  133. cshape[i] = C.int(dim)
  134. }
  135. return &Array{
  136. a: C.mlx_array_new_data(
  137. unsafe.Pointer(&s[0]),
  138. (*C.int)(&cshape[0]),
  139. C.int(len(cshape)),
  140. C.MLX_INT32,
  141. ),
  142. }, nil
  143. }
  144. // Zeros implements ml.Context.
  145. func (c *Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
  146. panic("unimplemented")
  147. }
  148. type Array struct {
  149. name string
  150. a C.mlx_array
  151. }
  152. func (a *Array) LogValue() slog.Value {
  153. return slog.GroupValue(
  154. slog.String("name", a.name),
  155. slog.Any("shape", a.Shape()),
  156. )
  157. }
  158. // Add implements ml.Tensor.
  159. func (a *Array) Add(ctx ml.Context, a2 ml.Tensor) ml.Tensor {
  160. panic("unimplemented")
  161. }
  162. // Bytes implements ml.Tensor.
  163. func (a *Array) Bytes() []byte {
  164. panic("unimplemented")
  165. }
  166. // Concat implements ml.Tensor.
  167. func (a *Array) Concat(ctx ml.Context, a2 ml.Tensor, dim int) ml.Tensor {
  168. panic("unimplemented")
  169. }
  170. // Contiguous implements ml.Tensor.
  171. func (a *Array) Contiguous(ctx ml.Context) ml.Tensor {
  172. panic("unimplemented")
  173. }
  174. // Conv2D implements ml.Tensor.
  175. func (a *Array) Conv2D(ctx ml.Context, weight ml.Tensor, s0 int, s1 int, p0 int, p1 int, d0 int, d1 int) ml.Tensor {
  176. panic("unimplemented")
  177. }
  178. // Copy implements ml.Tensor.
  179. func (a *Array) Copy(ctx ml.Context, a2 ml.Tensor) ml.Tensor {
  180. panic("unimplemented")
  181. }
  182. // DType implements ml.Tensor.
  183. func (a *Array) DType() ml.DType {
  184. panic("unimplemented")
  185. }
  186. // Dim implements ml.Tensor.
  187. func (a *Array) Dim(n int) int64 {
  188. return int64(C.mlx_array_dim(a.a, C.int(n)))
  189. }
  190. // Floats implements ml.Tensor.
  191. func (a *Array) Floats() []float32 {
  192. panic("unimplemented")
  193. }
  194. // GELU implements ml.Tensor.
  195. func (a *Array) GELU(ctx ml.Context) ml.Tensor {
  196. panic("unimplemented")
  197. }
  198. // Mul implements ml.Tensor.
  199. func (a *Array) Mul(ctx ml.Context, a2 ml.Tensor) ml.Tensor {
  200. panic("unimplemented")
  201. }
  202. // Mulmat implements ml.Tensor.
  203. func (a *Array) Mulmat(ctx ml.Context, a2 ml.Tensor) ml.Tensor {
  204. slog.Info("mulmat", "a", a, "a2", a2)
  205. var r C.mlx_array
  206. C.mlx_matmul(&r, a2.(*Array).a, a.Permute(1, 0, 2, 3), ctx.(*Context).stream)
  207. return &Array{a: r}
  208. }
  209. // LayerNorm implements ml.Tensor.
  210. func (a *Array) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
  211. var r C.mlx_array
  212. C.mlx_fast_layer_norm(
  213. &r,
  214. a.a,
  215. w.(*Array).a,
  216. b.(*Array).a,
  217. C.float(eps),
  218. ctx.(*Context).stream,
  219. )
  220. return &Array{a: r}
  221. }
  222. // Pad implements ml.Tensor.
  223. func (a *Array) Pad(ctx ml.Context, shape ...int64) ml.Tensor {
  224. panic("unimplemented")
  225. }
  226. // Permute implements ml.Tensor.
  227. func (a *Array) Permute(ctx ml.Context, shape ...int) ml.Tensor {
  228. panic("unimplemented")
  229. }
  230. // RMSNorm implements ml.Tensor.
  231. func (a *Array) RMSNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
  232. var r C.mlx_array
  233. C.mlx_fast_rms_norm(
  234. &r,
  235. a.a,
  236. w.(*Array).a,
  237. C.float(eps),
  238. ctx.(*Context).stream,
  239. )
  240. return &Array{a: r}
  241. }
  242. // Reshape implements ml.Tensor.
  243. func (a *Array) Reshape(ctx ml.Context, shape ...int64) ml.Tensor {
  244. cshape := make([]C.int, len(shape))
  245. for i, dim := range shape {
  246. cshape[i] = C.int(dim)
  247. }
  248. var r C.mlx_array
  249. C.mlx_reshape(&r, a.a, (*C.int)(&cshape[0]), C.size_t(len(cshape)), ctx.(*Context).stream)
  250. return &Array{a: r}
  251. }
  252. // Rope implements ml.Tensor.
  253. func (a *Array) Rope(ctx ml.Context, positionIDs ml.Tensor, ropeFactors ml.Tensor, dim uint32, base float32, scale float32) ml.Tensor {
  254. panic("unimplemented")
  255. }
  256. // Rows implements ml.Tensor.
  257. func (a *Array) Rows(ctx ml.Context, a2 ml.Tensor) ml.Tensor {
  258. var r C.mlx_array
  259. slog.Info("rows", "a", a, "a2", a2)
  260. C.mlx_take(&r, a.a, a2.(*Array).a, 0, ctx.(*Context).stream)
  261. return &Array{a: r}
  262. }
  263. // SILU implements ml.Tensor.
  264. func (a *Array) SILU(ctx ml.Context) ml.Tensor {
  265. panic("unimplemented")
  266. }
  267. // Scale implements ml.Tensor.
  268. func (a *Array) Scale(ctx ml.Context, s float64) ml.Tensor {
  269. panic("unimplemented")
  270. }
  271. // Shape implements ml.Tensor.
  272. func (a *Array) Shape() []int64 {
  273. shape := make([]int64, C.mlx_array_ndim(a.a))
  274. for i := range shape {
  275. shape[i] = int64(C.mlx_array_dim(a.a, C.int(i)))
  276. }
  277. return shape
  278. }
  279. // Softmax implements ml.Tensor.
  280. func (a *Array) Softmax(ctx ml.Context) ml.Tensor {
  281. panic("unimplemented")
  282. }
  283. // Stack implements ml.Tensor.
  284. func (a *Array) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
  285. panic("unimplemented")
  286. }
  287. // Stride implements ml.Tensor.
  288. func (a *Array) Stride(n int) int64 {
  289. panic("unimplemented")
  290. }
  291. // Tanh implements ml.Tensor.
  292. func (a *Array) Tanh(ctx ml.Context) ml.Tensor {
  293. panic("unimplemented")
  294. }
  295. // Unpad implements ml.Tensor.
  296. func (a *Array) Unpad(ctx ml.Context, shape ...int64) ml.Tensor {
  297. panic("unimplemented")
  298. }
  299. // View implements ml.Tensor.
  300. func (a *Array) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
  301. panic("unimplemented")
  302. }