sgemm.cpp 69 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893
  1. // Copyright 2024 Mozilla Foundation
  2. //
  3. // Permission is hereby granted, free of charge, to any person obtaining
  4. // a copy of this software and associated documentation files (the
  5. // "Software"), to deal in the Software without restriction, including
  6. // without limitation the rights to use, copy, modify, merge, publish,
  7. // distribute, sublicense, and/or sell copies of the Software, and to
  8. // permit persons to whom the Software is furnished to do so, subject to
  9. // the following conditions:
  10. //
  11. // The above copyright notice and this permission notice shall be
  12. // included in all copies or substantial portions of the Software.
  13. //
  14. // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
  15. // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
  16. // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
  17. // NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
  18. // BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
  19. // ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
  20. // CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  21. // SOFTWARE.
  22. //
  23. // _ _ ___ _ _ ___
  24. // | |_(_)_ _ _ _| _ ) | /_\ / __|
  25. // | _| | ' \ || | _ \ |__ / _ \\__ \.
  26. // \__|_|_||_\_, |___/____/_/ \_\___/
  27. // |__/
  28. //
  29. // BASIC LINEAR ALGEBRA SUBPROGRAMS
  30. //
  31. //
  32. // This file implements multithreaded CPU matrix multiplication for the
  33. // common contiguous use case C = Aᵀ * B. These kernels are designed to
  34. // have excellent performance[1] for matrices that fit in the CPU cache
  35. // without imposing any overhead such as cache filling or malloc calls.
  36. //
  37. // This implementation does not guarantee any upper bound with rounding
  38. // errors, which grow along with k. Our goal's to maximally exploit the
  39. // hardware for performance, and then use whatever resources remain for
  40. // improving numerical accuracy.
  41. //
  42. // [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].
  43. // Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
  44. #if defined(__GNUC__)
  45. #pragma GCC diagnostic ignored "-Wpedantic"
  46. #pragma GCC diagnostic ignored "-Wignored-attributes"
  47. #endif
  48. #include "sgemm.h"
  49. #include "ggml-impl.h"
  50. #include "ggml-cpu-impl.h"
  51. #include "ggml-quants.h"
  52. #include <atomic>
  53. #ifdef _MSC_VER
  54. #define NOINLINE __declspec(noinline)
  55. #else
  56. #define NOINLINE __attribute__((__noinline__))
  57. #endif
  58. #if defined(__ARM_NEON) || defined(__AVX512F__)
  59. #define VECTOR_REGISTERS 32
  60. #else
  61. #define VECTOR_REGISTERS 16
  62. #endif
  63. #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
  64. namespace {
  65. inline float unhalf(ggml_fp16_t d) {
  66. return GGML_FP16_TO_FP32(d);
  67. }
  68. ////////////////////////////////////////////////////////////////////////////////////////////////////
  69. // VECTORIZED ARITHMETIC OPERATIONS
  70. #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  71. inline __m128 add(__m128 x, __m128 y) { return _mm_add_ps(x, y); }
  72. inline __m128 sub(__m128 x, __m128 y) { return _mm_sub_ps(x, y); }
  73. inline __m128 mul(__m128 x, __m128 y) { return _mm_mul_ps(x, y); }
  74. #endif // __SSE__
  75. #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  76. inline __m256 add(__m256 x, __m256 y) { return _mm256_add_ps(x, y); }
  77. inline __m256 sub(__m256 x, __m256 y) { return _mm256_sub_ps(x, y); }
  78. inline __m256 mul(__m256 x, __m256 y) { return _mm256_mul_ps(x, y); }
  79. #endif // __AVX__
  80. #if defined(__AVX512F__)
  81. inline __m512 add(__m512 x, __m512 y) { return _mm512_add_ps(x, y); }
  82. inline __m512 sub(__m512 x, __m512 y) { return _mm512_sub_ps(x, y); }
  83. inline __m512 mul(__m512 x, __m512 y) { return _mm512_mul_ps(x, y); }
  84. #endif // __AVX512F__
  85. #if defined(__ARM_NEON)
  86. inline float32x4_t add(float32x4_t x, float32x4_t y) { return vaddq_f32(x, y); }
  87. inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vsubq_f32(x, y); }
  88. inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vmulq_f32(x, y); }
  89. #endif // __ARM_NEON
  90. #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
  91. inline float16x8_t add(float16x8_t x, float16x8_t y) { return vaddq_f16(x, y); }
  92. inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
  93. inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
  94. #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  95. #if defined(__MMA__)
  96. typedef vector unsigned char vec_t;
  97. typedef __vector_quad acc_t;
  98. #endif
  99. ////////////////////////////////////////////////////////////////////////////////////////////////////
  100. // VECTORIZED FUSED MULTIPLY ADD
  101. /**
  102. * Computes a * b + c.
  103. */
  104. template <typename T, typename U>
  105. inline U madd(T a, T b, U c) {
  106. return add(mul(a, b), c);
  107. }
  108. #if defined(__FMA__)
  109. #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  110. template <>
  111. inline __m256 madd(__m256 a, __m256 b, __m256 c) {
  112. return _mm256_fmadd_ps(a, b, c);
  113. }
  114. #endif
  115. #if defined(__AVX512F__)
  116. template <>
  117. inline __m512 madd(__m512 a, __m512 b, __m512 c) {
  118. return _mm512_fmadd_ps(a, b, c);
  119. }
  120. #endif
  121. #if defined(__AVX512BF16__)
  122. template <>
  123. inline __m512 madd(__m512bh a, __m512bh b, __m512 c) {
  124. return _mm512_dpbf16_ps(c, a, b);
  125. }
  126. template <>
  127. inline __m256 madd(__m256bh a, __m256bh b, __m256 c) {
  128. return _mm256_dpbf16_ps(c, a, b);
  129. }
  130. #endif
  131. #endif
  132. #if defined(__ARM_FEATURE_FMA)
  133. template <>
  134. inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
  135. return vfmaq_f32(c, b, a);
  136. }
  137. #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
  138. template <>
  139. inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
  140. return vfmaq_f16(c, b, a);
  141. }
  142. #endif
  143. #endif
  144. ////////////////////////////////////////////////////////////////////////////////////////////////////
  145. // VECTORIZED HORIZONTAL SUM
  146. #if defined(__ARM_NEON)
  147. inline float hsum(float32x4_t x) {
  148. return vaddvq_f32(x);
  149. }
  150. #endif // __ARM_NEON
  151. #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
  152. inline float hsum(float16x8_t x) {
  153. return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)),
  154. vcvt_f32_f16(vget_high_f16(x))));
  155. }
  156. #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  157. #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  158. inline float hsum(__m128 x) {
  159. #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  160. x = _mm_add_ps(x, _mm_movehl_ps(x, x));
  161. x = _mm_add_ss(x, _mm_movehdup_ps(x));
  162. #else
  163. __m128 t;
  164. t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1));
  165. x = _mm_add_ps(x, t);
  166. t = _mm_movehl_ps(t, x);
  167. x = _mm_add_ss(x, t);
  168. #endif
  169. return _mm_cvtss_f32(x);
  170. }
  171. #endif
  172. #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  173. inline float hsum(__m256 x) {
  174. return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1),
  175. _mm256_castps256_ps128(x)));
  176. }
  177. #endif // __AVX__
  178. #if defined(__AVX512F__)
  179. inline float hsum(__m512 x) {
  180. return _mm512_reduce_add_ps(x);
  181. }
  182. #endif // __AVX512F__
  183. ////////////////////////////////////////////////////////////////////////////////////////////////////
  184. // VECTORIZED MEMORY LOADING
  185. template <typename T, typename U> T load(const U *);
  186. #if defined(__ARM_NEON)
  187. template <> inline float32x4_t load(const float *p) {
  188. return vld1q_f32(p);
  189. }
  190. #if !defined(_MSC_VER)
  191. // FIXME: this should check for __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  192. template <> inline float16x8_t load(const ggml_fp16_t *p) {
  193. return vld1q_f16((const float16_t *)p);
  194. }
  195. template <> inline float32x4_t load(const ggml_fp16_t *p) {
  196. return vcvt_f32_f16(vld1_f16((const float16_t *)p));
  197. }
  198. #endif // _MSC_VER
  199. #endif // __ARM_NEON
  200. #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  201. template <> inline __m128 load(const float *p) {
  202. return _mm_loadu_ps(p);
  203. }
  204. #endif // __SSE__
  205. #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  206. template <> inline __m256 load(const float *p) {
  207. return _mm256_loadu_ps(p);
  208. }
  209. #endif // __AVX__
  210. #if defined(__AVX2__) || defined(__AVX512F__)
  211. template <> inline __m256 load(const ggml_bf16_t *p) {
  212. return _mm256_castsi256_ps(
  213. _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)p)), 16));
  214. }
  215. #endif // __AVX2__
  216. #if defined(__F16C__)
  217. template <> inline __m256 load(const ggml_fp16_t *p) {
  218. return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p));
  219. }
  220. #endif // __F16C__
  221. #if defined(__AVX512F__)
  222. template <> inline __m512 load(const float *p) {
  223. return _mm512_loadu_ps(p);
  224. }
  225. template <> inline __m512 load(const ggml_fp16_t *p) {
  226. return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p));
  227. }
  228. template <> inline __m512 load(const ggml_bf16_t *p) {
  229. return _mm512_castsi512_ps(
  230. _mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)p)), 16));
  231. }
  232. #endif // __AVX512F__
  233. #if defined(__AVX512BF16__)
  234. template <> inline __m512bh load(const ggml_bf16_t *p) {
  235. return (__m512bh)_mm512_loadu_ps((const float *)p);
  236. }
  237. template <> inline __m256bh load(const ggml_bf16_t *p) {
  238. return (__m256bh)_mm256_loadu_ps((const float *)p);
  239. }
  240. template <> inline __m512bh load(const float *p) {
  241. return _mm512_cvtne2ps_pbh(_mm512_loadu_ps(p + 16), _mm512_loadu_ps(p));
  242. }
  243. template <> inline __m256bh load(const float *p) {
  244. return _mm512_cvtneps_pbh(_mm512_loadu_ps(p));
  245. }
  246. #endif
  247. ////////////////////////////////////////////////////////////////////////////////////////////////////
  248. // CONSTANTS
  249. #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  250. static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
  251. static const __m128i iq4nlt = _mm_loadu_si128((const __m128i *) kvalues_iq4nl);
  252. #endif
  253. ////////////////////////////////////////////////////////////////////////////////////////////////////
  254. // FLOATING POINT MATRIX MULTIPLICATION
  255. template <int M>
  256. static inline int64_t BLOCK_SIZE(size_t m) {
  257. const int64_t NB_BLOC_M = (m + M - 1) / M;
  258. return (m % NB_BLOC_M == 0) ? m / NB_BLOC_M : (m / NB_BLOC_M) + 1;
  259. }
  260. static constexpr inline int64_t BLOC_POS(int64_t ib, int64_t ibN, int64_t bloc_size) {
  261. return ib < ibN ? ib * bloc_size : ibN * bloc_size + (ib - ibN) * (bloc_size - 1);
  262. }
  263. template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
  264. class tinyBLAS {
  265. public:
  266. tinyBLAS(const ggml_compute_params * params, int64_t k,
  267. const TA *A, int64_t lda,
  268. const TB *B, int64_t ldb,
  269. TC *C, int64_t ldc)
  270. : params(params), A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc) {
  271. }
  272. bool matmul(int64_t m, int64_t n) {
  273. if (k % KN != 0)
  274. return false;
  275. // compute RM for only need tile with size RM&RM-1
  276. #if VECTOR_REGISTERS == 32
  277. if (m % 16 == 0 && (m/16 >= params->nth)) {
  278. const int64_t SIZE_N = BLOCK_SIZE<6>(n);
  279. mnpack<4, 6, 4>(m, n, SIZE_N, 12);
  280. return true;
  281. }
  282. if (m % 8 == 0 ) {
  283. const int64_t SIZE_N = BLOCK_SIZE<6>(n);
  284. mnpack<4, 6, 2>(m, n, SIZE_N, 12);
  285. return true;
  286. }
  287. if (m % 4 == 0) {
  288. const int64_t SIZE_N = BLOCK_SIZE<6>(n);
  289. mnpack<4, 6, 1>(m, n, SIZE_N, 12);
  290. return true;
  291. }
  292. #else // VECTOR_REGISTERS == 16
  293. if (m % 16 == 0 && (m/16 >= params->nth)) {
  294. const int64_t SIZE_N = BLOCK_SIZE<3>(n);
  295. mnpack<4, 3, 4>(m, n, SIZE_N, 24);
  296. return true;
  297. }
  298. if (m % 8 == 0 ) {
  299. const int64_t SIZE_N = BLOCK_SIZE<3>(n);
  300. mnpack<4, 3, 2>(m, n, SIZE_N, 24);
  301. return true;
  302. }
  303. if (m % 4 == 0) {
  304. const int64_t SIZE_N = BLOCK_SIZE<3>(n);
  305. mnpack<4, 3, 1>(m, n, SIZE_N, 24);
  306. return true;
  307. }
  308. #endif
  309. return false;
  310. }
  311. private:
  312. template <int RM, int RN, int BM>
  313. inline void mnpack(int64_t m, int64_t n, int64_t SIZE_N, int64_t BN) {
  314. if (SIZE_N == RN) {
  315. return gemm<RM, RN, BM>(m, n, BN);
  316. }
  317. if constexpr (RN > 1) {
  318. return mnpack<RM, RN-1, BM>(m, n, SIZE_N, BN);
  319. } else {
  320. GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N);
  321. GGML_ASSERT(false); // we have miss something.
  322. }
  323. }
  324. template <int RM, int RN>
  325. inline void gemm_bloc(int64_t ii, int64_t jj) {
  326. D Cv[RN][RM] = {};
  327. for (int64_t l = 0; l < k; l += KN) {
  328. // help compiler for op order.
  329. if constexpr (RM <= RN) {
  330. V Av[RM];
  331. for (int64_t i = 0; i < RM; ++i) {
  332. Av[i] = load<V>(A + lda * (ii + i) + l);
  333. }
  334. for (int64_t j = 0; j < RN; ++j) {
  335. V Bv = load<V>(B + ldb * (jj + j) + l);
  336. for (int64_t i = 0; i < RM; ++i) {
  337. Cv[j][i] = madd(Av[i], Bv, Cv[j][i]);
  338. }
  339. }
  340. } else {
  341. V Bv[RN];
  342. for (int64_t j = 0; j < RN; ++j) {
  343. Bv[j] = load<V>(B + ldb * (jj + j) + l);
  344. }
  345. for (int64_t i = 0; i < RM; ++i) {
  346. V Av = load<V>(A + lda * (ii + i) + l);
  347. for (int64_t j = 0; j < RN; ++j) {
  348. Cv[j][i] = madd(Av, Bv[j], Cv[j][i]);
  349. }
  350. }
  351. }
  352. }
  353. for (int64_t j = 0; j < RN; ++j)
  354. for (int64_t i = 0; i < RM; ++i)
  355. C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
  356. }
  357. template <int RM, int RN, int BM>
  358. NOINLINE void gemm(int64_t m, int64_t n, int64_t BN) {
  359. static std::atomic<int64_t> current_chunk;
  360. GGML_ASSERT(m % (RM * BM) == 0);
  361. const int64_t ytiles = m / (RM * BM);
  362. const int64_t xtiles = (n + RN -1) / RN;
  363. const int64_t jj_RN = (xtiles - (xtiles * RN - n));
  364. // "round" bloc_size to "nearest" BN
  365. const int64_t NB_BN = xtiles < BN ? 1 : (xtiles + BN / 2) / BN;
  366. const int64_t SIZE_BN = xtiles % NB_BN == 0 ? xtiles / NB_BN : xtiles / NB_BN + 1;
  367. const int64_t jj_BN = (NB_BN - (NB_BN * SIZE_BN - xtiles));
  368. const int64_t nb_job = ytiles * NB_BN;
  369. if (params->ith == 0) {
  370. GGML_ASSERT( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1) == xtiles);
  371. // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
  372. std::atomic_store_explicit(&current_chunk, (int64_t)params->nth, std::memory_order_relaxed);
  373. }
  374. ggml_barrier(params->threadpool);
  375. int64_t job = params->ith;
  376. while (job < nb_job) {
  377. const int64_t ii = (job % ytiles) * RM * BM;
  378. const int64_t jb = job / ytiles;
  379. const int64_t jr0 = BLOC_POS(jb , jj_BN, SIZE_BN);
  380. const int64_t jrN = BLOC_POS(jb+1, jj_BN, SIZE_BN);
  381. const int64_t jj0 = BLOC_POS(jr0, jj_RN, RN);
  382. const int64_t jj2 = BLOC_POS(jrN, jj_RN, RN);
  383. const int64_t jj1 = jj2 < jj_RN * RN ? jj2 : jj_RN * RN;
  384. for (int64_t bi = 0; bi < BM * RM; bi += RM) {
  385. int64_t jj = jj0;
  386. for (; jj < jj1; jj += RN) {
  387. gemm_bloc<RM, RN>(ii + bi, jj);
  388. }
  389. if constexpr (RN > 1) {
  390. for (; jj < jj2; jj += RN - 1) {
  391. gemm_bloc<RM, RN-1>(ii + bi, jj);
  392. }
  393. }
  394. GGML_ASSERT(jj == jj2);
  395. }
  396. // next step.
  397. job = std::atomic_fetch_add_explicit(&current_chunk, (int64_t)1, std::memory_order_relaxed);
  398. }
  399. ggml_barrier(params->threadpool);
  400. return;
  401. }
  402. const ggml_compute_params * params;
  403. const TA *const A;
  404. const TB *const B;
  405. TC *const C;
  406. const int64_t k;
  407. const int64_t lda;
  408. const int64_t ldb;
  409. const int64_t ldc;
  410. };
  411. //////////////////////////////////////////////////////////////////////////////////////////
  412. // QUANT ZERO MATRIX MULTIPLICATION
  413. #if defined(__ARM_FEATURE_DOTPROD)
  414. template <typename TA>
  415. class tinyBLAS_Q0_ARM {
  416. public:
  417. tinyBLAS_Q0_ARM(int64_t k,
  418. const TA *A, int64_t lda,
  419. const block_q8_0 *B, int64_t ldb,
  420. float *C, int64_t ldc,
  421. int ith, int nth)
  422. : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
  423. }
  424. void matmul(int64_t m, int64_t n) {
  425. mnpack(0, m, 0, n);
  426. }
  427. private:
  428. NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  429. int64_t mc, nc, mp, np;
  430. switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) {
  431. case 0x33:
  432. mc = 3;
  433. nc = 3;
  434. gemm<3, 3>(m0, m, n0, n);
  435. break;
  436. case 0x32:
  437. mc = 3;
  438. nc = 2;
  439. gemm<3, 2>(m0, m, n0, n);
  440. break;
  441. case 0x23:
  442. mc = 2;
  443. nc = 3;
  444. gemm<2, 3>(m0, m, n0, n);
  445. break;
  446. case 0x22:
  447. mc = 2;
  448. nc = 2;
  449. gemm<2, 2>(m0, m, n0, n);
  450. break;
  451. case 0x31:
  452. mc = 3;
  453. nc = 1;
  454. gemm<3, 1>(m0, m, n0, n);
  455. break;
  456. case 0x13:
  457. mc = 1;
  458. nc = 3;
  459. gemm<1, 3>(m0, m, n0, n);
  460. break;
  461. case 0x21:
  462. mc = 2;
  463. nc = 1;
  464. gemm<2, 1>(m0, m, n0, n);
  465. break;
  466. case 0x12:
  467. mc = 1;
  468. nc = 2;
  469. gemm<1, 2>(m0, m, n0, n);
  470. break;
  471. case 0x11:
  472. mc = 1;
  473. nc = 1;
  474. gemm<1, 1>(m0, m, n0, n);
  475. break;
  476. default:
  477. return;
  478. }
  479. mp = m0 + (m - m0) / mc * mc;
  480. np = n0 + (n - n0) / nc * nc;
  481. mnpack(mp, m, n0, np);
  482. mnpack(m0, m, np, n);
  483. }
  484. template <int RM, int RN>
  485. NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  486. int64_t ytiles = (m - m0) / RM;
  487. int64_t xtiles = (n - n0) / RN;
  488. int64_t tiles = xtiles * ytiles;
  489. int64_t duty = (tiles + nth - 1) / nth;
  490. int64_t start = duty * ith;
  491. int64_t end = start + duty;
  492. if (end > tiles)
  493. end = tiles;
  494. for (int64_t job = start; job < end; ++job) {
  495. int64_t ii = m0 + job / xtiles * RM;
  496. int64_t jj = n0 + job % xtiles * RN;
  497. float32x4_t Cv[RN][RM] = {};
  498. for (int64_t l = 0; l < k; ++l)
  499. for (int64_t j = 0; j < RN; ++j)
  500. for (int64_t i = 0; i < RM; ++i)
  501. Cv[j][i] = vmlaq_n_f32(Cv[j][i],
  502. vcvtq_f32_s32(vdotq_s32(
  503. vdotq_s32(vdupq_n_s32(0),
  504. load_lo(A + lda * (ii + i) + l),
  505. load_lo(B + ldb * (jj + j) + l)),
  506. load_hi(A + lda * (ii + i) + l),
  507. load_hi(B + ldb * (jj + j) + l))),
  508. unhalf(A[lda * (ii + i) + l].d) *
  509. unhalf(B[ldb * (jj + j) + l].d));
  510. for (int64_t j = 0; j < RN; ++j)
  511. for (int64_t i = 0; i < RM; ++i)
  512. C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
  513. }
  514. }
  515. inline int8x16_t load_lo(const block_q8_0 *b) {
  516. return vld1q_s8(b->qs);
  517. }
  518. inline int8x16_t load_hi(const block_q8_0 *b) {
  519. return vld1q_s8(b->qs + 16);
  520. }
  521. inline int8x16_t load_lo(const block_q4_0 *b) {
  522. return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs),
  523. vdupq_n_u8(0x0f))),
  524. vdupq_n_s8(0x8));
  525. }
  526. inline int8x16_t load_hi(const block_q4_0 *b) {
  527. return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)),
  528. vdupq_n_s8(0x8));
  529. }
  530. const TA *const A;
  531. const block_q8_0 *const B;
  532. float *const C;
  533. const int64_t k;
  534. const int64_t lda;
  535. const int64_t ldb;
  536. const int64_t ldc;
  537. const int ith;
  538. const int nth;
  539. };
  540. #endif // __ARM_FEATURE_DOTPROD
  541. #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
  542. template <typename TA, typename TB, typename TC>
  543. class tinyBLAS_Q0_AVX {
  544. public:
  545. tinyBLAS_Q0_AVX(int64_t k,
  546. const TA *A, int64_t lda,
  547. const TB *B, int64_t ldb,
  548. TC *C, int64_t ldc,
  549. int ith, int nth)
  550. : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
  551. }
  552. void matmul(int64_t m, int64_t n) {
  553. mnpack(0, m, 0, n);
  554. }
  555. private:
  556. void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  557. int64_t mc, nc, mp, np;
  558. switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) {
  559. #if VECTOR_REGISTERS == 32
  560. case 0x44:
  561. mc = 4;
  562. nc = 4;
  563. #if defined(__AVX2__) && defined(__F16C__)
  564. gemm4xN<4>(m0, m, n0, n);
  565. #else
  566. gemm<4, 4>(m0, m, n0, n);
  567. #endif
  568. break;
  569. case 0x43:
  570. mc = 4;
  571. nc = 3;
  572. #if defined(__AVX2__) && defined(__F16C__)
  573. gemm4xN<3>(m0, m, n0, n);
  574. #else
  575. gemm<4, 3>(m0, m, n0, n);
  576. #endif
  577. break;
  578. case 0x34:
  579. mc = 3;
  580. nc = 4;
  581. #if defined(__AVX2__) && defined(__F16C__)
  582. gemmMx4<3>(m0, m, n0, n);
  583. #else
  584. gemm<3, 4>(m0, m, n0, n);
  585. #endif
  586. break;
  587. case 0x33:
  588. mc = 3;
  589. nc = 3;
  590. gemm<3, 3>(m0, m, n0, n);
  591. break;
  592. case 0x42:
  593. mc = 4;
  594. nc = 2;
  595. #if defined(__AVX2__) && defined(__F16C__)
  596. gemm4xN<2>(m0, m, n0, n);
  597. #else
  598. gemm<4, 2>(m0, m, n0, n);
  599. #endif
  600. break;
  601. case 0x24:
  602. mc = 2;
  603. nc = 4;
  604. #if defined(__AVX2__) && defined(__F16C__)
  605. gemmMx4<2>(m0, m, n0, n);
  606. #else
  607. gemm<2, 4>(m0, m, n0, n);
  608. #endif
  609. break;
  610. #else
  611. case 0x44:
  612. case 0x43:
  613. case 0x42:
  614. mc = 4;
  615. nc = 2;
  616. #if defined(__AVX2__) && defined(__F16C__)
  617. gemm4xN<2>(m0, m, n0, n);
  618. #else
  619. gemm<4, 2>(m0, m, n0, n);
  620. #endif
  621. break;
  622. case 0x34:
  623. case 0x24:
  624. mc = 2;
  625. nc = 4;
  626. #if defined(__AVX2__) && defined(__F16C__)
  627. gemmMx4<2>(m0, m, n0, n);
  628. #else
  629. gemm<2, 4>(m0, m, n0, n);
  630. #endif
  631. break;
  632. case 0x33:
  633. #endif
  634. case 0x32:
  635. mc = 3;
  636. nc = 2;
  637. gemm<3, 2>(m0, m, n0, n);
  638. break;
  639. case 0x23:
  640. mc = 2;
  641. nc = 3;
  642. gemm<2, 3>(m0, m, n0, n);
  643. break;
  644. case 0x41:
  645. mc = 4;
  646. nc = 1;
  647. #if defined(__AVX2__) && defined(__F16C__)
  648. gemm4xN<1>(m0, m, n0, n);
  649. #else
  650. gemm<4, 1>(m0, m, n0, n);
  651. #endif
  652. break;
  653. case 0x22:
  654. mc = 2;
  655. nc = 2;
  656. gemm<2, 2>(m0, m, n0, n);
  657. break;
  658. case 0x14:
  659. mc = 1;
  660. nc = 4;
  661. #if defined(__AVX2__) && defined(__F16C__)
  662. gemmMx4<1>(m0, m, n0, n);
  663. #else
  664. gemm<1, 4>(m0, m, n0, n);
  665. #endif
  666. break;
  667. case 0x31:
  668. mc = 3;
  669. nc = 1;
  670. gemm<3, 1>(m0, m, n0, n);
  671. break;
  672. case 0x13:
  673. mc = 1;
  674. nc = 3;
  675. gemm<1, 3>(m0, m, n0, n);
  676. break;
  677. case 0x21:
  678. mc = 2;
  679. nc = 1;
  680. gemm<2, 1>(m0, m, n0, n);
  681. break;
  682. case 0x12:
  683. mc = 1;
  684. nc = 2;
  685. gemm<1, 2>(m0, m, n0, n);
  686. break;
  687. case 0x11:
  688. mc = 1;
  689. nc = 1;
  690. gemm<1, 1>(m0, m, n0, n);
  691. break;
  692. default:
  693. return;
  694. }
  695. mp = m0 + (m - m0) / mc * mc;
  696. np = n0 + (n - n0) / nc * nc;
  697. mnpack(mp, m, n0, np);
  698. mnpack(m0, m, np, n);
  699. }
  700. #if defined(__AVX2__) && defined(__F16C__)
  701. // Templated functions for gemm of dimensions 4xN
  702. template <int RN>
  703. NOINLINE void gemm4xN(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  704. int64_t ytiles = (m - m0) / 4;
  705. int64_t xtiles = (n - n0) / RN;
  706. int64_t tiles = xtiles * ytiles;
  707. int64_t duty = (tiles + nth - 1) / nth;
  708. int64_t start = duty * ith;
  709. int64_t end = start + duty;
  710. if (end > tiles)
  711. end = tiles;
  712. for (int64_t job = start; job < end; ++job) {
  713. int64_t ii = m0 + job / xtiles * 4;
  714. int64_t jj = n0 + job % xtiles * RN;
  715. __m256 Cv[RN][4] = {};
  716. for (int64_t l = 0; l < k; ++l) {
  717. uint64_t a_delta = ((uint64_t)A[lda * (ii + 3) + l].d << 48) | ((uint64_t)A[lda * (ii + 2) + l].d << 32) | ((uint64_t)A[lda * (ii + 1) + l].d << 16) | (A[lda * (ii + 0) + l].d);
  718. // Convert delta values for four blocks to float values
  719. __m128 da = _mm_cvtph_ps(_mm_set_epi64x(0, a_delta));
  720. __m256i avec0 = load(A + lda * (ii + 0) + l);
  721. __m256i avec1 = load(A + lda * (ii + 1) + l);
  722. __m256i avec2 = load(A + lda * (ii + 2) + l);
  723. __m256i avec3 = load(A + lda * (ii + 3) + l);
  724. for (int64_t j = 0; j < RN; ++j) {
  725. __m128 db = _mm_set1_ps(unhalf(B[ldb * (jj + j) + l].d));
  726. // Computation of product of delta values for four blocks and replicate it across 256 bit lane
  727. __m256 dvec = _mm256_castps128_ps256(_mm_mul_ps(da, db));
  728. dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
  729. // Computation of dot product and multiplication with appropriate delta value products
  730. Cv[j][0] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
  731. updot(_mm256_sign_epi8(avec0, avec0),
  732. _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec0)),
  733. Cv[j][0]);
  734. Cv[j][1] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
  735. updot(_mm256_sign_epi8(avec1, avec1),
  736. _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec1)),
  737. Cv[j][1]);
  738. Cv[j][2] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
  739. updot(_mm256_sign_epi8(avec2, avec2),
  740. _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec2)),
  741. Cv[j][2]);
  742. Cv[j][3] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
  743. updot(_mm256_sign_epi8(avec3, avec3),
  744. _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec3)),
  745. Cv[j][3]);
  746. }
  747. }
  748. for (int64_t j = 0; j < RN; ++j)
  749. for (int64_t i = 0; i < 4; ++i)
  750. C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
  751. }
  752. }
  753. // Templated functions for gemm of dimensions Mx4
  754. template <int RM>
  755. NOINLINE void gemmMx4(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  756. int64_t ytiles = (m - m0) / RM;
  757. int64_t xtiles = (n - n0) / 4;
  758. int64_t tiles = xtiles * ytiles;
  759. int64_t duty = (tiles + nth - 1) / nth;
  760. int64_t start = duty * ith;
  761. int64_t end = start + duty;
  762. if (end > tiles)
  763. end = tiles;
  764. for (int64_t job = start; job < end; ++job) {
  765. int64_t ii = m0 + job / xtiles * RM;
  766. int64_t jj = n0 + job % xtiles * 4;
  767. __m256 Cv[4][RM] = {};
  768. for (int64_t l = 0; l < k; ++l) {
  769. uint64_t b_delta = ((uint64_t)B[ldb * (jj + 3) + l].d << 48) | ((uint64_t)B[ldb * (jj + 2) + l].d << 32) | ((uint64_t)B[ldb * (jj + 1) + l].d << 16) | (B[ldb * (jj + 0) + l].d);
  770. // Convert delta values for four blocks to float values
  771. __m128 db = _mm_cvtph_ps(_mm_set_epi64x(0, b_delta));
  772. __m256i bvec0 = load(B + ldb * (jj + 0) + l);
  773. __m256i bvec1 = load(B + ldb * (jj + 1) + l);
  774. __m256i bvec2 = load(B + ldb * (jj + 2) + l);
  775. __m256i bvec3 = load(B + ldb * (jj + 3) + l);
  776. for (int64_t i = 0; i < RM; ++i) {
  777. __m128 da = _mm_set1_ps(unhalf((A[lda * (ii + i) + l].d)));
  778. // Computation of product of delta values for four blocks and replicate it across 256 bit lane
  779. __m256 dvec = _mm256_castps128_ps256(_mm_mul_ps(da, db));
  780. dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
  781. // Computation of dot product and multiplication with appropriate delta value products
  782. Cv[0][i] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
  783. updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
  784. load(A + lda * (ii + i) + l)),
  785. _mm256_sign_epi8(bvec0, load(A + lda * (ii + i) + l))),
  786. Cv[0][i]);
  787. Cv[1][i] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
  788. updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
  789. load(A + lda * (ii + i) + l)),
  790. _mm256_sign_epi8(bvec1, load(A + lda * (ii + i) + l))),
  791. Cv[1][i]);
  792. Cv[2][i] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
  793. updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
  794. load(A + lda * (ii + i) + l)),
  795. _mm256_sign_epi8(bvec2, load(A + lda * (ii + i) + l))),
  796. Cv[2][i]);
  797. Cv[3][i] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
  798. updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
  799. load(A + lda * (ii + i) + l)),
  800. _mm256_sign_epi8(bvec3, load(A + lda * (ii + i) + l))),
  801. Cv[3][i]);
  802. }
  803. }
  804. for (int64_t j = 0; j < 4; ++j)
  805. for (int64_t i = 0; i < RM; ++i)
  806. C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
  807. }
  808. }
  809. #endif
  810. template <int RM, int RN>
  811. NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  812. int64_t ytiles = (m - m0) / RM;
  813. int64_t xtiles = (n - n0) / RN;
  814. int64_t tiles = xtiles * ytiles;
  815. int64_t duty = (tiles + nth - 1) / nth;
  816. int64_t start = duty * ith;
  817. int64_t end = start + duty;
  818. if (end > tiles)
  819. end = tiles;
  820. for (int64_t job = start; job < end; ++job) {
  821. int64_t ii = m0 + job / xtiles * RM;
  822. int64_t jj = n0 + job % xtiles * RN;
  823. __m256 Cv[RN][RM] = {};
  824. for (int64_t l = 0; l < k; ++l)
  825. for (int64_t j = 0; j < RN; ++j)
  826. for (int64_t i = 0; i < RM; ++i) {
  827. #if defined(__AVX2__)
  828. __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
  829. load(A + lda * (ii + i) + l)),
  830. _mm256_sign_epi8(load(B + ldb * (jj + j) + l),
  831. load(A + lda * (ii + i) + l)));
  832. #else
  833. __m128i ali0 = load0(A + lda * (ii + i) + l);
  834. __m128i ali1 = load1(A + lda * (ii + i) + l);
  835. __m128i blj0 = load0(B + ldb * (jj + j) + l);
  836. __m128i blj1 = load1(B + ldb * (jj + j) + l);
  837. __m128i sepAA0 = _mm_sign_epi8(ali0, ali0);
  838. __m128i sepAA1 = _mm_sign_epi8(ali1, ali1);
  839. __m128i sepBA0 = _mm_sign_epi8(blj0, ali0);
  840. __m128i sepBA1 = _mm_sign_epi8(blj1, ali1);
  841. // updot
  842. const __m128i oneFill = _mm_set1_epi16(1);
  843. __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0);
  844. __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1);
  845. __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0)));
  846. #endif
  847. Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
  848. unhalf(B[ldb * (jj + j) + l].d)),
  849. udTmp,
  850. Cv[j][i]);
  851. }
  852. for (int64_t j = 0; j < RN; ++j)
  853. for (int64_t i = 0; i < RM; ++i)
  854. C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
  855. }
  856. }
  857. inline __m256i load(const block_q8_0 *b) {
  858. return _mm256_loadu_si256((const __m256i *)b->qs);
  859. }
  860. inline __m128i load0(const block_q8_0 *b) {
  861. return _mm_loadu_si128((const __m128i *)b->qs);
  862. }
  863. inline __m128i load1(const block_q8_0 *b) {
  864. return _mm_loadu_si128(((const __m128i *)b->qs) + 1);
  865. }
  866. inline __m256i load(const block_q4_0 *b) {
  867. return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8));
  868. }
  869. inline __m128i load0(const block_q4_0 *b) {
  870. const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
  871. return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8));
  872. }
  873. inline __m128i load1(const block_q4_0 *b) {
  874. const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
  875. return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
  876. }
  877. inline __m256i load(const block_q5_0 *b) {
  878. return _mm256_or_si256(denibble(b->qs), bittobyte(b->qh));
  879. }
  880. inline __m128i load0(const block_q5_0* b) {
  881. const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
  882. uint32_t x32;
  883. memcpy(&x32, b->qh, sizeof(uint32_t));
  884. __m128i qxl = _mm_and_si128(_mm_set1_epi8(15), x);
  885. __m128i bytesl = _mm_cmpeq_epi8(_mm_set1_epi64x(-1),
  886. _mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe),
  887. _mm_shuffle_epi8(_mm_set1_epi32(x32),
  888. _mm_set_epi64x(0x0101010101010101, 0x0000000000000000))));
  889. bytesl = _mm_andnot_si128(bytesl, _mm_set1_epi8((char)0xF0));
  890. return _mm_or_si128(qxl, bytesl);
  891. }
  892. inline __m128i load1(const block_q5_0* b) {
  893. const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
  894. uint32_t x32;
  895. memcpy(&x32, b->qh, sizeof(uint32_t));
  896. __m128i qxh = _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4));
  897. __m128i bytesh = _mm_cmpeq_epi8(_mm_set1_epi64x(-1),
  898. _mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe),
  899. _mm_shuffle_epi8(_mm_set1_epi32(x32),
  900. _mm_set_epi64x(0x0303030303030303, 0x0202020202020202))));
  901. bytesh = _mm_andnot_si128(bytesh, _mm_set1_epi8((char)0xF0));
  902. return _mm_or_si128(qxh, bytesh);
  903. }
  904. inline __m256i load(const block_iq4_nl *b) {
  905. return MM256_SET_M128I(load1(b), load0(b));
  906. }
  907. inline __m128i load0(const block_iq4_nl *b) {
  908. const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
  909. return _mm_shuffle_epi8(iq4nlt, _mm_and_si128(_mm_set1_epi8(15), x));
  910. }
  911. inline __m128i load1(const block_iq4_nl *b) {
  912. const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
  913. return _mm_shuffle_epi8(iq4nlt, _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)));
  914. }
  915. inline __m256 updot(__m256i u, __m256i s) {
  916. __m256i res;
  917. #if defined(__AVX512VNNI__) && defined(__AVX512VL__)
  918. res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
  919. #elif defined(__AVXVNNI__)
  920. res = _mm256_dpbusd_avx_epi32(_mm256_setzero_si256(), u, s);
  921. #else
  922. res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
  923. #endif
  924. return _mm256_cvtepi32_ps(res);
  925. }
  926. static inline __m256i denibble(const uint8_t *p) {
  927. __m128i x = _mm_loadu_si128((const __m128i *)p);
  928. return _mm256_and_si256(_mm256_set1_epi8(15),
  929. _mm256_insertf128_si256(_mm256_castsi128_si256(x),
  930. _mm_srli_epi16(x, 4), 1));
  931. }
  932. static inline __m256i bittobyte(const uint8_t *p) {
  933. uint32_t x32;
  934. memcpy(&x32, p, sizeof(uint32_t));
  935. __m256i bytes = _mm256_cmpeq_epi8(_mm256_set1_epi64x(-1),
  936. _mm256_or_si256(_mm256_set1_epi64x(0x7fbfdfeff7fbfdfe),
  937. _mm256_shuffle_epi8(_mm256_set1_epi32(x32),
  938. _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202,
  939. 0x0101010101010101, 0x0000000000000000))));
  940. return _mm256_andnot_si256(bytes, _mm256_set1_epi8((char)0xF0));
  941. }
  942. const TA *const A;
  943. const TB *const B;
  944. TC *const C;
  945. const int64_t k;
  946. const int64_t lda;
  947. const int64_t ldb;
  948. const int64_t ldc;
  949. const int ith;
  950. const int nth;
  951. };
  952. #endif // __AVX__
  953. //PPC Implementation
  954. #if defined(__MMA__)
  955. #define SAVE_ACC(ACC, ii, jj) \
  956. __builtin_mma_disassemble_acc(vec_C, ACC); \
  957. for (int I = 0; I < 4; I++) { \
  958. for (int J = 0; J < 4; J++) { \
  959. *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J); \
  960. } \
  961. } \
  962. template <typename TA, typename TB, typename TC>
  963. class tinyBLAS_PPC {
  964. public:
  965. tinyBLAS_PPC(int64_t k,
  966. const TA *A, int64_t lda,
  967. const TB *B, int64_t ldb,
  968. TC *C, int64_t ldc,
  969. int ith, int nth)
  970. : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
  971. }
  972. void matmul(int64_t m, int64_t n) {
  973. mnpack(0, m, 0, n);
  974. }
  975. private:
  976. void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
  977. void READ_BLOCK(const float* a, int64_t lda, int rows, int cols, float* vec) {
  978. int64_t i, j;
  979. float *aoffset = NULL, *boffset = NULL;
  980. float *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
  981. float *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
  982. aoffset = const_cast<float*>(a);
  983. boffset = vec;
  984. j = (rows >> 3);
  985. if (j > 0) {
  986. do {
  987. aoffset1 = aoffset;
  988. aoffset2 = aoffset1 + lda;
  989. aoffset3 = aoffset2 + lda;
  990. aoffset4 = aoffset3 + lda;
  991. aoffset5 = aoffset4 + lda;
  992. aoffset6 = aoffset5 + lda;
  993. aoffset7 = aoffset6 + lda;
  994. aoffset8 = aoffset7 + lda;
  995. aoffset += 8 * lda;
  996. i = (cols >> 3);
  997. if (i > 0) {
  998. __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
  999. vector float c1[2], c2[2], c3[2], c4[2], c5[2], c6[2], c7[2], c8[2];
  1000. vector float t1, t2, t3, t4, t5, t6, t7, t8;
  1001. do {
  1002. C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
  1003. C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
  1004. C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
  1005. C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
  1006. C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5);
  1007. C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6);
  1008. C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7);
  1009. C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8);
  1010. __builtin_vsx_disassemble_pair(c1, &C1);
  1011. __builtin_vsx_disassemble_pair(c2, &C2);
  1012. __builtin_vsx_disassemble_pair(c3, &C3);
  1013. __builtin_vsx_disassemble_pair(c4, &C4);
  1014. __builtin_vsx_disassemble_pair(c5, &C5);
  1015. __builtin_vsx_disassemble_pair(c6, &C6);
  1016. __builtin_vsx_disassemble_pair(c7, &C7);
  1017. __builtin_vsx_disassemble_pair(c8, &C8);
  1018. t1 = vec_mergeh(c1[0], c2[0]);
  1019. t2 = vec_mergeh(c3[0], c4[0]);
  1020. t3 = vec_mergeh(c5[0], c6[0]);
  1021. t4 = vec_mergeh(c7[0], c8[0]);
  1022. t5 = vec_xxpermdi(t1, t2, 0);
  1023. t6 = vec_xxpermdi(t3, t4, 0);
  1024. t7 = vec_xxpermdi(t1, t2, 3);
  1025. t8 = vec_xxpermdi(t3, t4, 3);
  1026. vec_xst(t5, 0, boffset);
  1027. vec_xst(t6, 0, boffset+4);
  1028. vec_xst(t7, 0, boffset+8);
  1029. vec_xst(t8, 0, boffset+12);
  1030. t1 = vec_mergel(c1[0], c2[0]);
  1031. t2 = vec_mergel(c3[0], c4[0]);
  1032. t3 = vec_mergel(c5[0], c6[0]);
  1033. t4 = vec_mergel(c7[0], c8[0]);
  1034. t5 = vec_xxpermdi(t1, t2, 0);
  1035. t6 = vec_xxpermdi(t3, t4, 0);
  1036. t7 = vec_xxpermdi(t1, t2, 3);
  1037. t8 = vec_xxpermdi(t3, t4, 3);
  1038. vec_xst(t5, 0, boffset+16);
  1039. vec_xst(t6, 0, boffset+20);
  1040. vec_xst(t7, 0, boffset+24);
  1041. vec_xst(t8, 0, boffset+28);
  1042. t1 = vec_mergeh(c1[1], c2[1]);
  1043. t2 = vec_mergeh(c3[1], c4[1]);
  1044. t3 = vec_mergeh(c5[1], c6[1]);
  1045. t4 = vec_mergeh(c7[1], c8[1]);
  1046. t5 = vec_xxpermdi(t1, t2, 0);
  1047. t6 = vec_xxpermdi(t3, t4, 0);
  1048. t7 = vec_xxpermdi(t1, t2, 3);
  1049. t8 = vec_xxpermdi(t3, t4, 3);
  1050. vec_xst(t5, 0, boffset+32);
  1051. vec_xst(t6, 0, boffset+36);
  1052. vec_xst(t7, 0, boffset+40);
  1053. vec_xst(t8, 0, boffset+44);
  1054. t1 = vec_mergel(c1[1], c2[1]);
  1055. t2 = vec_mergel(c3[1], c4[1]);
  1056. t3 = vec_mergel(c5[1], c6[1]);
  1057. t4 = vec_mergel(c7[1], c8[1]);
  1058. t5 = vec_xxpermdi(t1, t2, 0);
  1059. t6 = vec_xxpermdi(t3, t4, 0);
  1060. t7 = vec_xxpermdi(t1, t2, 3);
  1061. t8 = vec_xxpermdi(t3, t4, 3);
  1062. vec_xst(t5, 0, boffset+48);
  1063. vec_xst(t6, 0, boffset+52);
  1064. vec_xst(t7, 0, boffset+56);
  1065. vec_xst(t8, 0, boffset+60);
  1066. aoffset1 += 8*lda;
  1067. aoffset2 += 8*lda;
  1068. aoffset3 += 8*lda;
  1069. aoffset4 += 8*lda;
  1070. boffset += 64;
  1071. i--;
  1072. } while(i > 0);
  1073. }
  1074. if (cols & 4) {
  1075. vector float c1, c2, c3, c4, c5, c6, c7, c8;
  1076. vector float t1, t2, t3, t4, t5, t6, t7, t8;
  1077. c1 = vec_xl(0, aoffset1);
  1078. c2 = vec_xl(0, aoffset2);
  1079. c3 = vec_xl(0, aoffset3);
  1080. c4 = vec_xl(0, aoffset4);
  1081. c5 = vec_xl(0, aoffset5);
  1082. c6 = vec_xl(0, aoffset6);
  1083. c7 = vec_xl(0, aoffset7);
  1084. c8 = vec_xl(0, aoffset8);
  1085. t1 = vec_mergeh(c1, c2);
  1086. t2 = vec_mergeh(c3, c4);
  1087. t3 = vec_mergeh(c5, c6);
  1088. t4 = vec_mergeh(c7, c8);
  1089. t5 = vec_xxpermdi(t1, t2, 0);
  1090. t6 = vec_xxpermdi(t3, t4, 0);
  1091. t7 = vec_xxpermdi(t1, t2, 3);
  1092. t8 = vec_xxpermdi(t3, t4, 3);
  1093. vec_xst(t5, 0, boffset);
  1094. vec_xst(t6, 0, boffset+4);
  1095. vec_xst(t7, 0, boffset+8);
  1096. vec_xst(t8, 0, boffset+12);
  1097. t1 = vec_mergel(c1, c2);
  1098. t2 = vec_mergel(c3, c4);
  1099. t3 = vec_mergel(c5, c6);
  1100. t4 = vec_mergel(c7, c8);
  1101. t5 = vec_xxpermdi(t1, t2, 0);
  1102. t6 = vec_xxpermdi(t3, t4, 0);
  1103. t7 = vec_xxpermdi(t1, t2, 3);
  1104. t8 = vec_xxpermdi(t3, t4, 3);
  1105. vec_xst(t5, 0, boffset+16);
  1106. vec_xst(t6, 0, boffset+20);
  1107. vec_xst(t7, 0, boffset+24);
  1108. vec_xst(t8, 0, boffset+28);
  1109. }
  1110. j--;
  1111. } while(j > 0);
  1112. }
  1113. if (rows & 4) {
  1114. aoffset1 = aoffset;
  1115. aoffset2 = aoffset1 + lda;
  1116. aoffset3 = aoffset2 + lda;
  1117. aoffset4 = aoffset3 + lda;
  1118. aoffset += 4 * lda;
  1119. i = (cols >> 3);
  1120. if (i > 0) {
  1121. __vector_pair C1, C2, C3, C4;
  1122. vector float c1[2], c2[2], c3[2], c4[2];
  1123. vector float t1, t2, t3, t4, t5, t6, t7, t8;
  1124. do {
  1125. C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
  1126. C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
  1127. C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
  1128. C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
  1129. __builtin_vsx_disassemble_pair(c1, &C1);
  1130. __builtin_vsx_disassemble_pair(c2, &C2);
  1131. __builtin_vsx_disassemble_pair(c3, &C3);
  1132. __builtin_vsx_disassemble_pair(c4, &C4);
  1133. t1 = vec_mergeh(c1[0], c2[0]);
  1134. t2 = vec_mergeh(c3[0], c4[0]);
  1135. t3 = vec_mergel(c1[0], c2[0]);
  1136. t4 = vec_mergel(c3[0], c4[0]);
  1137. t5 = vec_xxpermdi(t1, t2, 0);
  1138. t6 = vec_xxpermdi(t1, t2, 3);
  1139. t7 = vec_xxpermdi(t3, t4, 0);
  1140. t8 = vec_xxpermdi(t3, t4, 3);
  1141. vec_xst(t5, 0, boffset);
  1142. vec_xst(t6, 0, boffset+4);
  1143. vec_xst(t7, 0, boffset+8);
  1144. vec_xst(t8, 0, boffset+12);
  1145. t1 = vec_mergeh(c1[1], c2[1]);
  1146. t2 = vec_mergeh(c3[1], c4[1]);
  1147. t3 = vec_mergel(c1[1], c2[1]);
  1148. t4 = vec_mergel(c3[1], c4[1]);
  1149. t5 = vec_xxpermdi(t1, t2, 0);
  1150. t6 = vec_xxpermdi(t1, t2, 3);
  1151. t7 = vec_xxpermdi(t3, t4, 0);
  1152. t8 = vec_xxpermdi(t3, t4, 3);
  1153. vec_xst(t5, 0, boffset+16);
  1154. vec_xst(t6, 0, boffset+20);
  1155. vec_xst(t7, 0, boffset+24);
  1156. vec_xst(t8, 0, boffset+28);
  1157. aoffset1 += 8*lda;
  1158. aoffset2 += 8*lda;
  1159. aoffset3 += 8*lda;
  1160. aoffset4 += 8*lda;
  1161. boffset += 32;
  1162. i--;
  1163. } while(i > 0);
  1164. }
  1165. if (cols & 4) {
  1166. vector float c1, c2, c3, c4;
  1167. vector float t1, t2, t3, t4;
  1168. c1 = vec_xl(0, aoffset1);
  1169. c2 = vec_xl(0, aoffset2);
  1170. c3 = vec_xl(0, aoffset3);
  1171. c4 = vec_xl(0, aoffset4);
  1172. t1 = vec_mergeh(c1, c2);
  1173. t2 = vec_mergeh(c3, c4);
  1174. t3 = vec_xxpermdi(t1, t2, 0);
  1175. t4 = vec_xxpermdi(t1, t2, 3);
  1176. vec_xst(t3, 0, boffset);
  1177. vec_xst(t4, 0, boffset+4);
  1178. t1 = vec_mergel(c1, c2);
  1179. t2 = vec_mergel(c3, c4);
  1180. t3 = vec_xxpermdi(t1, t2, 0);
  1181. t4 = vec_xxpermdi(t1, t2, 3);
  1182. vec_xst(t3, 0, boffset+8);
  1183. vec_xst(t4, 0, boffset+12);
  1184. }
  1185. }
  1186. if (rows & 3) {
  1187. aoffset1 = aoffset;
  1188. aoffset2 = aoffset1 + lda;
  1189. aoffset3 = aoffset2 + lda;
  1190. if (cols & 4) {
  1191. vector float c1, c2, c3, c4 = {0};
  1192. vector float t1, t2, t3, t4;
  1193. c1 = vec_xl(0, aoffset1);
  1194. c2 = vec_xl(0, aoffset2);
  1195. c3 = vec_xl(0, aoffset3);
  1196. t1 = vec_mergeh(c1, c2);
  1197. t2 = vec_mergeh(c3, c4);
  1198. t3 = vec_xxpermdi(t1, t2, 0);
  1199. t4 = vec_xxpermdi(t1, t2, 3);
  1200. vec_xst(t3, 0, boffset);
  1201. vec_xst(t4, 0, boffset+4);
  1202. t1 = vec_mergel(c1, c2);
  1203. t2 = vec_mergel(c3, c4);
  1204. t3 = vec_xxpermdi(t1, t2, 0);
  1205. t4 = vec_xxpermdi(t1, t2, 3);
  1206. vec_xst(t3, 0, boffset+8);
  1207. vec_xst(t4, 0, boffset+12);
  1208. }
  1209. }
  1210. }
  1211. void KERNEL_4x4(int64_t ii, int64_t jj) {
  1212. vec_t vec_A[4], vec_B[4], vec_C[4];
  1213. acc_t acc_0;
  1214. __builtin_mma_xxsetaccz(&acc_0);
  1215. for (int l = 0; l < k; l+=4) {
  1216. READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
  1217. READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
  1218. __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
  1219. __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
  1220. __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
  1221. __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
  1222. }
  1223. SAVE_ACC(&acc_0, ii, jj);
  1224. }
  1225. void KERNEL_4x8(int64_t ii, int64_t jj) {
  1226. vec_t vec_A[4], vec_B[8], vec_C[4];
  1227. acc_t acc_0, acc_1;
  1228. __builtin_mma_xxsetaccz(&acc_0);
  1229. __builtin_mma_xxsetaccz(&acc_1);
  1230. for (int64_t l = 0; l < k; l+=4) {
  1231. READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
  1232. READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B);
  1233. __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
  1234. __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
  1235. __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
  1236. __builtin_mma_xvf32gerpp(&acc_1, vec_A[1], (vec_t)vec_B[3]);
  1237. __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], (vec_t)vec_B[4]);
  1238. __builtin_mma_xvf32gerpp(&acc_1, vec_A[2], (vec_t)vec_B[5]);
  1239. __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], (vec_t)vec_B[6]);
  1240. __builtin_mma_xvf32gerpp(&acc_1, vec_A[3], (vec_t)vec_B[7]);
  1241. }
  1242. SAVE_ACC(&acc_0, ii, jj);
  1243. SAVE_ACC(&acc_1, ii, jj+4);
  1244. }
  1245. void KERNEL_8x4(int64_t ii, int64_t jj) {
  1246. vec_t vec_A[8], vec_B[4], vec_C[4];
  1247. acc_t acc_0, acc_1;
  1248. __builtin_mma_xxsetaccz(&acc_0);
  1249. __builtin_mma_xxsetaccz(&acc_1);
  1250. for (int64_t l = 0; l < k; l+=4) {
  1251. READ_BLOCK(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A);
  1252. READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
  1253. __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
  1254. __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
  1255. __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
  1256. __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[3], vec_B[1]);
  1257. __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[4], vec_B[2]);
  1258. __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[5], vec_B[2]);
  1259. __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[6], vec_B[3]);
  1260. __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[7], vec_B[3]);
  1261. }
  1262. SAVE_ACC(&acc_0, ii, jj);
  1263. SAVE_ACC(&acc_1, ii+4, jj);
  1264. }
  1265. void KERNEL_8x8(int64_t ii, int64_t jj) {
  1266. vec_t vec_A[16], vec_B[16], vec_C[4];
  1267. acc_t acc_0, acc_1, acc_2, acc_3;
  1268. __builtin_mma_xxsetaccz(&acc_0);
  1269. __builtin_mma_xxsetaccz(&acc_1);
  1270. __builtin_mma_xxsetaccz(&acc_2);
  1271. __builtin_mma_xxsetaccz(&acc_3);
  1272. for (int l = 0; l < k; l+=8) {
  1273. READ_BLOCK(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A);
  1274. READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B);
  1275. for(int x = 0; x < 16; x+=2) {
  1276. __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
  1277. __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]);
  1278. __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x+1], vec_B[x]);
  1279. __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x+1], vec_B[x+1]);
  1280. }
  1281. }
  1282. SAVE_ACC(&acc_0, ii, jj);
  1283. SAVE_ACC(&acc_1, ii, jj+4);
  1284. SAVE_ACC(&acc_2, ii+4, jj);
  1285. SAVE_ACC(&acc_3, ii+4, jj+4);
  1286. }
  1287. void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  1288. int64_t mc, nc, mp, np;
  1289. int m_rem = MIN(m - m0, 16);
  1290. int n_rem = MIN(n - n0, 16);
  1291. if (m_rem >= 16 && n_rem >= 8) {
  1292. mc = 8;
  1293. nc = 8;
  1294. gemm<8,8>(m0, m, n0, n);
  1295. } else if(m_rem >= 8 && n_rem >= 16) {
  1296. mc = 8;
  1297. nc = 8;
  1298. gemm<8,8>(m0, m, n0, n);
  1299. } else if (m_rem >= 8 && n_rem >= 8) {
  1300. mc = 8;
  1301. nc = 8;
  1302. gemm<8,8>(m0, m, n0, n);
  1303. } else if (m_rem >= 4 && n_rem >= 8) {
  1304. mc = 4;
  1305. nc = 8;
  1306. gemm<4,8>(m0, m, n0, n);
  1307. } else if (m_rem >= 8 && n_rem >= 4) {
  1308. mc = 8;
  1309. nc = 4;
  1310. gemm<8,4>(m0, m, n0, n);
  1311. } else if (m_rem >= 4 && n_rem >= 4) {
  1312. mc = 4;
  1313. nc = 4;
  1314. gemm<4,4>(m0, m, n0, n);
  1315. } else if ((m_rem < 4) && (n_rem > 4)) {
  1316. nc = 4;
  1317. switch(m_rem) {
  1318. case 1:
  1319. mc = 1;
  1320. gemm_small(m0, m, n0, n, mc, nc);
  1321. break;
  1322. case 2:
  1323. mc = 2;
  1324. gemm_small(m0, m, n0, n, mc, nc);
  1325. break;
  1326. case 3:
  1327. mc = 3;
  1328. gemm_small(m0, m, n0, n, mc, nc);
  1329. break;
  1330. default:
  1331. return;
  1332. }
  1333. } else if ((m_rem > 4) && (n_rem < 4)) {
  1334. mc = 4;
  1335. switch(n_rem) {
  1336. case 1:
  1337. nc = 1;
  1338. gemm_small(m0, m, n0, n, mc, nc);
  1339. break;
  1340. case 2:
  1341. nc = 2;
  1342. gemm_small(m0, m, n0, n, mc, nc);
  1343. break;
  1344. case 3:
  1345. nc = 3;
  1346. gemm_small(m0, m, n0, n, mc, nc);
  1347. break;
  1348. default:
  1349. return;
  1350. }
  1351. } else {
  1352. switch((m_rem << 4) | n_rem) {
  1353. case 0x43:
  1354. mc = 4;
  1355. nc = 3;
  1356. gemm_small(m0, m, n0, n, mc, nc);
  1357. break;
  1358. case 0x42:
  1359. mc = 4;
  1360. nc = 2;
  1361. gemm_small(m0, m, n0, n, mc, nc);
  1362. break;
  1363. case 0x41:
  1364. mc = 4;
  1365. nc = 1;
  1366. gemm_small(m0, m, n0, n, mc, nc);
  1367. break;
  1368. case 0x34:
  1369. mc = 3;
  1370. nc = 4;
  1371. gemm_small(m0, m, n0, n, mc, nc);
  1372. break;
  1373. case 0x33:
  1374. mc = 3;
  1375. nc = 3;
  1376. gemm_small(m0, m, n0, n, mc, nc);
  1377. break;
  1378. case 0x32:
  1379. mc = 3;
  1380. nc = 2;
  1381. gemm_small(m0, m, n0, n, mc, nc);
  1382. break;
  1383. case 0x31:
  1384. mc = 3;
  1385. nc = 1;
  1386. gemm_small(m0, m, n0, n, mc, nc);
  1387. break;
  1388. case 0x24:
  1389. mc = 2;
  1390. nc = 4;
  1391. gemm_small(m0, m, n0, n, mc, nc);
  1392. break;
  1393. case 0x23:
  1394. mc = 2;
  1395. nc = 3;
  1396. gemm_small(m0, m, n0, n, mc, nc);
  1397. break;
  1398. case 0x22:
  1399. mc = 2;
  1400. nc = 2;
  1401. gemm_small(m0, m, n0, n, mc, nc);
  1402. break;
  1403. case 0x21:
  1404. mc = 2;
  1405. nc = 1;
  1406. gemm_small(m0, m, n0, n, mc, nc);
  1407. break;
  1408. case 0x14:
  1409. mc = 1;
  1410. nc = 4;
  1411. gemm_small(m0, m, n0, n, mc, nc);
  1412. break;
  1413. case 0x13:
  1414. mc = 1;
  1415. nc = 3;
  1416. gemm_small(m0, m, n0, n, mc, nc);
  1417. break;
  1418. case 0x12:
  1419. mc = 1;
  1420. nc = 2;
  1421. gemm_small(m0, m, n0, n, mc, nc);
  1422. break;
  1423. case 0x11:
  1424. mc = 1;
  1425. nc = 1;
  1426. gemm_small(m0, m, n0, n, mc, nc);
  1427. break;
  1428. default:
  1429. return;
  1430. }
  1431. }
  1432. mp = m0 + (m - m0) / mc * mc;
  1433. np = n0 + (n - n0) / nc * nc;
  1434. mnpack(mp, m, n0, np);
  1435. mnpack(m0, m, np, n);
  1436. }
  1437. void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
  1438. int64_t ytiles = (m - m0) / RM;
  1439. int64_t xtiles = (n - n0) / RN;
  1440. int64_t tiles = xtiles * ytiles;
  1441. int64_t duty = (tiles + nth - 1) / nth;
  1442. int64_t start = duty * ith;
  1443. int64_t end = start + duty;
  1444. if (end > tiles)
  1445. end = tiles;
  1446. for (int64_t job = start; job < end; ++job) {
  1447. int64_t ii = m0 + job / xtiles * RM;
  1448. int64_t jj = n0 + job % xtiles * RN;
  1449. vec_t vec_C[4];
  1450. acc_t acc_0;
  1451. __builtin_mma_xxsetaccz(&acc_0);
  1452. vec_t vec_A[4], vec_B[4];
  1453. for (int l=0; l<k; l+=4) {
  1454. if (RN >= 4 && RM == 1) {
  1455. float* a = const_cast<float*>(A+(ii)*lda+l);
  1456. READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
  1457. vec_A[0] = (vec_t)vec_xl(0,a);
  1458. vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1));
  1459. vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2));
  1460. vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3));
  1461. } else {
  1462. READ_BLOCK(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
  1463. READ_BLOCK(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
  1464. }
  1465. __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
  1466. __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
  1467. __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
  1468. __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
  1469. }
  1470. __builtin_mma_disassemble_acc(vec_C, &acc_0);
  1471. for (int I = 0; I < RM; I++) {
  1472. for (int J = 0; J < RN; J++) {
  1473. *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J);
  1474. }
  1475. }
  1476. }
  1477. }
  1478. template <int RM, int RN>
  1479. NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  1480. int64_t ytiles = (m - m0) / RM;
  1481. int64_t xtiles = (n - n0) / RN;
  1482. int64_t tiles = xtiles * ytiles;
  1483. int64_t duty = (tiles + nth - 1) / nth;
  1484. int64_t start = duty * ith;
  1485. int64_t end = start + duty;
  1486. if (RM == 4 && RN == 4) {
  1487. kernel = &tinyBLAS_PPC::KERNEL_4x4;
  1488. } else if (RM == 4 && RN == 8) {
  1489. kernel = &tinyBLAS_PPC::KERNEL_4x8;
  1490. } else if (RM == 8 && RN == 4) {
  1491. kernel = &tinyBLAS_PPC::KERNEL_8x4;
  1492. } else if (RM == 8 && RN == 8) {
  1493. kernel = &tinyBLAS_PPC::KERNEL_8x8;
  1494. }
  1495. if (end > tiles)
  1496. end = tiles;
  1497. for (int64_t job = start; job < end; ++job) {
  1498. int64_t ii = m0 + job / xtiles * RM;
  1499. int64_t jj = n0 + job % xtiles * RN;
  1500. (this->*kernel)(ii, jj);
  1501. }
  1502. }
  1503. const TA *const A;
  1504. const TB *const B;
  1505. TC *C;
  1506. TA *At;
  1507. TB *Bt;
  1508. const int64_t k;
  1509. const int64_t lda;
  1510. const int64_t ldb;
  1511. const int64_t ldc;
  1512. const int ith;
  1513. const int nth;
  1514. };
  1515. #endif
  1516. } // namespace
  1517. /**
  1518. * Performs optimized matrix multiplication on CPU.
  1519. *
  1520. * This subroutine may compute C = Aᵀ * B with column major ordering.
  1521. * Despite its name, this isn't a generalized implementation. Work is
  1522. * only performed when a handwritten kernel is written and available.
  1523. * Otherwise the caller should fall back to a general matmul routine.
  1524. *
  1525. * For example, for single-threaded single-precision GEMM you can say
  1526. *
  1527. * llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc,
  1528. * 0, 1,
  1529. * GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32);
  1530. *
  1531. * @param m is rows in `A` and `C`
  1532. * @param n is cols in `B` and `C`
  1533. * @param k is cols in `A` and rows in `B`
  1534. * @param A is first input matrix (always transposed)
  1535. * @param lda is row stride of `A`
  1536. * @param B is second input matrix (never transposed)
  1537. * @param ldb is row stride of `B`
  1538. * @param C is input/output array of output matrices
  1539. * @param ldc is row stride of `C`
  1540. * @param ith is thread id (must be less than `nth`)
  1541. * @param nth is number of threads (must be greater than zero)
  1542. * @param Atype is GGML data type of `A`
  1543. * @param Btype is GGML data type of `B`
  1544. * @param Ctype is GGML data type of `C`
  1545. * @return true if this function was able to service the matmul request
  1546. */
  1547. bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
  1548. const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
  1549. int64_t ldc, int Atype, int Btype, int Ctype) {
  1550. assert(m >= 0);
  1551. assert(n >= 0);
  1552. assert(k >= 0);
  1553. assert(lda >= k);
  1554. assert(ldb >= k);
  1555. assert(ldc >= m);
  1556. assert(params->nth > 0);
  1557. assert(params->ith < params->nth);
  1558. // only enable sgemm for prompt processing
  1559. if (n < 2)
  1560. return false;
  1561. if (Ctype != GGML_TYPE_F32)
  1562. return false;
  1563. switch (Atype) {
  1564. case GGML_TYPE_F32: {
  1565. if (Btype != GGML_TYPE_F32)
  1566. return false;
  1567. #if defined(__AVX512F__)
  1568. tinyBLAS<16, __m512, __m512, float, float, float> tb{ params,
  1569. k, (const float *)A, lda,
  1570. (const float *)B, ldb,
  1571. (float *)C, ldc};
  1572. return tb.matmul(m, n);
  1573. #elif defined(__AVX__) || defined(__AVX2__)
  1574. tinyBLAS<8, __m256, __m256, float, float, float> tb{ params,
  1575. k, (const float *)A, lda,
  1576. (const float *)B, ldb,
  1577. (float *)C, ldc};
  1578. return tb.matmul(m, n);
  1579. #elif defined(__ARM_NEON)
  1580. if (n < 4)
  1581. return false;
  1582. tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ params,
  1583. k, (const float *)A, lda,
  1584. (const float *)B, ldb,
  1585. (float *)C, ldc};
  1586. return tb.matmul(m, n);
  1587. #elif defined(__MMA__)
  1588. if (k % 8)
  1589. return false;
  1590. tinyBLAS_PPC<float, float, float> tb{
  1591. k, (const float *)A, lda,
  1592. (const float *)B, ldb,
  1593. (float *)C, ldc,
  1594. params->ith, params->nth};
  1595. tb.matmul(m, n);
  1596. return true;
  1597. #else
  1598. return false;
  1599. #endif
  1600. }
  1601. case GGML_TYPE_BF16: {
  1602. #if defined(__AVX512BF16__)
  1603. if (Btype == GGML_TYPE_BF16) {
  1604. tinyBLAS<32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,
  1605. (const ggml_bf16_t *)A, lda,
  1606. (const ggml_bf16_t *)B, ldb,
  1607. (float *)C, ldc};
  1608. return tb.matmul(m, n);
  1609. }
  1610. #elif defined(__AVX512F__)
  1611. if (Btype == GGML_TYPE_BF16) {
  1612. tinyBLAS<16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,
  1613. (const ggml_bf16_t *)A, lda,
  1614. (const ggml_bf16_t *)B, ldb,
  1615. (float *)C, ldc};
  1616. return tb.matmul(m, n);
  1617. }
  1618. #elif defined(__AVX2__)
  1619. if (Btype == GGML_TYPE_BF16) {
  1620. tinyBLAS<8, __m256, __m256, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,
  1621. (const ggml_bf16_t *)A, lda,
  1622. (const ggml_bf16_t *)B, ldb,
  1623. (float *)C, ldc};
  1624. return tb.matmul(m, n);
  1625. }
  1626. #endif
  1627. return false;
  1628. }
  1629. case GGML_TYPE_F16: {
  1630. #if defined(__AVX512F__)
  1631. if (Btype == GGML_TYPE_F16) {
  1632. tinyBLAS<16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, float> tb{ params, k,
  1633. (const ggml_fp16_t *)A, lda,
  1634. (const ggml_fp16_t *)B, ldb,
  1635. (float *)C, ldc};
  1636. return tb.matmul(m, n);
  1637. }
  1638. #elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
  1639. if (Btype == GGML_TYPE_F16) {
  1640. tinyBLAS<8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, float> tb{ params, k,
  1641. (const ggml_fp16_t *)A, lda,
  1642. (const ggml_fp16_t *)B, ldb,
  1643. (float *)C, ldc};
  1644. return tb.matmul(m, n);
  1645. }
  1646. #elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
  1647. if (n < 8)
  1648. return false;
  1649. if (Btype == GGML_TYPE_F16) {
  1650. tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
  1651. k, (const ggml_fp16_t *)A, lda,
  1652. (const ggml_fp16_t *)B, ldb,
  1653. (float *)C, ldc};
  1654. return tb.matmul(m, n);
  1655. }
  1656. #elif defined(__ARM_NEON) && !defined(_MSC_VER)
  1657. if (Btype == GGML_TYPE_F32) {
  1658. tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{ params,
  1659. k, (const ggml_fp16_t *)A, lda,
  1660. (const float *)B, ldb,
  1661. (float *)C, ldc};
  1662. return tb.matmul(m, n);
  1663. }
  1664. #endif
  1665. return false;
  1666. }
  1667. case GGML_TYPE_Q8_0: {
  1668. if (Btype != GGML_TYPE_Q8_0)
  1669. return false;
  1670. #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
  1671. tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{
  1672. k, (const block_q8_0 *)A, lda,
  1673. (const block_q8_0 *)B, ldb,
  1674. (float *)C, ldc,
  1675. params->ith, params->nth};
  1676. tb.matmul(m, n);
  1677. return true;
  1678. #elif defined(__ARM_FEATURE_DOTPROD)
  1679. tinyBLAS_Q0_ARM<block_q8_0> tb{
  1680. k, (const block_q8_0 *)A, lda,
  1681. (const block_q8_0 *)B, ldb,
  1682. (float *)C, ldc,
  1683. params->ith, params->nth};
  1684. tb.matmul(m, n);
  1685. return true;
  1686. #else
  1687. return false;
  1688. #endif
  1689. }
  1690. case GGML_TYPE_Q4_0: {
  1691. if (Btype != GGML_TYPE_Q8_0)
  1692. return false;
  1693. #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
  1694. tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{
  1695. k, (const block_q4_0 *)A, lda,
  1696. (const block_q8_0 *)B, ldb,
  1697. (float *)C, ldc,
  1698. params->ith, params->nth};
  1699. tb.matmul(m, n);
  1700. return true;
  1701. #elif defined(__ARM_FEATURE_DOTPROD)
  1702. tinyBLAS_Q0_ARM<block_q4_0> tb{
  1703. k, (const block_q4_0 *)A, lda,
  1704. (const block_q8_0 *)B, ldb,
  1705. (float *)C, ldc,
  1706. params->ith, params->nth};
  1707. tb.matmul(m, n);
  1708. return true;
  1709. #else
  1710. return false;
  1711. #endif
  1712. }
  1713. case GGML_TYPE_Q5_0: {
  1714. if (Btype != GGML_TYPE_Q8_0)
  1715. return false;
  1716. #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
  1717. tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float> tb{
  1718. k, (const block_q5_0 *)A, lda,
  1719. (const block_q8_0 *)B, ldb,
  1720. (float *)C, ldc,
  1721. params->ith, params->nth};
  1722. tb.matmul(m, n);
  1723. return true;
  1724. #else
  1725. return false;
  1726. #endif
  1727. }
  1728. case GGML_TYPE_IQ4_NL: {
  1729. if (Btype != GGML_TYPE_Q8_0)
  1730. return false;
  1731. #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
  1732. tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float> tb{
  1733. k, (const block_iq4_nl *)A, lda,
  1734. (const block_q8_0 *)B, ldb,
  1735. (float *)C, ldc,
  1736. params->ith, params->nth};
  1737. tb.matmul(m, n);
  1738. return true;
  1739. #else
  1740. return false;
  1741. #endif
  1742. }
  1743. default:
  1744. return false;
  1745. }
  1746. (void)params;
  1747. (void)m;
  1748. (void)n;
  1749. (void)k;
  1750. (void)A;
  1751. (void)lda;
  1752. (void)B;
  1753. (void)ldb;
  1754. (void)C;
  1755. (void)ldc;
  1756. (void)Atype;
  1757. (void)Btype;
  1758. (void)Ctype;
  1759. }