ggml-metal-impl.h 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. /**
  2. * llama.cpp - commit ba1cb19cdd0d92e012e0f6e009e0620f854b6afd - do not edit this file
  3. *
  4. * MIT License
  5. *
  6. * Copyright (c) 2023-2024 The ggml authors
  7. *
  8. * Permission is hereby granted, free of charge, to any person obtaining a copy
  9. * of this software and associated documentation files (the "Software"), to deal
  10. * in the Software without restriction, including without limitation the rights
  11. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  12. * copies of the Software, and to permit persons to whom the Software is
  13. * furnished to do so, subject to the following conditions:
  14. *
  15. * The above copyright notice and this permission notice shall be included in all
  16. * copies or substantial portions of the Software.
  17. *
  18. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  19. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  20. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  21. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  22. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  23. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  24. * SOFTWARE.
  25. */
  26. #ifndef GGML_METAL_IMPL
  27. #define GGML_METAL_IMPL
  28. // kernel argument structs
  29. //
  30. // - element counters (e.g. ne00) typically use int32_t to reduce register usage
  31. // however, be careful from int overflows when using those in the kernel implementation
  32. //
  33. // - strides (e.g. nb00) use uint64_t
  34. typedef struct {
  35. int32_t ne00;
  36. int32_t ne01;
  37. int32_t ne02;
  38. int32_t ne03;
  39. uint64_t nb00;
  40. uint64_t nb01;
  41. uint64_t nb02;
  42. uint64_t nb03;
  43. int32_t ne10;
  44. int32_t ne11;
  45. int32_t ne12;
  46. int32_t ne13;
  47. uint64_t nb10;
  48. uint64_t nb11;
  49. uint64_t nb12;
  50. uint64_t nb13;
  51. int32_t ne0;
  52. int32_t ne1;
  53. int32_t ne2;
  54. int32_t ne3;
  55. uint64_t nb0;
  56. uint64_t nb1;
  57. uint64_t nb2;
  58. uint64_t nb3;
  59. int32_t dim;
  60. } ggml_metal_kargs_concat;
  61. typedef struct {
  62. int32_t ne00;
  63. int32_t ne01;
  64. int32_t ne02;
  65. int32_t ne03;
  66. uint64_t nb00;
  67. uint64_t nb01;
  68. uint64_t nb02;
  69. uint64_t nb03;
  70. int32_t ne10;
  71. int32_t ne11;
  72. int32_t ne12;
  73. int32_t ne13;
  74. uint64_t nb10;
  75. uint64_t nb11;
  76. uint64_t nb12;
  77. uint64_t nb13;
  78. int32_t ne0;
  79. int32_t ne1;
  80. int32_t ne2;
  81. int32_t ne3;
  82. uint64_t nb0;
  83. uint64_t nb1;
  84. uint64_t nb2;
  85. uint64_t nb3;
  86. uint64_t offs;
  87. } ggml_metal_kargs_bin;
  88. typedef struct {
  89. int32_t ne00;
  90. int32_t ne01;
  91. int32_t ne02;
  92. int32_t ne03;
  93. uint64_t nb00;
  94. uint64_t nb01;
  95. uint64_t nb02;
  96. uint64_t nb03;
  97. int32_t ne0;
  98. int32_t ne1;
  99. int32_t ne2;
  100. int32_t ne3;
  101. uint64_t nb0;
  102. uint64_t nb1;
  103. uint64_t nb2;
  104. uint64_t nb3;
  105. } ggml_metal_kargs_repeat;
  106. typedef struct {
  107. int64_t ne00;
  108. int64_t ne01;
  109. int64_t ne02;
  110. int64_t ne03;
  111. uint64_t nb00;
  112. uint64_t nb01;
  113. uint64_t nb02;
  114. uint64_t nb03;
  115. int64_t ne0;
  116. int64_t ne1;
  117. int64_t ne2;
  118. int64_t ne3;
  119. uint64_t nb0;
  120. uint64_t nb1;
  121. uint64_t nb2;
  122. uint64_t nb3;
  123. } ggml_metal_kargs_cpy;
  124. typedef struct {
  125. int64_t ne10;
  126. int64_t ne11;
  127. int64_t ne12;
  128. uint64_t nb10;
  129. uint64_t nb11;
  130. uint64_t nb12;
  131. uint64_t nb13;
  132. uint64_t nb1;
  133. uint64_t nb2;
  134. uint64_t nb3;
  135. uint64_t offs;
  136. bool inplace;
  137. } ggml_metal_kargs_set;
  138. typedef struct {
  139. int32_t ne00;
  140. int32_t ne01;
  141. int32_t ne02;
  142. int32_t ne03;
  143. uint64_t nb00;
  144. uint64_t nb01;
  145. uint64_t nb02;
  146. uint64_t nb03;
  147. int32_t ne0;
  148. int32_t ne1;
  149. int32_t ne2;
  150. int32_t ne3;
  151. uint64_t nb0;
  152. uint64_t nb1;
  153. uint64_t nb2;
  154. uint64_t nb3;
  155. int32_t n_past;
  156. int32_t n_dims;
  157. int32_t n_ctx_orig;
  158. float freq_base;
  159. float freq_scale;
  160. float ext_factor;
  161. float attn_factor;
  162. float beta_fast;
  163. float beta_slow;
  164. } ggml_metal_kargs_rope;
  165. typedef struct {
  166. int32_t ne01;
  167. int32_t ne02;
  168. int32_t ne03;
  169. uint64_t nb01;
  170. uint64_t nb02;
  171. uint64_t nb03;
  172. int32_t ne11;
  173. int32_t ne_12_2; // assume K and V are same shape
  174. int32_t ne_12_3;
  175. uint64_t nb_12_1;
  176. uint64_t nb_12_2;
  177. uint64_t nb_12_3;
  178. uint64_t nb31;
  179. int32_t ne1;
  180. int32_t ne2;
  181. float scale;
  182. float max_bias;
  183. float m0;
  184. float m1;
  185. uint16_t n_head_log2;
  186. float logit_softcap;
  187. } ggml_metal_kargs_flash_attn_ext;
  188. typedef struct {
  189. int32_t ne00;
  190. int32_t ne02;
  191. uint64_t nb01;
  192. uint64_t nb02;
  193. uint64_t nb03;
  194. int32_t ne12;
  195. uint64_t nb10;
  196. uint64_t nb11;
  197. uint64_t nb12;
  198. uint64_t nb13;
  199. int32_t ne0;
  200. int32_t ne1;
  201. int16_t r2;
  202. int16_t r3;
  203. } ggml_metal_kargs_mul_mm;
  204. typedef struct {
  205. int32_t ne00;
  206. int32_t ne01;
  207. int32_t ne02;
  208. uint64_t nb00;
  209. uint64_t nb01;
  210. uint64_t nb02;
  211. uint64_t nb03;
  212. int32_t ne10;
  213. int32_t ne11;
  214. int32_t ne12;
  215. uint64_t nb10;
  216. uint64_t nb11;
  217. uint64_t nb12;
  218. uint64_t nb13;
  219. int32_t ne0;
  220. int32_t ne1;
  221. int16_t r2;
  222. int16_t r3;
  223. } ggml_metal_kargs_mul_mv;
  224. typedef struct {
  225. int32_t ne00;
  226. int32_t ne01;
  227. int32_t ne02;
  228. uint64_t nb00;
  229. uint64_t nb01;
  230. uint64_t nb02;
  231. uint64_t nb03;
  232. int32_t ne10;
  233. int32_t ne11;
  234. int32_t ne12;
  235. uint64_t nb10;
  236. uint64_t nb11;
  237. uint64_t nb12;
  238. uint64_t nb13;
  239. int32_t ne0;
  240. int32_t ne1;
  241. int16_t r2;
  242. int16_t r3;
  243. int16_t nsg;
  244. int16_t nxpsg;
  245. int16_t r1ptg;
  246. } ggml_metal_kargs_mul_mv_ext;
  247. typedef struct {
  248. int32_t nei0;
  249. int32_t nei1;
  250. uint64_t nbi1;
  251. int32_t ne00;
  252. int32_t ne02;
  253. uint64_t nb01;
  254. uint64_t nb02;
  255. int32_t ne11;
  256. int32_t ne12;
  257. int32_t ne13;
  258. uint64_t nb10;
  259. uint64_t nb11;
  260. uint64_t nb12;
  261. int32_t ne0;
  262. int32_t ne1;
  263. } ggml_metal_kargs_mul_mm_id;
  264. typedef struct {
  265. int32_t nei0;
  266. int32_t nei1;
  267. uint64_t nbi1;
  268. int32_t ne00;
  269. int32_t ne01;
  270. int32_t ne02;
  271. uint64_t nb00;
  272. uint64_t nb01;
  273. uint64_t nb02;
  274. int32_t ne10;
  275. int32_t ne11;
  276. int32_t ne12;
  277. int32_t ne13;
  278. uint64_t nb10;
  279. uint64_t nb11;
  280. uint64_t nb12;
  281. int32_t ne0;
  282. int32_t ne1;
  283. uint64_t nb1;
  284. } ggml_metal_kargs_mul_mv_id;
  285. typedef struct {
  286. int32_t ne00;
  287. int32_t ne00_4;
  288. uint64_t nb01;
  289. float eps;
  290. } ggml_metal_kargs_norm;
  291. typedef struct {
  292. int32_t ne00;
  293. int32_t ne00_4;
  294. uint64_t nb01;
  295. float eps;
  296. } ggml_metal_kargs_rms_norm;
  297. #endif // GGML_METAL_IMPL