sgemm.cpp 66 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884
  1. // Copyright 2024 Mozilla Foundation
  2. //
  3. // Permission is hereby granted, free of charge, to any person obtaining
  4. // a copy of this software and associated documentation files (the
  5. // "Software"), to deal in the Software without restriction, including
  6. // without limitation the rights to use, copy, modify, merge, publish,
  7. // distribute, sublicense, and/or sell copies of the Software, and to
  8. // permit persons to whom the Software is furnished to do so, subject to
  9. // the following conditions:
  10. //
  11. // The above copyright notice and this permission notice shall be
  12. // included in all copies or substantial portions of the Software.
  13. //
  14. // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
  15. // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
  16. // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
  17. // NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
  18. // BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
  19. // ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
  20. // CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  21. // SOFTWARE.
  22. //
  23. // _ _ ___ _ _ ___
  24. // | |_(_)_ _ _ _| _ ) | /_\ / __|
  25. // | _| | ' \ || | _ \ |__ / _ \\__ \.
  26. // \__|_|_||_\_, |___/____/_/ \_\___/
  27. // |__/
  28. //
  29. // BASIC LINEAR ALGEBRA SUBPROGRAMS
  30. //
  31. //
  32. // This file implements multithreaded CPU matrix multiplication for the
  33. // common contiguous use case C = Aᵀ * B. These kernels are designed to
  34. // have excellent performance[1] for matrices that fit in the CPU cache
  35. // without imposing any overhead such as cache filling or malloc calls.
  36. //
  37. // This implementation does not guarantee any upper bound with rounding
  38. // errors, which grow along with k. Our goal's to maximally exploit the
  39. // hardware for performance, and then use whatever resources remain for
  40. // improving numerical accuracy.
  41. //
  42. // [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].
  43. // Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
  44. #if defined(__GNUC__)
  45. #pragma GCC diagnostic ignored "-Wpedantic"
  46. #pragma GCC diagnostic ignored "-Wignored-attributes"
  47. #endif
  48. #include "sgemm.h"
  49. #include "ggml-impl.h"
  50. #include "ggml-cpu-impl.h"
  51. #include "ggml-quants.h"
  52. #ifdef _MSC_VER
  53. #define NOINLINE __declspec(noinline)
  54. #else
  55. #define NOINLINE __attribute__((__noinline__))
  56. #endif
  57. #if defined(__ARM_NEON) || defined(__AVX512F__)
  58. #define VECTOR_REGISTERS 32
  59. #else
  60. #define VECTOR_REGISTERS 16
  61. #endif
  62. #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
  63. namespace {
  64. inline float unhalf(ggml_fp16_t d) {
  65. return GGML_FP16_TO_FP32(d);
  66. }
  67. ////////////////////////////////////////////////////////////////////////////////////////////////////
  68. // VECTORIZED ARITHMETIC OPERATIONS
  69. #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  70. inline __m128 add(__m128 x, __m128 y) { return _mm_add_ps(x, y); }
  71. inline __m128 sub(__m128 x, __m128 y) { return _mm_sub_ps(x, y); }
  72. inline __m128 mul(__m128 x, __m128 y) { return _mm_mul_ps(x, y); }
  73. #endif // __SSE__
  74. #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  75. inline __m256 add(__m256 x, __m256 y) { return _mm256_add_ps(x, y); }
  76. inline __m256 sub(__m256 x, __m256 y) { return _mm256_sub_ps(x, y); }
  77. inline __m256 mul(__m256 x, __m256 y) { return _mm256_mul_ps(x, y); }
  78. #endif // __AVX__
  79. #if defined(__AVX512F__)
  80. inline __m512 add(__m512 x, __m512 y) { return _mm512_add_ps(x, y); }
  81. inline __m512 sub(__m512 x, __m512 y) { return _mm512_sub_ps(x, y); }
  82. inline __m512 mul(__m512 x, __m512 y) { return _mm512_mul_ps(x, y); }
  83. #endif // __AVX512F__
  84. #if defined(__ARM_NEON)
  85. inline float32x4_t add(float32x4_t x, float32x4_t y) { return vaddq_f32(x, y); }
  86. inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vsubq_f32(x, y); }
  87. inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vmulq_f32(x, y); }
  88. #endif // __ARM_NEON
  89. #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
  90. inline float16x8_t add(float16x8_t x, float16x8_t y) { return vaddq_f16(x, y); }
  91. inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
  92. inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
  93. #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  94. #if defined(__MMA__)
  95. typedef vector unsigned char vec_t;
  96. typedef __vector_quad acc_t;
  97. #endif
  98. ////////////////////////////////////////////////////////////////////////////////////////////////////
  99. // VECTORIZED FUSED MULTIPLY ADD
  100. /**
  101. * Computes a * b + c.
  102. */
  103. template <typename T, typename U>
  104. inline U madd(T a, T b, U c) {
  105. return add(mul(a, b), c);
  106. }
  107. #if defined(__FMA__)
  108. #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  109. template <>
  110. inline __m256 madd(__m256 a, __m256 b, __m256 c) {
  111. return _mm256_fmadd_ps(a, b, c);
  112. }
  113. #endif
  114. #if defined(__AVX512F__)
  115. template <>
  116. inline __m512 madd(__m512 a, __m512 b, __m512 c) {
  117. return _mm512_fmadd_ps(a, b, c);
  118. }
  119. #endif
  120. #endif
  121. #if defined(__ARM_FEATURE_FMA)
  122. template <>
  123. inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
  124. return vfmaq_f32(c, b, a);
  125. }
  126. #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
  127. template <>
  128. inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
  129. return vfmaq_f16(c, b, a);
  130. }
  131. #endif
  132. #endif
  133. ////////////////////////////////////////////////////////////////////////////////////////////////////
  134. // VECTORIZED HORIZONTAL SUM
  135. #if defined(__ARM_NEON)
  136. inline float hsum(float32x4_t x) {
  137. return vaddvq_f32(x);
  138. }
  139. #endif // __ARM_NEON
  140. #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
  141. inline float hsum(float16x8_t x) {
  142. return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)),
  143. vcvt_f32_f16(vget_high_f16(x))));
  144. }
  145. #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  146. #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  147. inline float hsum(__m128 x) {
  148. #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  149. x = _mm_add_ps(x, _mm_movehl_ps(x, x));
  150. x = _mm_add_ss(x, _mm_movehdup_ps(x));
  151. #else
  152. __m128 t;
  153. t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1));
  154. x = _mm_add_ps(x, t);
  155. t = _mm_movehl_ps(t, x);
  156. x = _mm_add_ss(x, t);
  157. #endif
  158. return _mm_cvtss_f32(x);
  159. }
  160. #endif
  161. #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  162. inline float hsum(__m256 x) {
  163. return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1),
  164. _mm256_castps256_ps128(x)));
  165. }
  166. #endif // __AVX__
  167. #if defined(__AVX512F__)
  168. inline float hsum(__m512 x) {
  169. return _mm512_reduce_add_ps(x);
  170. }
  171. #endif // __AVX512F__
  172. ////////////////////////////////////////////////////////////////////////////////////////////////////
  173. // VECTORIZED MEMORY LOADING
  174. template <typename T, typename U> T load(const U *);
  175. #if defined(__ARM_NEON)
  176. template <> inline float32x4_t load(const float *p) {
  177. return vld1q_f32(p);
  178. }
  179. #if !defined(_MSC_VER)
  180. template <> inline float16x8_t load(const ggml_fp16_t *p) {
  181. return vld1q_f16((const float16_t *)p);
  182. }
  183. template <> inline float32x4_t load(const ggml_fp16_t *p) {
  184. return vcvt_f32_f16(vld1_f16((const float16_t *)p));
  185. }
  186. #endif // _MSC_VER
  187. #endif // __ARM_NEON
  188. #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  189. template <> inline __m128 load(const float *p) {
  190. return _mm_loadu_ps(p);
  191. }
  192. #endif // __SSE__
  193. #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  194. template <> inline __m256 load(const float *p) {
  195. return _mm256_loadu_ps(p);
  196. }
  197. #endif // __AVX__
  198. #if defined(__F16C__)
  199. template <> inline __m256 load(const ggml_fp16_t *p) {
  200. return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p));
  201. }
  202. #endif // __F16C__
  203. #if defined(__AVX512F__)
  204. template <> inline __m512 load(const float *p) {
  205. return _mm512_loadu_ps(p);
  206. }
  207. template <> inline __m512 load(const ggml_fp16_t *p) {
  208. return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p));
  209. }
  210. #endif // __AVX512F__
  211. ////////////////////////////////////////////////////////////////////////////////////////////////////
  212. // CONSTANTS
  213. #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  214. static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
  215. static const __m128i iq4nlt = _mm_loadu_si128((const __m128i *) kvalues_iq4nl);
  216. #endif
  217. ////////////////////////////////////////////////////////////////////////////////////////////////////
  218. // FLOATING POINT MATRIX MULTIPLICATION
  219. template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
  220. class tinyBLAS {
  221. public:
  222. tinyBLAS(int64_t k,
  223. const TA *A, int64_t lda,
  224. const TB *B, int64_t ldb,
  225. TC *C, int64_t ldc,
  226. int ith, int nth)
  227. : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
  228. }
  229. void matmul(int64_t m, int64_t n) {
  230. mnpack(0, m, 0, n);
  231. }
  232. private:
  233. NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  234. int64_t mc, nc, mp, np;
  235. switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) {
  236. #if VECTOR_REGISTERS == 32
  237. case 0x55:
  238. mc = 5;
  239. nc = 5;
  240. gemm<5, 5>(m0, m, n0, n);
  241. break;
  242. case 0x45:
  243. mc = 4;
  244. nc = 5;
  245. gemm<4, 5>(m0, m, n0, n);
  246. break;
  247. case 0x54:
  248. mc = 5;
  249. nc = 4;
  250. gemm<5, 4>(m0, m, n0, n);
  251. break;
  252. case 0x44:
  253. mc = 4;
  254. nc = 4;
  255. gemm<4, 4>(m0, m, n0, n);
  256. break;
  257. case 0x53:
  258. mc = 5;
  259. nc = 3;
  260. gemm<5, 3>(m0, m, n0, n);
  261. break;
  262. case 0x35:
  263. mc = 3;
  264. nc = 5;
  265. gemm<3, 5>(m0, m, n0, n);
  266. break;
  267. case 0x43:
  268. mc = 4;
  269. nc = 3;
  270. gemm<4, 3>(m0, m, n0, n);
  271. break;
  272. #else
  273. case 0x55:
  274. case 0x54:
  275. case 0x53:
  276. case 0x45:
  277. case 0x44:
  278. case 0x43:
  279. mc = 4;
  280. nc = 3;
  281. gemm<4, 3>(m0, m, n0, n);
  282. break;
  283. case 0x35:
  284. #endif
  285. case 0x34:
  286. mc = 3;
  287. nc = 4;
  288. gemm<3, 4>(m0, m, n0, n);
  289. break;
  290. case 0x52:
  291. mc = 5;
  292. nc = 2;
  293. gemm<5, 2>(m0, m, n0, n);
  294. break;
  295. case 0x33:
  296. mc = 3;
  297. nc = 3;
  298. gemm<3, 3>(m0, m, n0, n);
  299. break;
  300. case 0x25:
  301. mc = 2;
  302. nc = 5;
  303. gemm<2, 5>(m0, m, n0, n);
  304. break;
  305. case 0x42:
  306. mc = 4;
  307. nc = 2;
  308. gemm<4, 2>(m0, m, n0, n);
  309. break;
  310. case 0x24:
  311. mc = 2;
  312. nc = 4;
  313. gemm<2, 4>(m0, m, n0, n);
  314. break;
  315. case 0x32:
  316. mc = 3;
  317. nc = 2;
  318. gemm<3, 2>(m0, m, n0, n);
  319. break;
  320. case 0x23:
  321. mc = 2;
  322. nc = 3;
  323. gemm<2, 3>(m0, m, n0, n);
  324. break;
  325. case 0x51:
  326. mc = 5;
  327. nc = 1;
  328. gemm<5, 1>(m0, m, n0, n);
  329. break;
  330. case 0x41:
  331. mc = 4;
  332. nc = 1;
  333. gemm<4, 1>(m0, m, n0, n);
  334. break;
  335. case 0x22:
  336. mc = 2;
  337. nc = 2;
  338. gemm<2, 2>(m0, m, n0, n);
  339. break;
  340. case 0x15:
  341. mc = 1;
  342. nc = 5;
  343. gemm<1, 5>(m0, m, n0, n);
  344. break;
  345. case 0x14:
  346. mc = 1;
  347. nc = 4;
  348. gemm<1, 4>(m0, m, n0, n);
  349. break;
  350. case 0x31:
  351. mc = 3;
  352. nc = 1;
  353. gemm<3, 1>(m0, m, n0, n);
  354. break;
  355. case 0x13:
  356. mc = 1;
  357. nc = 3;
  358. gemm<1, 3>(m0, m, n0, n);
  359. break;
  360. case 0x21:
  361. mc = 2;
  362. nc = 1;
  363. gemm<2, 1>(m0, m, n0, n);
  364. break;
  365. case 0x12:
  366. mc = 1;
  367. nc = 2;
  368. gemm<1, 2>(m0, m, n0, n);
  369. break;
  370. case 0x11:
  371. mc = 1;
  372. nc = 1;
  373. gemm<1, 1>(m0, m, n0, n);
  374. break;
  375. default:
  376. return;
  377. }
  378. mp = m0 + (m - m0) / mc * mc;
  379. np = n0 + (n - n0) / nc * nc;
  380. mnpack(mp, m, n0, np);
  381. mnpack(m0, m, np, n);
  382. }
  383. template <int RM, int RN>
  384. NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  385. int64_t ytiles = (m - m0) / RM;
  386. int64_t xtiles = (n - n0) / RN;
  387. int64_t tiles = xtiles * ytiles;
  388. int64_t duty = (tiles + nth - 1) / nth;
  389. int64_t start = duty * ith;
  390. int64_t end = start + duty;
  391. if (end > tiles)
  392. end = tiles;
  393. for (int64_t job = start; job < end; ++job) {
  394. int64_t ii = m0 + job / xtiles * RM;
  395. int64_t jj = n0 + job % xtiles * RN;
  396. D Cv[RN][RM] = {};
  397. for (int64_t l = 0; l < k; l += KN)
  398. for (int64_t j = 0; j < RN; ++j)
  399. for (int64_t i = 0; i < RM; ++i)
  400. Cv[j][i] = madd(load<V>(A + lda * (ii + i) + l),
  401. load<V>(B + ldb * (jj + j) + l),
  402. Cv[j][i]);
  403. for (int64_t j = 0; j < RN; ++j)
  404. for (int64_t i = 0; i < RM; ++i)
  405. C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
  406. }
  407. }
  408. const TA *const A;
  409. const TB *const B;
  410. TC *const C;
  411. const int64_t k;
  412. const int64_t lda;
  413. const int64_t ldb;
  414. const int64_t ldc;
  415. const int ith;
  416. const int nth;
  417. };
  418. //////////////////////////////////////////////////////////////////////////////////////////
  419. // QUANT ZERO MATRIX MULTIPLICATION
  420. #if defined(__ARM_FEATURE_DOTPROD)
  421. template <typename TA>
  422. class tinyBLAS_Q0_ARM {
  423. public:
  424. tinyBLAS_Q0_ARM(int64_t k,
  425. const TA *A, int64_t lda,
  426. const block_q8_0 *B, int64_t ldb,
  427. float *C, int64_t ldc,
  428. int ith, int nth)
  429. : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
  430. }
  431. void matmul(int64_t m, int64_t n) {
  432. mnpack(0, m, 0, n);
  433. }
  434. private:
  435. NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  436. int64_t mc, nc, mp, np;
  437. switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) {
  438. case 0x33:
  439. mc = 3;
  440. nc = 3;
  441. gemm<3, 3>(m0, m, n0, n);
  442. break;
  443. case 0x32:
  444. mc = 3;
  445. nc = 2;
  446. gemm<3, 2>(m0, m, n0, n);
  447. break;
  448. case 0x23:
  449. mc = 2;
  450. nc = 3;
  451. gemm<2, 3>(m0, m, n0, n);
  452. break;
  453. case 0x22:
  454. mc = 2;
  455. nc = 2;
  456. gemm<2, 2>(m0, m, n0, n);
  457. break;
  458. case 0x31:
  459. mc = 3;
  460. nc = 1;
  461. gemm<3, 1>(m0, m, n0, n);
  462. break;
  463. case 0x13:
  464. mc = 1;
  465. nc = 3;
  466. gemm<1, 3>(m0, m, n0, n);
  467. break;
  468. case 0x21:
  469. mc = 2;
  470. nc = 1;
  471. gemm<2, 1>(m0, m, n0, n);
  472. break;
  473. case 0x12:
  474. mc = 1;
  475. nc = 2;
  476. gemm<1, 2>(m0, m, n0, n);
  477. break;
  478. case 0x11:
  479. mc = 1;
  480. nc = 1;
  481. gemm<1, 1>(m0, m, n0, n);
  482. break;
  483. default:
  484. return;
  485. }
  486. mp = m0 + (m - m0) / mc * mc;
  487. np = n0 + (n - n0) / nc * nc;
  488. mnpack(mp, m, n0, np);
  489. mnpack(m0, m, np, n);
  490. }
  491. template <int RM, int RN>
  492. NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  493. int64_t ytiles = (m - m0) / RM;
  494. int64_t xtiles = (n - n0) / RN;
  495. int64_t tiles = xtiles * ytiles;
  496. int64_t duty = (tiles + nth - 1) / nth;
  497. int64_t start = duty * ith;
  498. int64_t end = start + duty;
  499. if (end > tiles)
  500. end = tiles;
  501. for (int64_t job = start; job < end; ++job) {
  502. int64_t ii = m0 + job / xtiles * RM;
  503. int64_t jj = n0 + job % xtiles * RN;
  504. float32x4_t Cv[RN][RM] = {};
  505. for (int64_t l = 0; l < k; ++l)
  506. for (int64_t j = 0; j < RN; ++j)
  507. for (int64_t i = 0; i < RM; ++i)
  508. Cv[j][i] = vmlaq_n_f32(Cv[j][i],
  509. vcvtq_f32_s32(vdotq_s32(
  510. vdotq_s32(vdupq_n_s32(0),
  511. load_lo(A + lda * (ii + i) + l),
  512. load_lo(B + ldb * (jj + j) + l)),
  513. load_hi(A + lda * (ii + i) + l),
  514. load_hi(B + ldb * (jj + j) + l))),
  515. unhalf(A[lda * (ii + i) + l].d) *
  516. unhalf(B[ldb * (jj + j) + l].d));
  517. for (int64_t j = 0; j < RN; ++j)
  518. for (int64_t i = 0; i < RM; ++i)
  519. C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
  520. }
  521. }
  522. inline int8x16_t load_lo(const block_q8_0 *b) {
  523. return vld1q_s8(b->qs);
  524. }
  525. inline int8x16_t load_hi(const block_q8_0 *b) {
  526. return vld1q_s8(b->qs + 16);
  527. }
  528. inline int8x16_t load_lo(const block_q4_0 *b) {
  529. return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs),
  530. vdupq_n_u8(0x0f))),
  531. vdupq_n_s8(0x8));
  532. }
  533. inline int8x16_t load_hi(const block_q4_0 *b) {
  534. return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)),
  535. vdupq_n_s8(0x8));
  536. }
  537. const TA *const A;
  538. const block_q8_0 *const B;
  539. float *const C;
  540. const int64_t k;
  541. const int64_t lda;
  542. const int64_t ldb;
  543. const int64_t ldc;
  544. const int ith;
  545. const int nth;
  546. };
  547. #endif // __ARM_FEATURE_DOTPROD
  548. #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
  549. template <typename TA, typename TB, typename TC>
  550. class tinyBLAS_Q0_AVX {
  551. public:
  552. tinyBLAS_Q0_AVX(int64_t k,
  553. const TA *A, int64_t lda,
  554. const TB *B, int64_t ldb,
  555. TC *C, int64_t ldc,
  556. int ith, int nth)
  557. : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
  558. }
  559. void matmul(int64_t m, int64_t n) {
  560. mnpack(0, m, 0, n);
  561. }
  562. private:
  563. void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  564. int64_t mc, nc, mp, np;
  565. switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) {
  566. #if VECTOR_REGISTERS == 32
  567. case 0x44:
  568. mc = 4;
  569. nc = 4;
  570. #if defined(__AVX2__) && defined(__F16C__)
  571. gemm4xN<4>(m0, m, n0, n);
  572. #else
  573. gemm<4, 4>(m0, m, n0, n);
  574. #endif
  575. break;
  576. case 0x43:
  577. mc = 4;
  578. nc = 3;
  579. #if defined(__AVX2__) && defined(__F16C__)
  580. gemm4xN<3>(m0, m, n0, n);
  581. #else
  582. gemm<4, 3>(m0, m, n0, n);
  583. #endif
  584. break;
  585. case 0x34:
  586. mc = 3;
  587. nc = 4;
  588. #if defined(__AVX2__) && defined(__F16C__)
  589. gemmMx4<3>(m0, m, n0, n);
  590. #else
  591. gemm<3, 4>(m0, m, n0, n);
  592. #endif
  593. break;
  594. case 0x33:
  595. mc = 3;
  596. nc = 3;
  597. gemm<3, 3>(m0, m, n0, n);
  598. break;
  599. case 0x42:
  600. mc = 4;
  601. nc = 2;
  602. #if defined(__AVX2__) && defined(__F16C__)
  603. gemm4xN<2>(m0, m, n0, n);
  604. #else
  605. gemm<4, 2>(m0, m, n0, n);
  606. #endif
  607. break;
  608. case 0x24:
  609. mc = 2;
  610. nc = 4;
  611. #if defined(__AVX2__) && defined(__F16C__)
  612. gemmMx4<2>(m0, m, n0, n);
  613. #else
  614. gemm<2, 4>(m0, m, n0, n);
  615. #endif
  616. break;
  617. #else
  618. case 0x44:
  619. case 0x43:
  620. case 0x42:
  621. mc = 4;
  622. nc = 2;
  623. #if defined(__AVX2__) && defined(__F16C__)
  624. gemm4xN<2>(m0, m, n0, n);
  625. #else
  626. gemm<4, 2>(m0, m, n0, n);
  627. #endif
  628. break;
  629. case 0x34:
  630. case 0x24:
  631. mc = 2;
  632. nc = 4;
  633. #if defined(__AVX2__) && defined(__F16C__)
  634. gemmMx4<2>(m0, m, n0, n);
  635. #else
  636. gemm<2, 4>(m0, m, n0, n);
  637. #endif
  638. break;
  639. case 0x33:
  640. #endif
  641. case 0x32:
  642. mc = 3;
  643. nc = 2;
  644. gemm<3, 2>(m0, m, n0, n);
  645. break;
  646. case 0x23:
  647. mc = 2;
  648. nc = 3;
  649. gemm<2, 3>(m0, m, n0, n);
  650. break;
  651. case 0x41:
  652. mc = 4;
  653. nc = 1;
  654. #if defined(__AVX2__) && defined(__F16C__)
  655. gemm4xN<1>(m0, m, n0, n);
  656. #else
  657. gemm<4, 1>(m0, m, n0, n);
  658. #endif
  659. break;
  660. case 0x22:
  661. mc = 2;
  662. nc = 2;
  663. gemm<2, 2>(m0, m, n0, n);
  664. break;
  665. case 0x14:
  666. mc = 1;
  667. nc = 4;
  668. #if defined(__AVX2__) && defined(__F16C__)
  669. gemmMx4<1>(m0, m, n0, n);
  670. #else
  671. gemm<1, 4>(m0, m, n0, n);
  672. #endif
  673. break;
  674. case 0x31:
  675. mc = 3;
  676. nc = 1;
  677. gemm<3, 1>(m0, m, n0, n);
  678. break;
  679. case 0x13:
  680. mc = 1;
  681. nc = 3;
  682. gemm<1, 3>(m0, m, n0, n);
  683. break;
  684. case 0x21:
  685. mc = 2;
  686. nc = 1;
  687. gemm<2, 1>(m0, m, n0, n);
  688. break;
  689. case 0x12:
  690. mc = 1;
  691. nc = 2;
  692. gemm<1, 2>(m0, m, n0, n);
  693. break;
  694. case 0x11:
  695. mc = 1;
  696. nc = 1;
  697. gemm<1, 1>(m0, m, n0, n);
  698. break;
  699. default:
  700. return;
  701. }
  702. mp = m0 + (m - m0) / mc * mc;
  703. np = n0 + (n - n0) / nc * nc;
  704. mnpack(mp, m, n0, np);
  705. mnpack(m0, m, np, n);
  706. }
  707. #if defined(__AVX2__) && defined(__F16C__)
  708. // Templated functions for gemm of dimensions 4xN
  709. template <int RN>
  710. NOINLINE void gemm4xN(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  711. int64_t ytiles = (m - m0) / 4;
  712. int64_t xtiles = (n - n0) / RN;
  713. int64_t tiles = xtiles * ytiles;
  714. int64_t duty = (tiles + nth - 1) / nth;
  715. int64_t start = duty * ith;
  716. int64_t end = start + duty;
  717. if (end > tiles)
  718. end = tiles;
  719. for (int64_t job = start; job < end; ++job) {
  720. int64_t ii = m0 + job / xtiles * 4;
  721. int64_t jj = n0 + job % xtiles * RN;
  722. __m256 Cv[RN][4] = {};
  723. for (int64_t l = 0; l < k; ++l) {
  724. uint64_t a_delta = ((uint64_t)A[lda * (ii + 3) + l].d << 48) | ((uint64_t)A[lda * (ii + 2) + l].d << 32) | ((uint64_t)A[lda * (ii + 1) + l].d << 16) | (A[lda * (ii + 0) + l].d);
  725. // Convert delta values for four blocks to float values
  726. __m128 da = _mm_cvtph_ps(_mm_set_epi64x(0, a_delta));
  727. __m256i avec0 = load(A + lda * (ii + 0) + l);
  728. __m256i avec1 = load(A + lda * (ii + 1) + l);
  729. __m256i avec2 = load(A + lda * (ii + 2) + l);
  730. __m256i avec3 = load(A + lda * (ii + 3) + l);
  731. for (int64_t j = 0; j < RN; ++j) {
  732. __m128 db = _mm_set1_ps(unhalf(B[ldb * (jj + j) + l].d));
  733. // Computation of product of delta values for four blocks and replicate it across 256 bit lane
  734. __m256 dvec = _mm256_castps128_ps256(_mm_mul_ps(da, db));
  735. dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
  736. // Computation of dot product and multiplication with appropriate delta value products
  737. Cv[j][0] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
  738. updot(_mm256_sign_epi8(avec0, avec0),
  739. _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec0)),
  740. Cv[j][0]);
  741. Cv[j][1] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
  742. updot(_mm256_sign_epi8(avec1, avec1),
  743. _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec1)),
  744. Cv[j][1]);
  745. Cv[j][2] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
  746. updot(_mm256_sign_epi8(avec2, avec2),
  747. _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec2)),
  748. Cv[j][2]);
  749. Cv[j][3] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
  750. updot(_mm256_sign_epi8(avec3, avec3),
  751. _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec3)),
  752. Cv[j][3]);
  753. }
  754. }
  755. for (int64_t j = 0; j < RN; ++j)
  756. for (int64_t i = 0; i < 4; ++i)
  757. C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
  758. }
  759. }
  760. // Templated functions for gemm of dimensions Mx4
  761. template <int RM>
  762. NOINLINE void gemmMx4(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  763. int64_t ytiles = (m - m0) / RM;
  764. int64_t xtiles = (n - n0) / 4;
  765. int64_t tiles = xtiles * ytiles;
  766. int64_t duty = (tiles + nth - 1) / nth;
  767. int64_t start = duty * ith;
  768. int64_t end = start + duty;
  769. if (end > tiles)
  770. end = tiles;
  771. for (int64_t job = start; job < end; ++job) {
  772. int64_t ii = m0 + job / xtiles * RM;
  773. int64_t jj = n0 + job % xtiles * 4;
  774. __m256 Cv[4][RM] = {};
  775. for (int64_t l = 0; l < k; ++l) {
  776. uint64_t b_delta = ((uint64_t)B[ldb * (jj + 3) + l].d << 48) | ((uint64_t)B[ldb * (jj + 2) + l].d << 32) | ((uint64_t)B[ldb * (jj + 1) + l].d << 16) | (B[ldb * (jj + 0) + l].d);
  777. // Convert delta values for four blocks to float values
  778. __m128 db = _mm_cvtph_ps(_mm_set_epi64x(0, b_delta));
  779. __m256i bvec0 = load(B + ldb * (jj + 0) + l);
  780. __m256i bvec1 = load(B + ldb * (jj + 1) + l);
  781. __m256i bvec2 = load(B + ldb * (jj + 2) + l);
  782. __m256i bvec3 = load(B + ldb * (jj + 3) + l);
  783. for (int64_t i = 0; i < RM; ++i) {
  784. __m128 da = _mm_set1_ps(unhalf((A[lda * (ii + i) + l].d)));
  785. // Computation of product of delta values for four blocks and replicate it across 256 bit lane
  786. __m256 dvec = _mm256_castps128_ps256(_mm_mul_ps(da, db));
  787. dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
  788. // Computation of dot product and multiplication with appropriate delta value products
  789. Cv[0][i] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
  790. updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
  791. load(A + lda * (ii + i) + l)),
  792. _mm256_sign_epi8(bvec0, load(A + lda * (ii + i) + l))),
  793. Cv[0][i]);
  794. Cv[1][i] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
  795. updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
  796. load(A + lda * (ii + i) + l)),
  797. _mm256_sign_epi8(bvec1, load(A + lda * (ii + i) + l))),
  798. Cv[1][i]);
  799. Cv[2][i] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
  800. updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
  801. load(A + lda * (ii + i) + l)),
  802. _mm256_sign_epi8(bvec2, load(A + lda * (ii + i) + l))),
  803. Cv[2][i]);
  804. Cv[3][i] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
  805. updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
  806. load(A + lda * (ii + i) + l)),
  807. _mm256_sign_epi8(bvec3, load(A + lda * (ii + i) + l))),
  808. Cv[3][i]);
  809. }
  810. }
  811. for (int64_t j = 0; j < 4; ++j)
  812. for (int64_t i = 0; i < RM; ++i)
  813. C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
  814. }
  815. }
  816. #endif
  817. template <int RM, int RN>
  818. NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  819. int64_t ytiles = (m - m0) / RM;
  820. int64_t xtiles = (n - n0) / RN;
  821. int64_t tiles = xtiles * ytiles;
  822. int64_t duty = (tiles + nth - 1) / nth;
  823. int64_t start = duty * ith;
  824. int64_t end = start + duty;
  825. if (end > tiles)
  826. end = tiles;
  827. for (int64_t job = start; job < end; ++job) {
  828. int64_t ii = m0 + job / xtiles * RM;
  829. int64_t jj = n0 + job % xtiles * RN;
  830. __m256 Cv[RN][RM] = {};
  831. for (int64_t l = 0; l < k; ++l)
  832. for (int64_t j = 0; j < RN; ++j)
  833. for (int64_t i = 0; i < RM; ++i) {
  834. #if defined(__AVX2__)
  835. __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
  836. load(A + lda * (ii + i) + l)),
  837. _mm256_sign_epi8(load(B + ldb * (jj + j) + l),
  838. load(A + lda * (ii + i) + l)));
  839. #else
  840. __m128i ali0 = load0(A + lda * (ii + i) + l);
  841. __m128i ali1 = load1(A + lda * (ii + i) + l);
  842. __m128i blj0 = load0(B + ldb * (jj + j) + l);
  843. __m128i blj1 = load1(B + ldb * (jj + j) + l);
  844. __m128i sepAA0 = _mm_sign_epi8(ali0, ali0);
  845. __m128i sepAA1 = _mm_sign_epi8(ali1, ali1);
  846. __m128i sepBA0 = _mm_sign_epi8(blj0, ali0);
  847. __m128i sepBA1 = _mm_sign_epi8(blj1, ali1);
  848. // updot
  849. const __m128i oneFill = _mm_set1_epi16(1);
  850. __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0);
  851. __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1);
  852. __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0)));
  853. #endif
  854. Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
  855. unhalf(B[ldb * (jj + j) + l].d)),
  856. udTmp,
  857. Cv[j][i]);
  858. }
  859. for (int64_t j = 0; j < RN; ++j)
  860. for (int64_t i = 0; i < RM; ++i)
  861. C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
  862. }
  863. }
  864. inline __m256i load(const block_q8_0 *b) {
  865. return _mm256_loadu_si256((const __m256i *)b->qs);
  866. }
  867. inline __m128i load0(const block_q8_0 *b) {
  868. return _mm_loadu_si128((const __m128i *)b->qs);
  869. }
  870. inline __m128i load1(const block_q8_0 *b) {
  871. return _mm_loadu_si128(((const __m128i *)b->qs) + 1);
  872. }
  873. inline __m256i load(const block_q4_0 *b) {
  874. return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8));
  875. }
  876. inline __m128i load0(const block_q4_0 *b) {
  877. const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
  878. return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8));
  879. }
  880. inline __m128i load1(const block_q4_0 *b) {
  881. const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
  882. return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
  883. }
  884. inline __m256i load(const block_q5_0 *b) {
  885. return _mm256_or_si256(denibble(b->qs), bittobyte(b->qh));
  886. }
  887. inline __m128i load0(const block_q5_0* b) {
  888. const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
  889. uint32_t x32;
  890. memcpy(&x32, b->qh, sizeof(uint32_t));
  891. __m128i qxl = _mm_and_si128(_mm_set1_epi8(15), x);
  892. __m128i bytesl = _mm_cmpeq_epi8(_mm_set1_epi64x(-1),
  893. _mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe),
  894. _mm_shuffle_epi8(_mm_set1_epi32(x32),
  895. _mm_set_epi64x(0x0101010101010101, 0x0000000000000000))));
  896. bytesl = _mm_andnot_si128(bytesl, _mm_set1_epi8((char)0xF0));
  897. return _mm_or_si128(qxl, bytesl);
  898. }
  899. inline __m128i load1(const block_q5_0* b) {
  900. const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
  901. uint32_t x32;
  902. memcpy(&x32, b->qh, sizeof(uint32_t));
  903. __m128i qxh = _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4));
  904. __m128i bytesh = _mm_cmpeq_epi8(_mm_set1_epi64x(-1),
  905. _mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe),
  906. _mm_shuffle_epi8(_mm_set1_epi32(x32),
  907. _mm_set_epi64x(0x0303030303030303, 0x0202020202020202))));
  908. bytesh = _mm_andnot_si128(bytesh, _mm_set1_epi8((char)0xF0));
  909. return _mm_or_si128(qxh, bytesh);
  910. }
  911. inline __m256i load(const block_iq4_nl *b) {
  912. return MM256_SET_M128I(load1(b), load0(b));
  913. }
  914. inline __m128i load0(const block_iq4_nl *b) {
  915. const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
  916. return _mm_shuffle_epi8(iq4nlt, _mm_and_si128(_mm_set1_epi8(15), x));
  917. }
  918. inline __m128i load1(const block_iq4_nl *b) {
  919. const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
  920. return _mm_shuffle_epi8(iq4nlt, _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)));
  921. }
  922. inline __m256 updot(__m256i u, __m256i s) {
  923. __m256i res;
  924. #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
  925. res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
  926. #else
  927. res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
  928. #endif
  929. return _mm256_cvtepi32_ps(res);
  930. }
  931. static inline __m256i denibble(const uint8_t *p) {
  932. __m128i x = _mm_loadu_si128((const __m128i *)p);
  933. return _mm256_and_si256(_mm256_set1_epi8(15),
  934. _mm256_insertf128_si256(_mm256_castsi128_si256(x),
  935. _mm_srli_epi16(x, 4), 1));
  936. }
  937. static inline __m256i bittobyte(const uint8_t *p) {
  938. uint32_t x32;
  939. memcpy(&x32, p, sizeof(uint32_t));
  940. __m256i bytes = _mm256_cmpeq_epi8(_mm256_set1_epi64x(-1),
  941. _mm256_or_si256(_mm256_set1_epi64x(0x7fbfdfeff7fbfdfe),
  942. _mm256_shuffle_epi8(_mm256_set1_epi32(x32),
  943. _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202,
  944. 0x0101010101010101, 0x0000000000000000))));
  945. return _mm256_andnot_si256(bytes, _mm256_set1_epi8((char)0xF0));
  946. }
  947. const TA *const A;
  948. const TB *const B;
  949. TC *const C;
  950. const int64_t k;
  951. const int64_t lda;
  952. const int64_t ldb;
  953. const int64_t ldc;
  954. const int ith;
  955. const int nth;
  956. };
  957. #endif // __AVX__
  958. //PPC Implementation
  959. #if defined(__MMA__)
  960. #define SAVE_ACC(ACC, ii, jj) \
  961. __builtin_mma_disassemble_acc(vec_C, ACC); \
  962. for (int I = 0; I < 4; I++) { \
  963. for (int J = 0; J < 4; J++) { \
  964. *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J); \
  965. } \
  966. } \
  967. template <typename TA, typename TB, typename TC>
  968. class tinyBLAS_PPC {
  969. public:
  970. tinyBLAS_PPC(int64_t k,
  971. const TA *A, int64_t lda,
  972. const TB *B, int64_t ldb,
  973. TC *C, int64_t ldc,
  974. int ith, int nth)
  975. : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
  976. }
  977. void matmul(int64_t m, int64_t n) {
  978. mnpack(0, m, 0, n);
  979. }
  980. private:
  981. void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
  982. void READ_BLOCK(const float* a, int64_t lda, int rows, int cols, float* vec) {
  983. int64_t i, j;
  984. float *aoffset = NULL, *boffset = NULL;
  985. float *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
  986. float *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
  987. aoffset = const_cast<float*>(a);
  988. boffset = vec;
  989. j = (rows >> 3);
  990. if (j > 0) {
  991. do {
  992. aoffset1 = aoffset;
  993. aoffset2 = aoffset1 + lda;
  994. aoffset3 = aoffset2 + lda;
  995. aoffset4 = aoffset3 + lda;
  996. aoffset5 = aoffset4 + lda;
  997. aoffset6 = aoffset5 + lda;
  998. aoffset7 = aoffset6 + lda;
  999. aoffset8 = aoffset7 + lda;
  1000. aoffset += 8 * lda;
  1001. i = (cols >> 3);
  1002. if (i > 0) {
  1003. __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
  1004. vector float c1[2], c2[2], c3[2], c4[2], c5[2], c6[2], c7[2], c8[2];
  1005. vector float t1, t2, t3, t4, t5, t6, t7, t8;
  1006. do {
  1007. C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
  1008. C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
  1009. C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
  1010. C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
  1011. C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5);
  1012. C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6);
  1013. C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7);
  1014. C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8);
  1015. __builtin_vsx_disassemble_pair(c1, &C1);
  1016. __builtin_vsx_disassemble_pair(c2, &C2);
  1017. __builtin_vsx_disassemble_pair(c3, &C3);
  1018. __builtin_vsx_disassemble_pair(c4, &C4);
  1019. __builtin_vsx_disassemble_pair(c5, &C5);
  1020. __builtin_vsx_disassemble_pair(c6, &C6);
  1021. __builtin_vsx_disassemble_pair(c7, &C7);
  1022. __builtin_vsx_disassemble_pair(c8, &C8);
  1023. t1 = vec_mergeh(c1[0], c2[0]);
  1024. t2 = vec_mergeh(c3[0], c4[0]);
  1025. t3 = vec_mergeh(c5[0], c6[0]);
  1026. t4 = vec_mergeh(c7[0], c8[0]);
  1027. t5 = vec_xxpermdi(t1, t2, 0);
  1028. t6 = vec_xxpermdi(t3, t4, 0);
  1029. t7 = vec_xxpermdi(t1, t2, 3);
  1030. t8 = vec_xxpermdi(t3, t4, 3);
  1031. vec_xst(t5, 0, boffset);
  1032. vec_xst(t6, 0, boffset+4);
  1033. vec_xst(t7, 0, boffset+8);
  1034. vec_xst(t8, 0, boffset+12);
  1035. t1 = vec_mergel(c1[0], c2[0]);
  1036. t2 = vec_mergel(c3[0], c4[0]);
  1037. t3 = vec_mergel(c5[0], c6[0]);
  1038. t4 = vec_mergel(c7[0], c8[0]);
  1039. t5 = vec_xxpermdi(t1, t2, 0);
  1040. t6 = vec_xxpermdi(t3, t4, 0);
  1041. t7 = vec_xxpermdi(t1, t2, 3);
  1042. t8 = vec_xxpermdi(t3, t4, 3);
  1043. vec_xst(t5, 0, boffset+16);
  1044. vec_xst(t6, 0, boffset+20);
  1045. vec_xst(t7, 0, boffset+24);
  1046. vec_xst(t8, 0, boffset+28);
  1047. t1 = vec_mergeh(c1[1], c2[1]);
  1048. t2 = vec_mergeh(c3[1], c4[1]);
  1049. t3 = vec_mergeh(c5[1], c6[1]);
  1050. t4 = vec_mergeh(c7[1], c8[1]);
  1051. t5 = vec_xxpermdi(t1, t2, 0);
  1052. t6 = vec_xxpermdi(t3, t4, 0);
  1053. t7 = vec_xxpermdi(t1, t2, 3);
  1054. t8 = vec_xxpermdi(t3, t4, 3);
  1055. vec_xst(t5, 0, boffset+32);
  1056. vec_xst(t6, 0, boffset+36);
  1057. vec_xst(t7, 0, boffset+40);
  1058. vec_xst(t8, 0, boffset+44);
  1059. t1 = vec_mergel(c1[1], c2[1]);
  1060. t2 = vec_mergel(c3[1], c4[1]);
  1061. t3 = vec_mergel(c5[1], c6[1]);
  1062. t4 = vec_mergel(c7[1], c8[1]);
  1063. t5 = vec_xxpermdi(t1, t2, 0);
  1064. t6 = vec_xxpermdi(t3, t4, 0);
  1065. t7 = vec_xxpermdi(t1, t2, 3);
  1066. t8 = vec_xxpermdi(t3, t4, 3);
  1067. vec_xst(t5, 0, boffset+48);
  1068. vec_xst(t6, 0, boffset+52);
  1069. vec_xst(t7, 0, boffset+56);
  1070. vec_xst(t8, 0, boffset+60);
  1071. aoffset1 += 8*lda;
  1072. aoffset2 += 8*lda;
  1073. aoffset3 += 8*lda;
  1074. aoffset4 += 8*lda;
  1075. boffset += 64;
  1076. i--;
  1077. } while(i > 0);
  1078. }
  1079. if (cols & 4) {
  1080. vector float c1, c2, c3, c4, c5, c6, c7, c8;
  1081. vector float t1, t2, t3, t4, t5, t6, t7, t8;
  1082. c1 = vec_xl(0, aoffset1);
  1083. c2 = vec_xl(0, aoffset2);
  1084. c3 = vec_xl(0, aoffset3);
  1085. c4 = vec_xl(0, aoffset4);
  1086. c5 = vec_xl(0, aoffset5);
  1087. c6 = vec_xl(0, aoffset6);
  1088. c7 = vec_xl(0, aoffset7);
  1089. c8 = vec_xl(0, aoffset8);
  1090. t1 = vec_mergeh(c1, c2);
  1091. t2 = vec_mergeh(c3, c4);
  1092. t3 = vec_mergeh(c5, c6);
  1093. t4 = vec_mergeh(c7, c8);
  1094. t5 = vec_xxpermdi(t1, t2, 0);
  1095. t6 = vec_xxpermdi(t3, t4, 0);
  1096. t7 = vec_xxpermdi(t1, t2, 3);
  1097. t8 = vec_xxpermdi(t3, t4, 3);
  1098. vec_xst(t5, 0, boffset);
  1099. vec_xst(t6, 0, boffset+4);
  1100. vec_xst(t7, 0, boffset+8);
  1101. vec_xst(t8, 0, boffset+12);
  1102. t1 = vec_mergel(c1, c2);
  1103. t2 = vec_mergel(c3, c4);
  1104. t3 = vec_mergel(c5, c6);
  1105. t4 = vec_mergel(c7, c8);
  1106. t5 = vec_xxpermdi(t1, t2, 0);
  1107. t6 = vec_xxpermdi(t3, t4, 0);
  1108. t7 = vec_xxpermdi(t1, t2, 3);
  1109. t8 = vec_xxpermdi(t3, t4, 3);
  1110. vec_xst(t5, 0, boffset+16);
  1111. vec_xst(t6, 0, boffset+20);
  1112. vec_xst(t7, 0, boffset+24);
  1113. vec_xst(t8, 0, boffset+28);
  1114. }
  1115. j--;
  1116. } while(j > 0);
  1117. }
  1118. if (rows & 4) {
  1119. aoffset1 = aoffset;
  1120. aoffset2 = aoffset1 + lda;
  1121. aoffset3 = aoffset2 + lda;
  1122. aoffset4 = aoffset3 + lda;
  1123. aoffset += 4 * lda;
  1124. i = (cols >> 3);
  1125. if (i > 0) {
  1126. __vector_pair C1, C2, C3, C4;
  1127. vector float c1[2], c2[2], c3[2], c4[2];
  1128. vector float t1, t2, t3, t4, t5, t6, t7, t8;
  1129. do {
  1130. C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
  1131. C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
  1132. C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
  1133. C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
  1134. __builtin_vsx_disassemble_pair(c1, &C1);
  1135. __builtin_vsx_disassemble_pair(c2, &C2);
  1136. __builtin_vsx_disassemble_pair(c3, &C3);
  1137. __builtin_vsx_disassemble_pair(c4, &C4);
  1138. t1 = vec_mergeh(c1[0], c2[0]);
  1139. t2 = vec_mergeh(c3[0], c4[0]);
  1140. t3 = vec_mergel(c1[0], c2[0]);
  1141. t4 = vec_mergel(c3[0], c4[0]);
  1142. t5 = vec_xxpermdi(t1, t2, 0);
  1143. t6 = vec_xxpermdi(t1, t2, 3);
  1144. t7 = vec_xxpermdi(t3, t4, 0);
  1145. t8 = vec_xxpermdi(t3, t4, 3);
  1146. vec_xst(t5, 0, boffset);
  1147. vec_xst(t6, 0, boffset+4);
  1148. vec_xst(t7, 0, boffset+8);
  1149. vec_xst(t8, 0, boffset+12);
  1150. t1 = vec_mergeh(c1[1], c2[1]);
  1151. t2 = vec_mergeh(c3[1], c4[1]);
  1152. t3 = vec_mergel(c1[1], c2[1]);
  1153. t4 = vec_mergel(c3[1], c4[1]);
  1154. t5 = vec_xxpermdi(t1, t2, 0);
  1155. t6 = vec_xxpermdi(t1, t2, 3);
  1156. t7 = vec_xxpermdi(t3, t4, 0);
  1157. t8 = vec_xxpermdi(t3, t4, 3);
  1158. vec_xst(t5, 0, boffset+16);
  1159. vec_xst(t6, 0, boffset+20);
  1160. vec_xst(t7, 0, boffset+24);
  1161. vec_xst(t8, 0, boffset+28);
  1162. aoffset1 += 8*lda;
  1163. aoffset2 += 8*lda;
  1164. aoffset3 += 8*lda;
  1165. aoffset4 += 8*lda;
  1166. boffset += 32;
  1167. i--;
  1168. } while(i > 0);
  1169. }
  1170. if (cols & 4) {
  1171. vector float c1, c2, c3, c4;
  1172. vector float t1, t2, t3, t4;
  1173. c1 = vec_xl(0, aoffset1);
  1174. c2 = vec_xl(0, aoffset2);
  1175. c3 = vec_xl(0, aoffset3);
  1176. c4 = vec_xl(0, aoffset4);
  1177. t1 = vec_mergeh(c1, c2);
  1178. t2 = vec_mergeh(c3, c4);
  1179. t3 = vec_xxpermdi(t1, t2, 0);
  1180. t4 = vec_xxpermdi(t1, t2, 3);
  1181. vec_xst(t3, 0, boffset);
  1182. vec_xst(t4, 0, boffset+4);
  1183. t1 = vec_mergel(c1, c2);
  1184. t2 = vec_mergel(c3, c4);
  1185. t3 = vec_xxpermdi(t1, t2, 0);
  1186. t4 = vec_xxpermdi(t1, t2, 3);
  1187. vec_xst(t3, 0, boffset+8);
  1188. vec_xst(t4, 0, boffset+12);
  1189. }
  1190. }
  1191. if (rows & 3) {
  1192. aoffset1 = aoffset;
  1193. aoffset2 = aoffset1 + lda;
  1194. aoffset3 = aoffset2 + lda;
  1195. if (cols & 4) {
  1196. vector float c1, c2, c3, c4 = {0};
  1197. vector float t1, t2, t3, t4;
  1198. c1 = vec_xl(0, aoffset1);
  1199. c2 = vec_xl(0, aoffset2);
  1200. c3 = vec_xl(0, aoffset3);
  1201. t1 = vec_mergeh(c1, c2);
  1202. t2 = vec_mergeh(c3, c4);
  1203. t3 = vec_xxpermdi(t1, t2, 0);
  1204. t4 = vec_xxpermdi(t1, t2, 3);
  1205. vec_xst(t3, 0, boffset);
  1206. vec_xst(t4, 0, boffset+4);
  1207. t1 = vec_mergel(c1, c2);
  1208. t2 = vec_mergel(c3, c4);
  1209. t3 = vec_xxpermdi(t1, t2, 0);
  1210. t4 = vec_xxpermdi(t1, t2, 3);
  1211. vec_xst(t3, 0, boffset+8);
  1212. vec_xst(t4, 0, boffset+12);
  1213. }
  1214. }
  1215. }
  1216. void KERNEL_4x4(int64_t ii, int64_t jj) {
  1217. vec_t vec_A[4], vec_B[4], vec_C[4];
  1218. acc_t acc_0;
  1219. __builtin_mma_xxsetaccz(&acc_0);
  1220. for (int l = 0; l < k; l+=4) {
  1221. READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
  1222. READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
  1223. __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
  1224. __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
  1225. __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
  1226. __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
  1227. }
  1228. SAVE_ACC(&acc_0, ii, jj);
  1229. }
  1230. void KERNEL_4x8(int64_t ii, int64_t jj) {
  1231. vec_t vec_A[4], vec_B[8], vec_C[4];
  1232. acc_t acc_0, acc_1;
  1233. __builtin_mma_xxsetaccz(&acc_0);
  1234. __builtin_mma_xxsetaccz(&acc_1);
  1235. for (int64_t l = 0; l < k; l+=4) {
  1236. READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
  1237. READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B);
  1238. __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
  1239. __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
  1240. __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
  1241. __builtin_mma_xvf32gerpp(&acc_1, vec_A[1], (vec_t)vec_B[3]);
  1242. __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], (vec_t)vec_B[4]);
  1243. __builtin_mma_xvf32gerpp(&acc_1, vec_A[2], (vec_t)vec_B[5]);
  1244. __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], (vec_t)vec_B[6]);
  1245. __builtin_mma_xvf32gerpp(&acc_1, vec_A[3], (vec_t)vec_B[7]);
  1246. }
  1247. SAVE_ACC(&acc_0, ii, jj);
  1248. SAVE_ACC(&acc_1, ii, jj+4);
  1249. }
  1250. void KERNEL_8x4(int64_t ii, int64_t jj) {
  1251. vec_t vec_A[8], vec_B[4], vec_C[4];
  1252. acc_t acc_0, acc_1;
  1253. __builtin_mma_xxsetaccz(&acc_0);
  1254. __builtin_mma_xxsetaccz(&acc_1);
  1255. for (int64_t l = 0; l < k; l+=4) {
  1256. READ_BLOCK(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A);
  1257. READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
  1258. __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
  1259. __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
  1260. __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
  1261. __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[3], vec_B[1]);
  1262. __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[4], vec_B[2]);
  1263. __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[5], vec_B[2]);
  1264. __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[6], vec_B[3]);
  1265. __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[7], vec_B[3]);
  1266. }
  1267. SAVE_ACC(&acc_0, ii, jj);
  1268. SAVE_ACC(&acc_1, ii+4, jj);
  1269. }
  1270. void KERNEL_8x8(int64_t ii, int64_t jj) {
  1271. vec_t vec_A[16], vec_B[16], vec_C[4];
  1272. acc_t acc_0, acc_1, acc_2, acc_3;
  1273. __builtin_mma_xxsetaccz(&acc_0);
  1274. __builtin_mma_xxsetaccz(&acc_1);
  1275. __builtin_mma_xxsetaccz(&acc_2);
  1276. __builtin_mma_xxsetaccz(&acc_3);
  1277. for (int l = 0; l < k; l+=8) {
  1278. READ_BLOCK(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A);
  1279. READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B);
  1280. for(int x = 0; x < 16; x+=2) {
  1281. __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
  1282. __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]);
  1283. __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x+1], vec_B[x]);
  1284. __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x+1], vec_B[x+1]);
  1285. }
  1286. }
  1287. SAVE_ACC(&acc_0, ii, jj);
  1288. SAVE_ACC(&acc_1, ii, jj+4);
  1289. SAVE_ACC(&acc_2, ii+4, jj);
  1290. SAVE_ACC(&acc_3, ii+4, jj+4);
  1291. }
  1292. void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  1293. int64_t mc, nc, mp, np;
  1294. int m_rem = MIN(m - m0, 16);
  1295. int n_rem = MIN(n - n0, 16);
  1296. if (m_rem >= 16 && n_rem >= 8) {
  1297. mc = 8;
  1298. nc = 8;
  1299. gemm<8,8>(m0, m, n0, n);
  1300. } else if(m_rem >= 8 && n_rem >= 16) {
  1301. mc = 8;
  1302. nc = 8;
  1303. gemm<8,8>(m0, m, n0, n);
  1304. } else if (m_rem >= 8 && n_rem >= 8) {
  1305. mc = 8;
  1306. nc = 8;
  1307. gemm<8,8>(m0, m, n0, n);
  1308. } else if (m_rem >= 4 && n_rem >= 8) {
  1309. mc = 4;
  1310. nc = 8;
  1311. gemm<4,8>(m0, m, n0, n);
  1312. } else if (m_rem >= 8 && n_rem >= 4) {
  1313. mc = 8;
  1314. nc = 4;
  1315. gemm<8,4>(m0, m, n0, n);
  1316. } else if (m_rem >= 4 && n_rem >= 4) {
  1317. mc = 4;
  1318. nc = 4;
  1319. gemm<4,4>(m0, m, n0, n);
  1320. } else if ((m_rem < 4) && (n_rem > 4)) {
  1321. nc = 4;
  1322. switch(m_rem) {
  1323. case 1:
  1324. mc = 1;
  1325. gemm_small(m0, m, n0, n, mc, nc);
  1326. break;
  1327. case 2:
  1328. mc = 2;
  1329. gemm_small(m0, m, n0, n, mc, nc);
  1330. break;
  1331. case 3:
  1332. mc = 3;
  1333. gemm_small(m0, m, n0, n, mc, nc);
  1334. break;
  1335. default:
  1336. return;
  1337. }
  1338. } else if ((m_rem > 4) && (n_rem < 4)) {
  1339. mc = 4;
  1340. switch(n_rem) {
  1341. case 1:
  1342. nc = 1;
  1343. gemm_small(m0, m, n0, n, mc, nc);
  1344. break;
  1345. case 2:
  1346. nc = 2;
  1347. gemm_small(m0, m, n0, n, mc, nc);
  1348. break;
  1349. case 3:
  1350. nc = 3;
  1351. gemm_small(m0, m, n0, n, mc, nc);
  1352. break;
  1353. default:
  1354. return;
  1355. }
  1356. } else {
  1357. switch((m_rem << 4) | n_rem) {
  1358. case 0x43:
  1359. mc = 4;
  1360. nc = 3;
  1361. gemm_small(m0, m, n0, n, mc, nc);
  1362. break;
  1363. case 0x42:
  1364. mc = 4;
  1365. nc = 2;
  1366. gemm_small(m0, m, n0, n, mc, nc);
  1367. break;
  1368. case 0x41:
  1369. mc = 4;
  1370. nc = 1;
  1371. gemm_small(m0, m, n0, n, mc, nc);
  1372. break;
  1373. case 0x34:
  1374. mc = 3;
  1375. nc = 4;
  1376. gemm_small(m0, m, n0, n, mc, nc);
  1377. break;
  1378. case 0x33:
  1379. mc = 3;
  1380. nc = 3;
  1381. gemm_small(m0, m, n0, n, mc, nc);
  1382. break;
  1383. case 0x32:
  1384. mc = 3;
  1385. nc = 2;
  1386. gemm_small(m0, m, n0, n, mc, nc);
  1387. break;
  1388. case 0x31:
  1389. mc = 3;
  1390. nc = 1;
  1391. gemm_small(m0, m, n0, n, mc, nc);
  1392. break;
  1393. case 0x24:
  1394. mc = 2;
  1395. nc = 4;
  1396. gemm_small(m0, m, n0, n, mc, nc);
  1397. break;
  1398. case 0x23:
  1399. mc = 2;
  1400. nc = 3;
  1401. gemm_small(m0, m, n0, n, mc, nc);
  1402. break;
  1403. case 0x22:
  1404. mc = 2;
  1405. nc = 2;
  1406. gemm_small(m0, m, n0, n, mc, nc);
  1407. break;
  1408. case 0x21:
  1409. mc = 2;
  1410. nc = 1;
  1411. gemm_small(m0, m, n0, n, mc, nc);
  1412. break;
  1413. case 0x14:
  1414. mc = 1;
  1415. nc = 4;
  1416. gemm_small(m0, m, n0, n, mc, nc);
  1417. break;
  1418. case 0x13:
  1419. mc = 1;
  1420. nc = 3;
  1421. gemm_small(m0, m, n0, n, mc, nc);
  1422. break;
  1423. case 0x12:
  1424. mc = 1;
  1425. nc = 2;
  1426. gemm_small(m0, m, n0, n, mc, nc);
  1427. break;
  1428. case 0x11:
  1429. mc = 1;
  1430. nc = 1;
  1431. gemm_small(m0, m, n0, n, mc, nc);
  1432. break;
  1433. default:
  1434. return;
  1435. }
  1436. }
  1437. mp = m0 + (m - m0) / mc * mc;
  1438. np = n0 + (n - n0) / nc * nc;
  1439. mnpack(mp, m, n0, np);
  1440. mnpack(m0, m, np, n);
  1441. }
  1442. void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
  1443. int64_t ytiles = (m - m0) / RM;
  1444. int64_t xtiles = (n - n0) / RN;
  1445. int64_t tiles = xtiles * ytiles;
  1446. int64_t duty = (tiles + nth - 1) / nth;
  1447. int64_t start = duty * ith;
  1448. int64_t end = start + duty;
  1449. if (end > tiles)
  1450. end = tiles;
  1451. for (int64_t job = start; job < end; ++job) {
  1452. int64_t ii = m0 + job / xtiles * RM;
  1453. int64_t jj = n0 + job % xtiles * RN;
  1454. vec_t vec_C[4];
  1455. acc_t acc_0;
  1456. __builtin_mma_xxsetaccz(&acc_0);
  1457. vec_t vec_A[4], vec_B[4];
  1458. for (int l=0; l<k; l+=4) {
  1459. if (RN >= 4 && RM == 1) {
  1460. float* a = const_cast<float*>(A+(ii)*lda+l);
  1461. READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
  1462. vec_A[0] = (vec_t)vec_xl(0,a);
  1463. vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1));
  1464. vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2));
  1465. vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3));
  1466. } else {
  1467. READ_BLOCK(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
  1468. READ_BLOCK(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
  1469. }
  1470. __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
  1471. __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
  1472. __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
  1473. __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
  1474. }
  1475. __builtin_mma_disassemble_acc(vec_C, &acc_0);
  1476. for (int I = 0; I < RM; I++) {
  1477. for (int J = 0; J < RN; J++) {
  1478. *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J);
  1479. }
  1480. }
  1481. }
  1482. }
  1483. template <int RM, int RN>
  1484. NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
  1485. int64_t ytiles = (m - m0) / RM;
  1486. int64_t xtiles = (n - n0) / RN;
  1487. int64_t tiles = xtiles * ytiles;
  1488. int64_t duty = (tiles + nth - 1) / nth;
  1489. int64_t start = duty * ith;
  1490. int64_t end = start + duty;
  1491. if (RM == 4 && RN == 4) {
  1492. kernel = &tinyBLAS_PPC::KERNEL_4x4;
  1493. } else if (RM == 4 && RN == 8) {
  1494. kernel = &tinyBLAS_PPC::KERNEL_4x8;
  1495. } else if (RM == 8 && RN == 4) {
  1496. kernel = &tinyBLAS_PPC::KERNEL_8x4;
  1497. } else if (RM == 8 && RN == 8) {
  1498. kernel = &tinyBLAS_PPC::KERNEL_8x8;
  1499. }
  1500. if (end > tiles)
  1501. end = tiles;
  1502. for (int64_t job = start; job < end; ++job) {
  1503. int64_t ii = m0 + job / xtiles * RM;
  1504. int64_t jj = n0 + job % xtiles * RN;
  1505. (this->*kernel)(ii, jj);
  1506. }
  1507. }
  1508. const TA *const A;
  1509. const TB *const B;
  1510. TC *C;
  1511. TA *At;
  1512. TB *Bt;
  1513. const int64_t k;
  1514. const int64_t lda;
  1515. const int64_t ldb;
  1516. const int64_t ldc;
  1517. const int ith;
  1518. const int nth;
  1519. };
  1520. #endif
  1521. } // namespace
  1522. /**
  1523. * Performs optimized matrix multiplication on CPU.
  1524. *
  1525. * This subroutine may compute C = Aᵀ * B with column major ordering.
  1526. * Despite its name, this isn't a generalized implementation. Work is
  1527. * only performed when a handwritten kernel is written and available.
  1528. * Otherwise the caller should fall back to a general matmul routine.
  1529. *
  1530. * For example, for single-threaded single-precision GEMM you can say
  1531. *
  1532. * llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc,
  1533. * 0, 1,
  1534. * GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32);
  1535. *
  1536. * @param m is rows in `A` and `C`
  1537. * @param n is cols in `B` and `C`
  1538. * @param k is cols in `A` and rows in `B`
  1539. * @param A is first input matrix (always transposed)
  1540. * @param lda is row stride of `A`
  1541. * @param B is second input matrix (never transposed)
  1542. * @param ldb is row stride of `B`
  1543. * @param C is input/output array of output matrices
  1544. * @param ldc is row stride of `C`
  1545. * @param ith is thread id (must be less than `nth`)
  1546. * @param nth is number of threads (must be greater than zero)
  1547. * @param Atype is GGML data type of `A`
  1548. * @param Btype is GGML data type of `B`
  1549. * @param Ctype is GGML data type of `C`
  1550. * @return true if this function was able to service the matmul request
  1551. */
  1552. bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
  1553. int64_t ldc, int ith, int nth, int Atype, int Btype, int Ctype) {
  1554. assert(m >= 0);
  1555. assert(n >= 0);
  1556. assert(k >= 0);
  1557. assert(lda >= k);
  1558. assert(ldb >= k);
  1559. assert(ldc >= m);
  1560. assert(nth > 0);
  1561. assert(ith < nth);
  1562. // only enable sgemm for prompt processing
  1563. if (n < 2)
  1564. return false;
  1565. if (Ctype != GGML_TYPE_F32)
  1566. return false;
  1567. switch (Atype) {
  1568. case GGML_TYPE_F32: {
  1569. if (Btype != GGML_TYPE_F32)
  1570. return false;
  1571. #if defined(__AVX512F__)
  1572. if (k % 16)
  1573. return false;
  1574. tinyBLAS<16, __m512, __m512, float, float, float> tb{
  1575. k, (const float *)A, lda,
  1576. (const float *)B, ldb,
  1577. (float *)C, ldc,
  1578. ith, nth};
  1579. tb.matmul(m, n);
  1580. return true;
  1581. #elif defined(__AVX__) || defined(__AVX2__)
  1582. if (k % 8)
  1583. return false;
  1584. tinyBLAS<8, __m256, __m256, float, float, float> tb{
  1585. k, (const float *)A, lda,
  1586. (const float *)B, ldb,
  1587. (float *)C, ldc,
  1588. ith, nth};
  1589. tb.matmul(m, n);
  1590. return true;
  1591. #elif defined(__ARM_NEON)
  1592. if (n < 4)
  1593. return false;
  1594. if (k % 4)
  1595. return false;
  1596. tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{
  1597. k, (const float *)A, lda,
  1598. (const float *)B, ldb,
  1599. (float *)C, ldc,
  1600. ith, nth};
  1601. tb.matmul(m, n);
  1602. return true;
  1603. #elif defined(__MMA__)
  1604. if (k % 8)
  1605. return false;
  1606. tinyBLAS_PPC<float, float, float> tb{
  1607. k, (const float *)A, lda,
  1608. (const float *)B, ldb,
  1609. (float *)C, ldc,
  1610. ith, nth};
  1611. tb.matmul(m, n);
  1612. return true;
  1613. #else
  1614. return false;
  1615. #endif
  1616. }
  1617. case GGML_TYPE_F16: {
  1618. #if defined(__AVX512F__)
  1619. if (k % 16)
  1620. return false;
  1621. if (Btype != GGML_TYPE_F32)
  1622. return false;
  1623. tinyBLAS<16, __m512, __m512, ggml_fp16_t, float, float> tb{
  1624. k, (const ggml_fp16_t *)A, lda,
  1625. (const float *)B, ldb,
  1626. (float *)C, ldc,
  1627. ith, nth};
  1628. tb.matmul(m, n);
  1629. return true;
  1630. #elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
  1631. if (k % 8)
  1632. return false;
  1633. if (Btype != GGML_TYPE_F32)
  1634. return false;
  1635. tinyBLAS<8, __m256, __m256, ggml_fp16_t, float, float> tb{
  1636. k, (const ggml_fp16_t *)A, lda,
  1637. (const float *)B, ldb,
  1638. (float *)C, ldc,
  1639. ith, nth};
  1640. tb.matmul(m, n);
  1641. return true;
  1642. #elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
  1643. if (n < 8)
  1644. return false;
  1645. if (k % 8)
  1646. return false;
  1647. if (Btype != GGML_TYPE_F16)
  1648. return false;
  1649. tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{
  1650. k, (const ggml_fp16_t *)A, lda,
  1651. (const ggml_fp16_t *)B, ldb,
  1652. (float *)C, ldc,
  1653. ith, nth};
  1654. tb.matmul(m, n);
  1655. return true;
  1656. #elif defined(__ARM_NEON) && !defined(_MSC_VER)
  1657. if (k % 4)
  1658. return false;
  1659. if (Btype != GGML_TYPE_F32)
  1660. return false;
  1661. tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{
  1662. k, (const ggml_fp16_t *)A, lda,
  1663. (const float *)B, ldb,
  1664. (float *)C, ldc,
  1665. ith, nth};
  1666. tb.matmul(m, n);
  1667. return true;
  1668. #else
  1669. return false;
  1670. #endif
  1671. }
  1672. case GGML_TYPE_Q8_0: {
  1673. if (Btype != GGML_TYPE_Q8_0)
  1674. return false;
  1675. #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
  1676. tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{
  1677. k, (const block_q8_0 *)A, lda,
  1678. (const block_q8_0 *)B, ldb,
  1679. (float *)C, ldc,
  1680. ith, nth};
  1681. tb.matmul(m, n);
  1682. return true;
  1683. #elif defined(__ARM_FEATURE_DOTPROD)
  1684. tinyBLAS_Q0_ARM<block_q8_0> tb{
  1685. k, (const block_q8_0 *)A, lda,
  1686. (const block_q8_0 *)B, ldb,
  1687. (float *)C, ldc,
  1688. ith, nth};
  1689. tb.matmul(m, n);
  1690. return true;
  1691. #else
  1692. return false;
  1693. #endif
  1694. }
  1695. case GGML_TYPE_Q4_0: {
  1696. if (Btype != GGML_TYPE_Q8_0)
  1697. return false;
  1698. #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
  1699. tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{
  1700. k, (const block_q4_0 *)A, lda,
  1701. (const block_q8_0 *)B, ldb,
  1702. (float *)C, ldc,
  1703. ith, nth};
  1704. tb.matmul(m, n);
  1705. return true;
  1706. #elif defined(__ARM_FEATURE_DOTPROD)
  1707. tinyBLAS_Q0_ARM<block_q4_0> tb{
  1708. k, (const block_q4_0 *)A, lda,
  1709. (const block_q8_0 *)B, ldb,
  1710. (float *)C, ldc,
  1711. ith, nth};
  1712. tb.matmul(m, n);
  1713. return true;
  1714. #else
  1715. return false;
  1716. #endif
  1717. }
  1718. case GGML_TYPE_Q5_0: {
  1719. if (Btype != GGML_TYPE_Q8_0)
  1720. return false;
  1721. #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
  1722. tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float> tb{
  1723. k, (const block_q5_0 *)A, lda,
  1724. (const block_q8_0 *)B, ldb,
  1725. (float *)C, ldc,
  1726. ith, nth};
  1727. tb.matmul(m, n);
  1728. return true;
  1729. #else
  1730. return false;
  1731. #endif
  1732. }
  1733. case GGML_TYPE_IQ4_NL: {
  1734. if (Btype != GGML_TYPE_Q8_0)
  1735. return false;
  1736. #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
  1737. tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float> tb{
  1738. k, (const block_iq4_nl *)A, lda,
  1739. (const block_q8_0 *)B, ldb,
  1740. (float *)C, ldc,
  1741. ith, nth};
  1742. tb.matmul(m, n);
  1743. return true;
  1744. #else
  1745. return false;
  1746. #endif
  1747. }
  1748. default:
  1749. return false;
  1750. }
  1751. (void)m;
  1752. (void)n;
  1753. (void)k;
  1754. (void)A;
  1755. (void)lda;
  1756. (void)B;
  1757. (void)ldb;
  1758. (void)C;
  1759. (void)ldc;
  1760. (void)ith;
  1761. (void)nth;
  1762. (void)Atype;
  1763. (void)Btype;
  1764. (void)Ctype;
  1765. }