ggml-metal-impl.h 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. /**
  2. * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  3. *
  4. * MIT License
  5. *
  6. * Copyright (c) 2023-2024 The ggml authors
  7. *
  8. * Permission is hereby granted, free of charge, to any person obtaining a copy
  9. * of this software and associated documentation files (the "Software"), to deal
  10. * in the Software without restriction, including without limitation the rights
  11. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  12. * copies of the Software, and to permit persons to whom the Software is
  13. * furnished to do so, subject to the following conditions:
  14. *
  15. * The above copyright notice and this permission notice shall be included in all
  16. * copies or substantial portions of the Software.
  17. *
  18. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  19. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  20. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  21. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  22. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  23. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  24. * SOFTWARE.
  25. */
  26. #ifndef GGML_METAL_IMPL
  27. #define GGML_METAL_IMPL
  28. // kernel argument structs
  29. //
  30. // - element counters (e.g. ne00) typically use int32_t to reduce register usage
  31. // however, be careful from int overflows when using those in the kernel implementation
  32. //
  33. // - strides (e.g. nb00) use uint64_t
  34. typedef struct {
  35. int32_t ne00;
  36. int32_t ne01;
  37. int32_t ne02;
  38. int32_t ne03;
  39. uint64_t nb00;
  40. uint64_t nb01;
  41. uint64_t nb02;
  42. uint64_t nb03;
  43. int32_t ne10;
  44. int32_t ne11;
  45. int32_t ne12;
  46. int32_t ne13;
  47. uint64_t nb10;
  48. uint64_t nb11;
  49. uint64_t nb12;
  50. uint64_t nb13;
  51. int32_t ne0;
  52. int32_t ne1;
  53. int32_t ne2;
  54. int32_t ne3;
  55. uint64_t nb0;
  56. uint64_t nb1;
  57. uint64_t nb2;
  58. uint64_t nb3;
  59. int32_t dim;
  60. } ggml_metal_kargs_concat;
  61. typedef struct {
  62. int32_t ne00;
  63. int32_t ne01;
  64. int32_t ne02;
  65. int32_t ne03;
  66. uint64_t nb00;
  67. uint64_t nb01;
  68. uint64_t nb02;
  69. uint64_t nb03;
  70. int32_t ne10;
  71. int32_t ne11;
  72. int32_t ne12;
  73. int32_t ne13;
  74. uint64_t nb10;
  75. uint64_t nb11;
  76. uint64_t nb12;
  77. uint64_t nb13;
  78. int32_t ne0;
  79. int32_t ne1;
  80. int32_t ne2;
  81. int32_t ne3;
  82. uint64_t nb0;
  83. uint64_t nb1;
  84. uint64_t nb2;
  85. uint64_t nb3;
  86. uint64_t offs;
  87. } ggml_metal_kargs_bin;
  88. typedef struct {
  89. int32_t ne00;
  90. int32_t ne01;
  91. int32_t ne02;
  92. int32_t ne03;
  93. uint64_t nb00;
  94. uint64_t nb01;
  95. uint64_t nb02;
  96. uint64_t nb03;
  97. int32_t ne0;
  98. int32_t ne1;
  99. int32_t ne2;
  100. int32_t ne3;
  101. uint64_t nb0;
  102. uint64_t nb1;
  103. uint64_t nb2;
  104. uint64_t nb3;
  105. } ggml_metal_kargs_repeat;
  106. typedef struct {
  107. int64_t ne00;
  108. int64_t ne01;
  109. int64_t ne02;
  110. int64_t ne03;
  111. uint64_t nb00;
  112. uint64_t nb01;
  113. uint64_t nb02;
  114. uint64_t nb03;
  115. int64_t ne0;
  116. int64_t ne1;
  117. int64_t ne2;
  118. int64_t ne3;
  119. uint64_t nb0;
  120. uint64_t nb1;
  121. uint64_t nb2;
  122. uint64_t nb3;
  123. } ggml_metal_kargs_cpy;
  124. typedef struct {
  125. int32_t ne00;
  126. int32_t ne01;
  127. int32_t ne02;
  128. int32_t ne03;
  129. uint64_t nb00;
  130. uint64_t nb01;
  131. uint64_t nb02;
  132. uint64_t nb03;
  133. int32_t ne0;
  134. int32_t ne1;
  135. int32_t ne2;
  136. int32_t ne3;
  137. uint64_t nb0;
  138. uint64_t nb1;
  139. uint64_t nb2;
  140. uint64_t nb3;
  141. int32_t n_past;
  142. int32_t n_dims;
  143. int32_t n_ctx_orig;
  144. float freq_base;
  145. float freq_scale;
  146. float ext_factor;
  147. float attn_factor;
  148. float beta_fast;
  149. float beta_slow;
  150. } ggml_metal_kargs_rope;
  151. typedef struct {
  152. int32_t ne01;
  153. int32_t ne02;
  154. int32_t ne03;
  155. uint64_t nb01;
  156. uint64_t nb02;
  157. uint64_t nb03;
  158. int32_t ne11;
  159. int32_t ne_12_2; // assume K and V are same shape
  160. int32_t ne_12_3;
  161. uint64_t nb_12_1;
  162. uint64_t nb_12_2;
  163. uint64_t nb_12_3;
  164. uint64_t nb31;
  165. int32_t ne1;
  166. int32_t ne2;
  167. float scale;
  168. float max_bias;
  169. float m0;
  170. float m1;
  171. uint16_t n_head_log2;
  172. float logit_softcap;
  173. } ggml_metal_kargs_flash_attn_ext;
  174. typedef struct {
  175. int32_t ne00;
  176. int32_t ne02;
  177. uint64_t nb01;
  178. uint64_t nb02;
  179. uint64_t nb03;
  180. int32_t ne12;
  181. uint64_t nb10;
  182. uint64_t nb11;
  183. uint64_t nb12;
  184. uint64_t nb13;
  185. int32_t ne0;
  186. int32_t ne1;
  187. int16_t r2;
  188. int16_t r3;
  189. } ggml_metal_kargs_mul_mm;
  190. typedef struct {
  191. int32_t ne00;
  192. int32_t ne01;
  193. int32_t ne02;
  194. uint64_t nb00;
  195. uint64_t nb01;
  196. uint64_t nb02;
  197. uint64_t nb03;
  198. int32_t ne10;
  199. int32_t ne11;
  200. int32_t ne12;
  201. uint64_t nb10;
  202. uint64_t nb11;
  203. uint64_t nb12;
  204. uint64_t nb13;
  205. int32_t ne0;
  206. int32_t ne1;
  207. int16_t r2;
  208. int16_t r3;
  209. } ggml_metal_kargs_mul_mv;
  210. typedef struct {
  211. int32_t ne00;
  212. int32_t ne01;
  213. int32_t ne02;
  214. uint64_t nb00;
  215. uint64_t nb01;
  216. uint64_t nb02;
  217. uint64_t nb03;
  218. int32_t ne10;
  219. int32_t ne11;
  220. int32_t ne12;
  221. uint64_t nb10;
  222. uint64_t nb11;
  223. uint64_t nb12;
  224. uint64_t nb13;
  225. int32_t ne0;
  226. int32_t ne1;
  227. int16_t r2;
  228. int16_t r3;
  229. int16_t nsg;
  230. int16_t nxpsg;
  231. int16_t r1ptg;
  232. } ggml_metal_kargs_mul_mv_ext;
  233. typedef struct {
  234. int32_t nei0;
  235. int32_t nei1;
  236. uint64_t nbi1;
  237. int32_t ne00;
  238. int32_t ne02;
  239. uint64_t nb01;
  240. uint64_t nb02;
  241. int32_t ne11;
  242. int32_t ne12;
  243. int32_t ne13;
  244. uint64_t nb10;
  245. uint64_t nb11;
  246. uint64_t nb12;
  247. int32_t ne0;
  248. int32_t ne1;
  249. } ggml_metal_kargs_mul_mm_id;
  250. typedef struct {
  251. int32_t nei0;
  252. int32_t nei1;
  253. uint64_t nbi1;
  254. int32_t ne00;
  255. int32_t ne01;
  256. int32_t ne02;
  257. uint64_t nb00;
  258. uint64_t nb01;
  259. uint64_t nb02;
  260. int32_t ne10;
  261. int32_t ne11;
  262. int32_t ne12;
  263. int32_t ne13;
  264. uint64_t nb10;
  265. uint64_t nb11;
  266. uint64_t nb12;
  267. int32_t ne0;
  268. int32_t ne1;
  269. uint64_t nb1;
  270. } ggml_metal_kargs_mul_mv_id;
  271. typedef struct {
  272. int32_t ne00;
  273. int32_t ne00_4;
  274. uint64_t nb01;
  275. float eps;
  276. } ggml_metal_kargs_norm;
  277. typedef struct {
  278. int32_t ne00;
  279. int32_t ne00_4;
  280. uint64_t nb01;
  281. float eps;
  282. } ggml_metal_kargs_rms_norm;
  283. #endif // GGML_METAL_IMPL