mmq.cuh 112 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962
  1. /**
  2. * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
  3. *
  4. * MIT License
  5. *
  6. * Copyright (c) 2023-2024 The ggml authors
  7. *
  8. * Permission is hereby granted, free of charge, to any person obtaining a copy
  9. * of this software and associated documentation files (the "Software"), to deal
  10. * in the Software without restriction, including without limitation the rights
  11. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  12. * copies of the Software, and to permit persons to whom the Software is
  13. * furnished to do so, subject to the following conditions:
  14. *
  15. * The above copyright notice and this permission notice shall be included in all
  16. * copies or substantial portions of the Software.
  17. *
  18. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  19. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  20. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  21. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  22. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  23. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  24. * SOFTWARE.
  25. */
  26. #pragma once
  27. #include "common.cuh"
  28. #include "vecdotq.cuh"
  29. #include "mma.cuh"
  30. #include <climits>
  31. #include <cstdint>
  32. #define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
  33. #define MMQ_ITER_K 256
  34. #define MMQ_NWARPS 8
  35. typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int & kbx0, const int & i_max, const int & stride);
  36. typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00);
  37. typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max);
  38. enum mmq_q8_1_ds_layout {
  39. MMQ_Q8_1_DS_LAYOUT_D4,
  40. MMQ_Q8_1_DS_LAYOUT_DS4,
  41. MMQ_Q8_1_DS_LAYOUT_D2S6,
  42. };
  43. struct block_q8_1_mmq {
  44. // The y float data is converted to a data layout that can simply be copied to shared memory as a contiguous block.
  45. // The y float data is first grouped as blocks of 128 values.
  46. // These blocks are then treated as individual data values and transposed.
  47. //
  48. // To avoid shared memory bank conflicts each block is padded with 16 bytes.
  49. // This padding is also used to store block scales/partial sums.
  50. // The scales multiplied with the quantized data are equal to the unquantized values.
  51. // The partial sums are obtained by summing up a subgroup of the contained values (prior to quantization)
  52. // and are only needed for performance reasons.
  53. //
  54. // The exact data stored depends on the x data type.
  55. union {
  56. float d4[4]; // 1 32 bit scale per 32 values, stored as d0,d1,d2,d3
  57. half2 ds4[4]; // 1 16 bit scale + 1 16 bit partial sum per 32 values, stored as d0,s0,d1,s1,d2,s2,d3,s3
  58. half d2s6[8]; // 1 16 bit scale per 64 values + 1 16 bit partial sum per 16 values for the first 96 values,
  59. // stored as d0,d1,s1,s2,s3,s4,s5
  60. };
  61. int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each
  62. };
  63. static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
  64. static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
  65. static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
  66. switch (type_x) {
  67. case GGML_TYPE_Q4_0:
  68. case GGML_TYPE_Q4_1:
  69. return MMQ_Q8_1_DS_LAYOUT_DS4;
  70. case GGML_TYPE_Q5_0:
  71. return MMQ_Q8_1_DS_LAYOUT_D4;
  72. case GGML_TYPE_Q5_1:
  73. return MMQ_Q8_1_DS_LAYOUT_DS4;
  74. case GGML_TYPE_Q8_0:
  75. return MMQ_Q8_1_DS_LAYOUT_D4;
  76. case GGML_TYPE_Q2_K:
  77. return MMQ_Q8_1_DS_LAYOUT_D2S6;
  78. case GGML_TYPE_Q3_K:
  79. return MMQ_Q8_1_DS_LAYOUT_D4;
  80. case GGML_TYPE_Q4_K:
  81. case GGML_TYPE_Q5_K:
  82. return MMQ_Q8_1_DS_LAYOUT_DS4;
  83. case GGML_TYPE_Q6_K:
  84. case GGML_TYPE_IQ2_XXS:
  85. case GGML_TYPE_IQ2_XS:
  86. case GGML_TYPE_IQ2_S:
  87. case GGML_TYPE_IQ3_XXS:
  88. case GGML_TYPE_IQ3_S:
  89. return MMQ_Q8_1_DS_LAYOUT_D4;
  90. case GGML_TYPE_IQ1_S:
  91. return MMQ_Q8_1_DS_LAYOUT_DS4;
  92. case GGML_TYPE_IQ4_XS:
  93. case GGML_TYPE_IQ4_NL:
  94. return MMQ_Q8_1_DS_LAYOUT_D4;
  95. default:
  96. GGML_ABORT("fatal error");
  97. break;
  98. }
  99. }
  100. struct tile_x_sizes {
  101. int qs;
  102. int dm;
  103. int sc;
  104. };
  105. static constexpr int get_mmq_x_max_host(const int cc) {
  106. return int8_mma_available(cc) ? 128 :
  107. #ifdef GGML_CUDA_FORCE_MMQ
  108. cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? 128 : 64;
  109. #else
  110. cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_DP4A_MAX_BATCH_SIZE : 64;
  111. #endif // GGML_CUDA_FORCE_MMQ
  112. }
  113. static constexpr __device__ int get_mmq_x_max_device() {
  114. #ifdef INT8_MMA_AVAILABLE
  115. return 128;
  116. #else // INT8_MMA_AVAILABLE
  117. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  118. return 128;
  119. #else // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  120. #if __CUDA_ARCH__ >= CC_VOLTA
  121. #ifdef GGML_CUDA_FORCE_MMQ
  122. return MMQ_DP4A_MAX_BATCH_SIZE;
  123. #else // GGML_CUDA_FORCE_MMQ
  124. return 128;
  125. #endif // GGML_CUDA_FORCE_MMQ
  126. #else // __CUDA_ARCH__ >= CC_VOLTA
  127. return 64;
  128. #endif // __CUDA_ARCH__ >= CC_VOLTA
  129. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  130. #endif // INT8_MMA_AVAILABLE
  131. }
  132. static constexpr int get_mmq_y_host(const int cc) {
  133. return cc >= CC_OFFSET_AMD ? (cc == CC_RDNA1 ? 64 : 128) : (cc >= CC_VOLTA ? 128 : 64);
  134. }
  135. static constexpr __device__ int get_mmq_y_device() {
  136. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  137. #if defined(RDNA1)
  138. return 64;
  139. #else
  140. return 128;
  141. #endif // defined RDNA1
  142. #else
  143. #if __CUDA_ARCH__ >= CC_VOLTA
  144. return 128;
  145. #else
  146. return 64;
  147. #endif // __CUDA_ARCH__ >= CC_VOLTA
  148. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  149. }
  150. #define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
  151. #define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
  152. #define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_0 + mmq_y/(QI8_0/2), 0}
  153. #define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*4/QI8_0 + mmq_y/(QI8_0/4), 0}
  154. #define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_1 + mmq_y/(QI8_1/2), 0}
  155. #define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0}
  156. #define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y, mmq_y*WARP_SIZE/8 + mmq_y/8}
  157. #define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
  158. #define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
  159. #define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
  160. static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
  161. return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
  162. type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
  163. type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q8_0 :
  164. type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q8_1 :
  165. type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
  166. type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
  167. type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
  168. type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
  169. type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
  170. type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
  171. type == GGML_TYPE_IQ2_XXS ? MMQ_DP4A_TXS_Q8_0 :
  172. type == GGML_TYPE_IQ2_XS ? MMQ_DP4A_TXS_Q8_0_16 :
  173. type == GGML_TYPE_IQ2_S ? MMQ_DP4A_TXS_Q8_0_16 :
  174. type == GGML_TYPE_IQ3_XXS ? MMQ_DP4A_TXS_Q8_0 :
  175. type == GGML_TYPE_IQ3_S ? MMQ_DP4A_TXS_Q8_0 :
  176. type == GGML_TYPE_IQ1_S ? MMQ_DP4A_TXS_Q8_0 :
  177. type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q8_0 :
  178. type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q8_0 :
  179. tile_x_sizes{0, 0, 0};
  180. }
  181. #define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4)
  182. #define MMQ_MMA_TILE_X_K_Q8_1 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4)
  183. #define MMQ_MMA_TILE_X_K_Q2_K (2*WARP_SIZE + WARP_SIZE + 4)
  184. #define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/2 + 4)
  185. #define MMQ_MMA_TILE_X_K_Q6_K (2*WARP_SIZE + WARP_SIZE/QI6_K + WARP_SIZE/8 + 7)
  186. static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
  187. static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
  188. static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
  189. static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
  190. static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
  191. static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
  192. return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
  193. type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
  194. type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
  195. type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
  196. type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
  197. type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
  198. type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
  199. type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q8_1 :
  200. type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q8_1 :
  201. type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
  202. type == GGML_TYPE_IQ2_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
  203. type == GGML_TYPE_IQ2_XS ? MMQ_MMA_TILE_X_K_Q3_K :
  204. type == GGML_TYPE_IQ2_S ? MMQ_MMA_TILE_X_K_Q3_K :
  205. type == GGML_TYPE_IQ3_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
  206. type == GGML_TYPE_IQ3_S ? MMQ_MMA_TILE_X_K_Q8_0 :
  207. type == GGML_TYPE_IQ1_S ? MMQ_MMA_TILE_X_K_Q8_0 :
  208. type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q8_0 :
  209. type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q8_0 :
  210. 0;
  211. }
  212. #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
  213. static int mmq_get_granularity_host(const int mmq_x, const int cc) {
  214. return int8_mma_available(cc) && mmq_x >= 48 ? 16 : 8;
  215. }
  216. #ifdef INT8_MMA_AVAILABLE
  217. static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
  218. return mmq_x >= 48 ? 16 : 8;
  219. }
  220. #else
  221. static constexpr __device__ int mmq_get_granularity_device(const int /* mmq_x */) {
  222. return 8;
  223. }
  224. #endif // INT8_MMA_AVAILABLE
  225. // ------------------------------------------------------------
  226. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
  227. const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
  228. #ifdef INT8_MMA_AVAILABLE
  229. int * x_qs = (int *) x_tile;
  230. float * x_df = (float *) (x_qs + 2*WARP_SIZE);
  231. #else
  232. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
  233. int * x_qs = (int *) x_tile;
  234. float * x_df = (float *) (x_qs + txs.qs);
  235. #endif // INT8_MMA_AVAILABLE
  236. const int kbx = threadIdx.x / QI4_0;
  237. const int kqsx = threadIdx.x % QI4_0;
  238. #pragma unroll
  239. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  240. int i = i0 + threadIdx.y;
  241. if (need_check) {
  242. i = min(i, i_max);
  243. }
  244. const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
  245. const int qs0 = get_int_b2(bxi->qs, kqsx);
  246. #ifdef INT8_MMA_AVAILABLE
  247. x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);
  248. x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);
  249. #else
  250. x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
  251. #endif // INT8_MMA_AVAILABLE
  252. }
  253. const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
  254. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  255. #pragma unroll
  256. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) {
  257. int i = i0 + threadIdx.y * QI4_0 + threadIdx.x / blocks_per_tile_x_row;
  258. if (need_check) {
  259. i = min(i, i_max);
  260. }
  261. const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
  262. #ifdef INT8_MMA_AVAILABLE
  263. x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
  264. #else
  265. x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
  266. #endif // INT8_MMA_AVAILABLE
  267. }
  268. }
  269. template <int mmq_x, int mmq_y, int nwarps>
  270. static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
  271. const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
  272. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
  273. const int * x_qs = (const int *) x;
  274. const float * x_df = (const float *) x_qs + txs.qs;
  275. const int * y_qs = (const int *) y + 4;
  276. const half2 * y_ds = (const half2 *) y;
  277. // #pragma unroll
  278. for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) {
  279. const int k0 = k00 + k01;
  280. #pragma unroll
  281. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  282. const int j = j0 + threadIdx.y;
  283. #pragma unroll
  284. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  285. const int i = i0 + threadIdx.x;
  286. const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
  287. int u[2*VDR_Q4_0_Q8_1_MMQ];
  288. #pragma unroll
  289. for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
  290. u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l];
  291. u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)];
  292. }
  293. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
  294. (&x_qs[i*(WARP_SIZE + 1) + k0/QR4_0], u,
  295. x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
  296. }
  297. }
  298. }
  299. }
  300. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
  301. const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
  302. #ifdef INT8_MMA_AVAILABLE
  303. int * x_qs = (int *) x_tile;
  304. half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
  305. #else
  306. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
  307. int * x_qs = (int *) x_tile;
  308. half2 * x_dm = (half2 *) (x_qs + txs.qs);
  309. #endif // INT8_MMA_AVAILABLE
  310. const int kbx = threadIdx.x / QI4_1;
  311. const int kqsx = threadIdx.x % QI4_1;
  312. #pragma unroll
  313. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  314. int i = i0 + threadIdx.y;
  315. if (need_check) {
  316. i = min(i, i_max);
  317. }
  318. const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
  319. const int qs0 = get_int_b4(bxi->qs, kqsx);
  320. #ifdef INT8_MMA_AVAILABLE
  321. x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F;
  322. x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;
  323. #else
  324. x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
  325. #endif // INT8_MMA_AVAILABLE
  326. }
  327. const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
  328. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  329. #pragma unroll
  330. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) {
  331. int i = i0 + threadIdx.y * QI4_1 + threadIdx.x / blocks_per_tile_x_row;
  332. if (need_check) {
  333. i = min(i, i_max);
  334. }
  335. const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
  336. #ifdef INT8_MMA_AVAILABLE
  337. x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
  338. #else
  339. x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
  340. #endif // INT8_MMA_AVAILABLE
  341. }
  342. }
  343. template <int mmq_x, int mmq_y, int nwarps>
  344. static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
  345. const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
  346. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
  347. const int * x_qs = (const int *) x;
  348. const half2 * x_dm = (const half2 *) x_qs + txs.qs;
  349. const int * y_qs = (const int *) y + 4;
  350. const half2 * y_ds = (const half2 *) y;
  351. // #pragma unroll
  352. for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) {
  353. const int k0 = k00 + k01;
  354. #pragma unroll
  355. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  356. const int j = j0 + threadIdx.y;
  357. #pragma unroll
  358. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  359. const int i = i0 + threadIdx.x;
  360. const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
  361. int u[2*VDR_Q4_1_Q8_1_MMQ];
  362. #pragma unroll
  363. for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
  364. u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l];
  365. u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)];
  366. }
  367. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
  368. (&x_qs[i*(WARP_SIZE + 1) + k0/QR4_1], u,
  369. x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
  370. }
  371. }
  372. }
  373. }
  374. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
  375. const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
  376. #ifdef INT8_MMA_AVAILABLE
  377. int * x_qs = (int *) x_tile;
  378. float * x_df = (float *) (x_qs + WARP_SIZE*2);
  379. #else
  380. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
  381. int * x_qs = (int *) x_tile;
  382. float * x_df = (float *) (x_qs + txs.qs);
  383. #endif // INT8_MMA_AVAILABLE
  384. const int kbx = threadIdx.x / QI5_0;
  385. const int kqsx = threadIdx.x % QI5_0;
  386. #pragma unroll
  387. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  388. int i = i0 + threadIdx.y;
  389. if (need_check) {
  390. i = min(i, i_max);
  391. }
  392. const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx;
  393. const int ql = get_int_b2(bxi->qs, kqsx);
  394. const int qh = get_int_b2(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_0));
  395. int qs0 = (ql >> 0) & 0x0F0F0F0F;
  396. qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
  397. qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
  398. qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
  399. qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
  400. qs0 = __vsubss4(qs0, 0x10101010); // subtract 16
  401. int qs1 = (ql >> 4) & 0x0F0F0F0F;
  402. qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
  403. qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
  404. qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
  405. qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
  406. qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
  407. #ifdef INT8_MMA_AVAILABLE
  408. x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0;
  409. x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
  410. #else
  411. x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0;
  412. x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
  413. #endif // INT8_MMA_AVAILABLE
  414. }
  415. const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
  416. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  417. #pragma unroll
  418. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {
  419. int i = i0 + threadIdx.y * QI5_0 + threadIdx.x / blocks_per_tile_x_row;
  420. if (need_check) {
  421. i = min(i, i_max);
  422. }
  423. const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
  424. #ifdef INT8_MMA_AVAILABLE
  425. x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
  426. #else
  427. x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
  428. #endif // INT8_MMA_AVAILABLE
  429. }
  430. }
  431. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
  432. const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
  433. #ifdef INT8_MMA_AVAILABLE
  434. int * x_qs = (int *) x_tile;
  435. half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
  436. #else
  437. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
  438. int * x_qs = (int *) x_tile;
  439. half2 * x_dm = (half2 *) (x_qs + txs.qs);
  440. #endif // INT8_MMA_AVAILABLE
  441. const int kbx = threadIdx.x / QI5_1;
  442. const int kqsx = threadIdx.x % QI5_1;
  443. #pragma unroll
  444. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  445. int i = i0 + threadIdx.y;
  446. if (need_check) {
  447. i = min(i, i_max);
  448. }
  449. const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx;
  450. const int ql = get_int_b4(bxi->qs, kqsx);
  451. const int qh = get_int_b4(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_1));
  452. int qs0 = (ql >> 0) & 0x0F0F0F0F;
  453. qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
  454. qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
  455. qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
  456. qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
  457. int qs1 = (ql >> 4) & 0x0F0F0F0F;
  458. qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
  459. qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
  460. qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
  461. qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
  462. #ifdef INT8_MMA_AVAILABLE
  463. x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0;
  464. x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
  465. #else
  466. x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0;
  467. x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
  468. #endif // INT8_MMA_AVAILABLE
  469. }
  470. const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
  471. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  472. #pragma unroll
  473. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) {
  474. int i = i0 + threadIdx.y * QI5_1 + threadIdx.x / blocks_per_tile_x_row;
  475. if (need_check) {
  476. i = min(i, i_max);
  477. }
  478. const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
  479. #ifdef INT8_MMA_AVAILABLE
  480. x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
  481. #else
  482. x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
  483. #endif // INT8_MMA_AVAILABLE
  484. }
  485. }
  486. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
  487. const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
  488. #ifdef INT8_MMA_AVAILABLE
  489. int * x_qs = (int *) x_tile;
  490. float * x_df = (float *) (x_tile + 2*WARP_SIZE);
  491. #else
  492. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
  493. int * x_qs = (int *) x_tile;
  494. float * x_df = (float *) (x_qs + txs.qs);
  495. #endif // INT8_MMA_AVAILABLE
  496. const int kbx = threadIdx.x / QI8_0;
  497. const int kqsx = threadIdx.x % QI8_0;
  498. #pragma unroll
  499. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  500. int i = i0 + threadIdx.y;
  501. if (need_check) {
  502. i = min(i, i_max);
  503. }
  504. const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
  505. #ifdef INT8_MMA_AVAILABLE
  506. x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx);
  507. x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
  508. #else
  509. x_qs[i*(2*WARP_SIZE + 1) + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx);
  510. x_qs[i*(2*WARP_SIZE + 1) + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
  511. #endif // INT8_MMA_AVAILABLE
  512. }
  513. const int blocks_per_tile_x_row = 2*WARP_SIZE / QI8_0;
  514. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  515. #pragma unroll
  516. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0/2) {
  517. int i = i0 + threadIdx.y * (QI8_0/2) + threadIdx.x / blocks_per_tile_x_row;
  518. if (need_check) {
  519. i = min(i, i_max);
  520. }
  521. const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
  522. #ifdef INT8_MMA_AVAILABLE
  523. x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
  524. #else
  525. x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
  526. #endif // INT8_MMA_AVAILABLE
  527. }
  528. }
  529. template <int mmq_x, int mmq_y, int nwarps>
  530. static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
  531. const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
  532. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
  533. const int * x_qs = (const int *) x;
  534. const float * x_df = (const float *) x_qs + txs.qs;
  535. const int * y_qs = (const int *) y + 4;
  536. const float * y_df = (const float *) y;
  537. // #pragma unroll
  538. for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) {
  539. const int k0 = k00 + k01;
  540. #pragma unroll
  541. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  542. const int j = j0 + threadIdx.y;
  543. #pragma unroll
  544. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  545. const int i = i0 + threadIdx.x;
  546. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
  547. (&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % WARP_SIZE],
  548. x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)]);
  549. }
  550. }
  551. }
  552. }
  553. template <int mmq_x, int mmq_y, int nwarps, mmq_q8_1_ds_layout ds_layout>
  554. static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
  555. const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
  556. typedef mma_int_A_I16K8 mma_A;
  557. typedef mma_int_B_J8K8 mma_B;
  558. typedef mma_int_C_I16J8 mma_C;
  559. constexpr int granularity = mmq_get_granularity_device(mmq_x);
  560. constexpr int rows_per_warp = 2 * granularity;
  561. constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
  562. y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
  563. const int * x_qs = (const int *) x;
  564. const float * x_df = (const float *) x_qs + 2*WARP_SIZE;
  565. const int * y_qs = (const int *) y + 4;
  566. const float * y_df = (const float *) y;
  567. const half2 * y_ds = (const half2 *) y;
  568. mma_A A[ntx][WARP_SIZE/QI8_0];
  569. float dA[ntx][mma_C::ne/2][WARP_SIZE/QI8_0];
  570. const int i0 = (threadIdx.y/ntx)*rows_per_warp;
  571. #pragma unroll
  572. for (int n = 0; n < ntx; ++n) {
  573. #pragma unroll
  574. for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
  575. const int k0 = k00 + k01;
  576. A[n][k01/QI8_0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
  577. }
  578. #pragma unroll
  579. for (int l = 0; l < mma_C::ne/2; ++l) {
  580. const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
  581. #pragma unroll
  582. for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
  583. const int k0 = k00 + k01;
  584. dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
  585. }
  586. }
  587. }
  588. #pragma unroll
  589. for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
  590. #pragma unroll
  591. for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
  592. mma_B B;
  593. float dB[mma_C::ne/2];
  594. B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
  595. #pragma unroll
  596. for (int l = 0; l < mma_C::ne/2; ++l) {
  597. const int j = j0 + mma_C::get_j(l);
  598. if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
  599. dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
  600. } else {
  601. dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
  602. }
  603. }
  604. #pragma unroll
  605. for (int n = 0; n < ntx; ++n) {
  606. mma_C C;
  607. C.mma_K8(A[n][k01/QI8_0], B);
  608. #pragma unroll
  609. for (int l = 0; l < mma_C::ne; ++l) {
  610. sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
  611. }
  612. }
  613. }
  614. }
  615. }
  616. template <int mmq_x, int mmq_y, int nwarps>
  617. static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
  618. const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
  619. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
  620. const int * x_qs = (const int *) x;
  621. const half2 * x_dm = (const half2 *) x_qs + txs.qs;
  622. const int * y_qs = (const int *) y + 4;
  623. const half2 * y_ds = (const half2 *) y;
  624. // #pragma unroll
  625. for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) {
  626. const int k0 = k00 + k01;
  627. #pragma unroll
  628. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  629. const int j = j0 + threadIdx.y;
  630. #pragma unroll
  631. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  632. const int i = i0 + threadIdx.x;
  633. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
  634. (&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
  635. x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
  636. }
  637. }
  638. }
  639. }
  640. template <int mmq_x, int mmq_y, int nwarps>
  641. static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
  642. const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
  643. typedef mma_int_A_I16K8 mma_A;
  644. typedef mma_int_B_J8K8 mma_B;
  645. typedef mma_int_C_I16J8 mma_C;
  646. constexpr int granularity = mmq_get_granularity_device(mmq_x);
  647. constexpr int rows_per_warp = 2 * granularity;
  648. constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
  649. y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
  650. const int * x_qs = (const int *) x;
  651. const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE;
  652. const int * y_qs = (const int *) y + 4;
  653. const half2 * y_dm = (const half2 *) y;
  654. mma_A A[ntx][WARP_SIZE/QI8_1];
  655. float2 dmA[ntx][mma_C::ne/2][WARP_SIZE/QI8_1];
  656. const int i0 = (threadIdx.y/ntx)*rows_per_warp;
  657. #pragma unroll
  658. for (int n = 0; n < ntx; ++n) {
  659. #pragma unroll
  660. for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
  661. const int k0 = k00 + k01;
  662. A[n][k01/QI8_1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
  663. }
  664. #pragma unroll
  665. for (int l = 0; l < mma_C::ne/2; ++l) {
  666. const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
  667. #pragma unroll
  668. for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
  669. const int k0 = k00 + k01;
  670. dmA[n][l][k01/QI8_1] = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
  671. }
  672. }
  673. }
  674. #pragma unroll
  675. for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
  676. #pragma unroll
  677. for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
  678. mma_B B;
  679. float2 dsB[mma_C::ne/2];
  680. B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
  681. #pragma unroll
  682. for (int l = 0; l < mma_C::ne/2; ++l) {
  683. const int j = j0 + mma_C::get_j(l);
  684. dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
  685. }
  686. #pragma unroll
  687. for (int n = 0; n < ntx; ++n) {
  688. mma_C C;
  689. C.mma_K8(A[n][k01/QI8_1], B);
  690. #pragma unroll
  691. for (int l = 0; l < mma_C::ne; ++l) {
  692. sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l];
  693. sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y;
  694. }
  695. }
  696. }
  697. }
  698. }
  699. template <int mmq_x, int mmq_y, int nwarps>
  700. static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
  701. const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
  702. constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
  703. const int * x_qs = (const int *) x;
  704. const float * x_df = (const float *) x_qs + txs.qs;
  705. const int * y_qs = (const int *) y + 4;
  706. const float * y_df = (const float *) y;
  707. // #pragma unroll
  708. for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
  709. const int k0 = k00 + k01;
  710. #pragma unroll
  711. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  712. const int j = j0 + threadIdx.y;
  713. #pragma unroll
  714. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  715. const int i = i0 + threadIdx.x;
  716. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_16_q8_1_impl<QI8_0>(
  717. &x_qs[i*(2*WARP_SIZE + 1) + k0],
  718. &y_qs[j*MMQ_TILE_Y_K + k01],
  719. &x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)],
  720. y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
  721. }
  722. }
  723. }
  724. }
  725. template <int mmq_x, int mmq_y, int nwarps>
  726. static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
  727. const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
  728. #ifdef INT8_MMA_AVAILABLE
  729. typedef mma_int_A_I16K4 mma_A;
  730. typedef mma_int_A_I16K8 mma_A_K8;
  731. typedef mma_int_B_J8K4 mma_B;
  732. typedef mma_int_C_I16J8 mma_C;
  733. constexpr int granularity = mmq_get_granularity_device(mmq_x);
  734. constexpr int rows_per_warp = 2 * granularity;
  735. constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
  736. y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
  737. const int * x_qs = (const int *) x;
  738. const float * x_df = (const float *) x_qs + WARP_SIZE*2;
  739. const int * y_qs = (const int *) y + 4;
  740. const float * y_df = (const float *) y;
  741. const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
  742. mma_A A[ntx][8];
  743. float dA[ntx][mma_C::ne/2][8];
  744. #pragma unroll
  745. for (int n = 0; n < ntx; ++n) {
  746. #pragma unroll
  747. for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
  748. const int k0 = k00 + k01;
  749. ((mma_A_K8 *) A[n])[k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
  750. }
  751. #pragma unroll
  752. for (int l = 0; l < mma_C::ne/2; ++l) {
  753. const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
  754. #pragma unroll
  755. for (int k01 = 0; k01 < WARP_SIZE; k01 += 4) {
  756. const int k0 = k00 + k01;
  757. dA[n][l][k01/4] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4];
  758. }
  759. }
  760. }
  761. #pragma unroll
  762. for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
  763. #pragma unroll
  764. for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
  765. mma_B B[2];
  766. float dB[mma_C::ne/2];
  767. B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
  768. B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
  769. #pragma unroll
  770. for (int l = 0; l < mma_C::ne/2; ++l) {
  771. const int j = j0 + mma_C::get_j(l);
  772. dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
  773. }
  774. #pragma unroll
  775. for (int n = 0; n < ntx; ++n) {
  776. mma_C C[2];
  777. C[0].mma_K4(A[n][k01/4 + 0], B[0]);
  778. C[1].mma_K4(A[n][k01/4 + 1], B[1]);
  779. #pragma unroll
  780. for (int l = 0; l < mma_C::ne; ++l) {
  781. sum[(j0/mma_C::J + n)*mma_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]);
  782. }
  783. }
  784. }
  785. }
  786. #else
  787. GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
  788. NO_DEVICE_CODE;
  789. #endif // INT8_MMA_AVAILABLE
  790. }
  791. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
  792. const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
  793. #ifdef INT8_MMA_AVAILABLE
  794. int * x_qs = (int *) x_tile;
  795. half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
  796. #else
  797. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
  798. int * x_qs = (int *) x_tile;
  799. half2 * x_dm = (half2 *) (x_qs + txs.qs);
  800. #endif // INT8_MMA_AVAILABLE
  801. const int kqsx = threadIdx.x % QI2_K;
  802. #pragma unroll
  803. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI2_K) {
  804. int i = i0 + threadIdx.y*(WARP_SIZE/QI2_K) + threadIdx.x/QI2_K;
  805. if (need_check) {
  806. i = min(i, i_max);
  807. }
  808. const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride;
  809. const int x_ql_0 = get_int_b2(bxi->qs, kqsx);
  810. #pragma unroll
  811. for (int l = 0; l < QR2_K; ++l) {
  812. const int k = (kqsx/8)*32 + l*8 + kqsx % 8;
  813. const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
  814. #ifdef INT8_MMA_AVAILABLE
  815. x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
  816. #else
  817. x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k;
  818. #endif // INT8_MMA_AVAILABLE
  819. }
  820. const int sc_m = bxi->scales[kqsx];
  821. #ifdef FAST_FP16_AVAILABLE
  822. const half2 x_dm_ik = __hmul2(bxi->dm, make_half2(sc_m & 0x0F, sc_m >> 4));
  823. #else
  824. const float2 bxi_dmf = __half22float2(bxi->dm);
  825. const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
  826. #endif // FAST_FP16_AVAILABLE
  827. #ifdef INT8_MMA_AVAILABLE
  828. x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
  829. #else
  830. x_dm[i*(WARP_SIZE + 1) + kqsx] = x_dm_ik;
  831. #endif // INT8_MMA_AVAILABLE
  832. }
  833. }
  834. template <int mmq_x, int mmq_y, int nwarps>
  835. static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
  836. const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
  837. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
  838. const int * x_qs = (const int *) x;
  839. const half2 * x_dm = (const half2 *) x_qs + txs.qs;
  840. const int * y_qs = (const int *) y + 4;
  841. const half2 * y_ds = (const half2 *) y;
  842. float2 y_df[mmq_x/nwarps];
  843. #pragma unroll
  844. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  845. const int j = j0 + threadIdx.y;
  846. y_df[j0/nwarps] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
  847. }
  848. #pragma unroll
  849. for (int k01 = 0; k01 < WARP_SIZE; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
  850. const int k0 = k00 + k01;
  851. #pragma unroll
  852. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  853. const int j = j0 + threadIdx.y;
  854. #pragma unroll
  855. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  856. const int i = i0 + threadIdx.x;
  857. if (k01 < WARP_SIZE/2) {
  858. constexpr int ns = 2;
  859. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
  860. &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
  861. &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
  862. &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
  863. } else {
  864. constexpr int ns = 1;
  865. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
  866. &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
  867. &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
  868. &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
  869. }
  870. }
  871. }
  872. }
  873. }
  874. template <int mmq_x, int mmq_y, int nwarps>
  875. static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
  876. const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
  877. #ifdef INT8_MMA_AVAILABLE
  878. typedef mma_int_A_I16K4 mma_A;
  879. typedef mma_int_A_I16K8 mma_A_K8;
  880. typedef mma_int_B_J8K4 mma_B;
  881. typedef mma_int_C_I16J8 mma_C;
  882. constexpr int granularity = mmq_get_granularity_device(mmq_x);
  883. constexpr int rows_per_warp = 2 * granularity;
  884. constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
  885. y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
  886. const int * x_qs = (const int *) x;
  887. const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2;
  888. const int * y_qs = (const int *) y + 4;
  889. const half2 * y_ds = (const half2 *) y;
  890. const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
  891. mma_A A[ntx][8];
  892. float dA[ntx][mma_C::ne/2][8];
  893. float mA[ntx][mma_C::ne/2][8];
  894. #pragma unroll
  895. for (int n = 0; n < ntx; ++n) {
  896. #pragma unroll
  897. for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
  898. const int k0 = k00 + k01;
  899. ((mma_A_K8 *) A[n])[k01/QI8_1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
  900. }
  901. }
  902. #pragma unroll
  903. for (int n = 0; n < ntx; ++n) {
  904. #pragma unroll
  905. for (int l = 0; l < mma_C::ne/2; ++l) {
  906. const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
  907. #pragma unroll
  908. for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1/2) {
  909. const int k0 = k00 + k01;
  910. const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/(QI8_1/2)]);
  911. dA[n][l][k01/(QI8_1/2)] = dm.x;
  912. mA[n][l][k01/(QI8_1/2)] = dm.y;
  913. }
  914. }
  915. }
  916. #pragma unroll
  917. for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
  918. float2 dB[mma_C::ne/2];
  919. #pragma unroll
  920. for (int l = 0; l < mma_C::ne/2; ++l) {
  921. const int j = j0 + mma_C::get_j(l);
  922. dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
  923. }
  924. #pragma unroll
  925. for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
  926. mma_B B[2];
  927. B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
  928. B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
  929. mma_C Cm[2];
  930. if (k01 >= WARP_SIZE * 3/4) {
  931. mma_A A1;
  932. A1.x[0] = 0x01010101;
  933. A1.x[1] = 0x01010101;
  934. Cm[0].mma_K4(A1, B[0]);
  935. Cm[1].mma_K4(A1, B[1]);
  936. }
  937. #pragma unroll
  938. for (int n = 0; n < ntx; ++n) {
  939. mma_C Cd[2];
  940. Cd[0].mma_K4(A[n][k01/4 + 0], B[0]);
  941. Cd[1].mma_K4(A[n][k01/4 + 1], B[1]);
  942. #pragma unroll
  943. for (int l = 0; l < mma_C::ne; ++l) {
  944. float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1];
  945. if (k01 >= WARP_SIZE * 3/4) {
  946. tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1];
  947. }
  948. sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y);
  949. }
  950. }
  951. }
  952. #pragma unroll
  953. for (int k01 = 0; k01 < WARP_SIZE * 3/4; k01 += QI8_1) {
  954. float2 sB[mma_C::ne/2];
  955. #pragma unroll
  956. for (int l = 0; l < mma_C::ne/2; ++l) {
  957. const int j = j0 + mma_C::get_j(l);
  958. sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
  959. }
  960. #pragma unroll
  961. for (int n = 0; n < ntx; ++n) {
  962. #pragma unroll
  963. for (int l = 0; l < mma_C::ne; ++l) {
  964. sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x;
  965. sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y;
  966. }
  967. }
  968. }
  969. }
  970. #else
  971. GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
  972. NO_DEVICE_CODE;
  973. #endif // INT8_MMA_AVAILABLE
  974. }
  975. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
  976. const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
  977. #ifdef INT8_MMA_AVAILABLE
  978. int * x_qs = (int *) x_tile;
  979. float * x_df = (float *) (x_qs + WARP_SIZE*2);
  980. #else
  981. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
  982. int * x_qs = (int *) x_tile;
  983. float * x_df = (float *) (x_qs + txs.qs);
  984. int * x_sc = (int *) (x_df + txs.dm);
  985. #endif // INT8_MMA_AVAILABLE
  986. const int kqsx = threadIdx.x % QI3_K;
  987. #pragma unroll
  988. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI3_K) {
  989. int i = i0 + threadIdx.y * (WARP_SIZE/QI3_K) + threadIdx.x / QI3_K;
  990. if (need_check) {
  991. i = min(i, i_max);
  992. }
  993. const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
  994. const int x_ql_0 = get_int_b2(bxi->qs, kqsx);
  995. const int x_qh_0 = get_int_b2(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2)));
  996. #pragma unroll
  997. for (int l = 0; l < QR3_K; ++l) {
  998. const int k = (kqsx/8)*32 + l*8 + kqsx % 8;
  999. const int x_ql_k = (x_ql_0 >> (2*l)) & 0x03030303;
  1000. const int x_qh_k = ((x_qh_0 >> l) << 2) & 0x04040404;
  1001. const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);
  1002. #ifdef INT8_MMA_AVAILABLE
  1003. x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
  1004. #else
  1005. x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k;
  1006. #endif // INT8_MMA_AVAILABLE
  1007. }
  1008. }
  1009. #pragma unroll
  1010. for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) {
  1011. int i = i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8);
  1012. if (need_check) {
  1013. i = min(i, i_max);
  1014. }
  1015. const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
  1016. const int ksc = threadIdx.x % (WARP_SIZE/8);
  1017. const int ksc_low = ksc % (QI3_K/8);
  1018. const int shift_low = 4 * (ksc / (QI3_K/8));
  1019. const int sc_low = (get_int_b2(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
  1020. const int ksc_high = QI3_K/8;
  1021. const int shift_high = 2 * ksc;
  1022. const int sc_high = ((get_int_b2(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
  1023. const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
  1024. #ifdef INT8_MMA_AVAILABLE
  1025. const int8_t * sc8 = (const int8_t *) &sc;
  1026. const float d = bxi->d;
  1027. #pragma unroll
  1028. for (int l = 0; l < sizeof(int); ++l) {
  1029. x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*(threadIdx.x % (WARP_SIZE/8)) + l] = d*sc8[l];
  1030. }
  1031. #else
  1032. x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = sc;
  1033. #endif // INT8_MMA_AVAILABLE
  1034. }
  1035. #ifndef INT8_MMA_AVAILABLE
  1036. #pragma unroll
  1037. for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) {
  1038. int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y;
  1039. if (need_check) {
  1040. i = min(i, i_max);
  1041. }
  1042. const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
  1043. x_df[i] = bxi->d;
  1044. }
  1045. #endif // INT8_MMA_AVAILABLE
  1046. }
  1047. template <int mmq_x, int mmq_y, int nwarps>
  1048. static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
  1049. const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
  1050. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
  1051. const int * x_qs = (const int *) x;
  1052. const float * x_df = (const float *) x_qs + txs.qs;
  1053. const int * x_sc = (const int *) x_df + txs.dm;
  1054. const int * y_qs = (const int *) y + 4;
  1055. const float * y_df = (const float *) y;
  1056. // #pragma unroll
  1057. for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
  1058. const int k0 = k00 + k01;
  1059. #pragma unroll
  1060. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  1061. const int j = j0 + threadIdx.y;
  1062. #pragma unroll
  1063. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  1064. const int i = i0 + threadIdx.x;
  1065. const int8_t * scales = ((const int8_t *) (x_sc + i*(WARP_SIZE/8) + i/8)) + k0/4;
  1066. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq(
  1067. &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales,
  1068. x_df[i], y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
  1069. }
  1070. }
  1071. }
  1072. }
  1073. static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, const int ksc) {
  1074. // scale arrangement after the following two lines:
  1075. // - ksc == 0: sc0, sc1, sc2, sc3
  1076. // - ksc == 1: sc4, sc5, sc6, sc7
  1077. // - ksc == 2: m0, m1, m2, m3
  1078. // - ksc == 3: m4, m5, m6, m7
  1079. return ((scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F) | // lower 4 bits
  1080. ((scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030); // upper 2 bits
  1081. }
  1082. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
  1083. const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
  1084. #ifdef INT8_MMA_AVAILABLE
  1085. int * x_qs = (int *) x_tile;
  1086. half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
  1087. #else
  1088. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
  1089. int * x_qs = (int *) x_tile;
  1090. half2 * x_dm = (half2 *) (x_qs + txs.qs);
  1091. int * x_sc = (int *) (x_dm + txs.dm);
  1092. #endif // INT8_MMA_AVAILABLE
  1093. #pragma unroll
  1094. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  1095. int i = i0 + threadIdx.y;
  1096. if (need_check) {
  1097. i = min(i, i_max);
  1098. }
  1099. const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
  1100. const int qs0 = get_int_b4(bxi->qs, threadIdx.x);
  1101. #ifdef INT8_MMA_AVAILABLE
  1102. x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
  1103. x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
  1104. #else
  1105. x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
  1106. #endif // INT8_MMA_AVAILABLE
  1107. }
  1108. #ifdef INT8_MMA_AVAILABLE
  1109. #pragma unroll
  1110. for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
  1111. int i = (i0 + threadIdx.y*16 + threadIdx.x/(WARP_SIZE/16)) % mmq_y;
  1112. if (need_check) {
  1113. i = min(i, i_max);
  1114. }
  1115. const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
  1116. const int * scales = (const int *) bxi->scales;
  1117. const int ksc = threadIdx.x % (WARP_SIZE/16);
  1118. const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
  1119. const int m32 = unpack_scales_q45_K(scales, ksc + 2);
  1120. const uint8_t * sc8 = (const uint8_t *) &sc32;
  1121. const uint8_t * m8 = (const uint8_t *) &m32;
  1122. const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
  1123. #pragma unroll
  1124. for (int l = 0; l < sizeof(int); ++l) {
  1125. x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
  1126. }
  1127. }
  1128. #else
  1129. #pragma unroll
  1130. for (int i0 = 0; i0 < mmq_y; i0 += nwarps*QI4_K) {
  1131. int i = (i0 + threadIdx.y*QI4_K + threadIdx.x) % mmq_y;
  1132. if (need_check) {
  1133. i = min(i, i_max);
  1134. }
  1135. const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
  1136. x_dm[i] = bxi->dm;
  1137. }
  1138. #pragma unroll
  1139. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
  1140. int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
  1141. if (need_check) {
  1142. i = min(i, i_max);
  1143. }
  1144. const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI4_K/8);
  1145. const int * scales = (const int *) bxi->scales;
  1146. const int ksc = threadIdx.x % (WARP_SIZE/8);
  1147. const int scales8 = unpack_scales_q45_K(scales, ksc);
  1148. x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
  1149. }
  1150. #endif // INT8_MMA_AVAILABLE
  1151. }
  1152. template <int mmq_x, int mmq_y, int nwarps>
  1153. static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
  1154. const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
  1155. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
  1156. const int * x_qs = (const int *) x;
  1157. const half2 * x_dm = (const half2 *) x_qs + txs.qs;
  1158. const int * x_sc = (const int *) x_dm + txs.dm;
  1159. const int * y_qs = (const int *) y + 4;
  1160. const half2 * y_ds = (const half2 *) y;
  1161. // #pragma unroll
  1162. for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) {
  1163. const int k0 = k00 + k01;
  1164. #pragma unroll
  1165. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  1166. const int j = j0 + threadIdx.y;
  1167. #pragma unroll
  1168. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  1169. const int i = i0 + threadIdx.x;
  1170. const uint8_t * sc = (const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/32] + 2*(k01/16);
  1171. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
  1172. &x_qs[i*(WARP_SIZE + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
  1173. x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
  1174. }
  1175. }
  1176. }
  1177. }
  1178. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
  1179. const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
  1180. #ifdef INT8_MMA_AVAILABLE
  1181. int * x_qs = (int *) x_tile;
  1182. half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2);
  1183. #else
  1184. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
  1185. int * x_qs = (int *) x_tile;
  1186. half2 * x_dm = (half2 *) (x_qs + txs.qs);
  1187. int * x_sc = (int *) (x_dm + txs.dm);
  1188. #endif // INT8_MMA_AVAILABLE
  1189. #pragma unroll
  1190. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  1191. int i = i0 + threadIdx.y;
  1192. if (need_check) {
  1193. i = min(i, i_max);
  1194. }
  1195. const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
  1196. const int ky = QR5_K*threadIdx.x;
  1197. const int ql = get_int_b4(bxi->qs, threadIdx.x);
  1198. const int ql0 = (ql >> 0) & 0x0F0F0F0F;
  1199. const int ql1 = (ql >> 4) & 0x0F0F0F0F;
  1200. const int qh = get_int_b4(bxi->qh, threadIdx.x % (QI5_K/4));
  1201. const int qh0 = ((qh >> (2 * (threadIdx.x / (QI5_K/4)) + 0)) << 4) & 0x10101010;
  1202. const int qh1 = ((qh >> (2 * (threadIdx.x / (QI5_K/4)) + 1)) << 4) & 0x10101010;
  1203. const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0;
  1204. const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + QI5_K/4;
  1205. #ifdef INT8_MMA_AVAILABLE
  1206. x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;
  1207. x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;
  1208. #else
  1209. x_qs[i*(2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
  1210. x_qs[i*(2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
  1211. #endif // INT8_MMA_AVAILABLE
  1212. }
  1213. #ifdef INT8_MMA_AVAILABLE
  1214. #pragma unroll
  1215. for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
  1216. int i = (i0 + threadIdx.y*16 + threadIdx.x/(WARP_SIZE/16)) % mmq_y;
  1217. if (need_check) {
  1218. i = min(i, i_max);
  1219. }
  1220. const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
  1221. const int * scales = (const int *) bxi->scales;
  1222. const int ksc = threadIdx.x % (WARP_SIZE/16);
  1223. const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
  1224. const int m32 = unpack_scales_q45_K(scales, ksc + 2);
  1225. const uint8_t * sc8 = (const uint8_t *) &sc32;
  1226. const uint8_t * m8 = (const uint8_t *) &m32;
  1227. const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
  1228. #pragma unroll
  1229. for (int l = 0; l < sizeof(int); ++l) {
  1230. x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
  1231. }
  1232. }
  1233. #else
  1234. #pragma unroll
  1235. for (int i0 = 0; i0 < mmq_y; i0 += nwarps*QI5_K) {
  1236. int i = (i0 + threadIdx.y*QI5_K + threadIdx.x) % mmq_y;
  1237. if (need_check) {
  1238. i = min(i, i_max);
  1239. }
  1240. const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
  1241. x_dm[i] = bxi->dm;
  1242. }
  1243. #pragma unroll
  1244. for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) {
  1245. int i = (i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8)) % mmq_y;
  1246. if (need_check) {
  1247. i = min(i, i_max);
  1248. }
  1249. const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
  1250. const int * scales = (const int *) bxi->scales;
  1251. const int ksc = threadIdx.x % (WARP_SIZE/8);
  1252. const int scales8 = unpack_scales_q45_K(scales, ksc);
  1253. x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
  1254. }
  1255. #endif // INT8_MMA_AVAILABLE
  1256. }
  1257. template <int mmq_x, int mmq_y, int nwarps>
  1258. static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
  1259. const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
  1260. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
  1261. const int * x_qs = (const int *) x;
  1262. const half2 * x_dm = (const half2 *) x_qs + txs.qs;
  1263. const int * x_sc = (const int *) x_dm + txs.dm;
  1264. const int * y_qs = (const int *) y + 4;
  1265. const half2 * y_ds = (const half2 *) y;
  1266. // #pragma unroll
  1267. for (int k01 = 0; k01 < WARP_SIZE; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) {
  1268. const int k0 = k00 + k01;
  1269. #pragma unroll
  1270. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  1271. const int j = j0 + threadIdx.y;
  1272. #pragma unroll
  1273. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  1274. const int i = i0 + threadIdx.x;
  1275. const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k00/32]) + 2*(k01/16);
  1276. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
  1277. &x_qs[i*(QR5_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
  1278. x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
  1279. }
  1280. }
  1281. }
  1282. }
  1283. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
  1284. const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
  1285. #ifdef INT8_MMA_AVAILABLE
  1286. int * x_qs = (int *) x_tile;
  1287. float * x_df = (float *) (x_qs + WARP_SIZE*2);
  1288. int * x_sc = (int *) (x_df + WARP_SIZE/QI6_K);
  1289. #else
  1290. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
  1291. int * x_qs = (int *) x_tile;
  1292. float * x_df = (float *) (x_qs + txs.qs);
  1293. int * x_sc = (int *) (x_df + txs.dm);
  1294. #endif // INT8_MMA_AVAILABLE
  1295. #pragma unroll
  1296. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  1297. int i = i0 + threadIdx.y;
  1298. if (need_check) {
  1299. i = min(i, i_max);
  1300. }
  1301. const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
  1302. const int ql = get_int_b2(bxi->ql, threadIdx.x);
  1303. const int ql0 = (ql >> 0) & 0x0F0F0F0F;
  1304. const int ql1 = (ql >> 4) & 0x0F0F0F0F;
  1305. const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (threadIdx.x / (QI6_K/2)) + threadIdx.x % (QI6_K/4));
  1306. const int qh0 = ((qh >> ((threadIdx.x & 0x08) >> 2)) << 4) & 0x30303030;
  1307. const int qh1 = (qh >> ((threadIdx.x & 0x08) >> 2)) & 0x30303030;
  1308. const int kq0 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + 0;
  1309. const int kq1 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + QI6_K/2;
  1310. #ifdef INT8_MMA_AVAILABLE
  1311. x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
  1312. x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
  1313. #else
  1314. x_qs[i*(2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
  1315. x_qs[i*(2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
  1316. #endif // INT8_MMA_AVAILABLE
  1317. }
  1318. const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
  1319. const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
  1320. #pragma unroll
  1321. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {
  1322. int i = (i0 + threadIdx.y * QI6_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
  1323. if (need_check) {
  1324. i = min(i, i_max);
  1325. }
  1326. const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd;
  1327. #ifdef INT8_MMA_AVAILABLE
  1328. x_df[i*MMQ_MMA_TILE_X_K_Q6_K + kbxd] = bxi->d;
  1329. #else
  1330. x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + kbxd] = bxi->d;
  1331. #endif // INT8_MMA_AVAILABLE
  1332. }
  1333. #pragma unroll
  1334. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
  1335. int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
  1336. if (need_check) {
  1337. i = min(i, i_max);
  1338. }
  1339. const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4;
  1340. #ifdef INT8_MMA_AVAILABLE
  1341. x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
  1342. #else
  1343. x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
  1344. #endif // INT8_MMA_AVAILABLE
  1345. }
  1346. }
  1347. template <int mmq_x, int mmq_y, int nwarps>
  1348. static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
  1349. const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
  1350. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
  1351. const int * x_qs = (const int *) x;
  1352. const float * x_df = (const float *) x_qs + txs.qs;
  1353. const int * x_sc = (const int *) x_df + txs.dm;
  1354. const int * y_qs = (const int *) y + 4;
  1355. const float * y_df = (const float *) y;
  1356. // #pragma unroll
  1357. for (int k01 = 0; k01 < WARP_SIZE; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) {
  1358. const int k0 = k00 + k01;
  1359. #pragma unroll
  1360. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  1361. const int j = j0 + threadIdx.y;
  1362. #pragma unroll
  1363. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  1364. const int i = i0 + threadIdx.x;
  1365. const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]);
  1366. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
  1367. &x_qs[i*(QR6_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc,
  1368. x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
  1369. }
  1370. }
  1371. }
  1372. }
  1373. template <int mmq_x, int mmq_y, int nwarps>
  1374. static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
  1375. const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
  1376. #ifdef INT8_MMA_AVAILABLE
  1377. typedef mma_int_A_I16K4 mma_A;
  1378. typedef mma_int_B_J8K4 mma_B;
  1379. typedef mma_int_C_I16J8 mma_C;
  1380. constexpr int granularity = mmq_get_granularity_device(mmq_x);
  1381. constexpr int rows_per_warp = 2 * granularity;
  1382. constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
  1383. y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
  1384. const int * x_qs = (const int *) x;
  1385. const float * x_df = (const float *) x_qs + WARP_SIZE*2;
  1386. const int * x_sc = (const int *) x_df + WARP_SIZE/QI6_K;
  1387. const int * y_qs = (const int *) y + 4;
  1388. const float * y_df = (const float *) y;
  1389. const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
  1390. mma_A A[ntx][8];
  1391. int scA[ntx][mma_C::ne/2][8];
  1392. float dA[ntx][mma_C::ne/2];
  1393. #pragma unroll
  1394. for (int n = 0; n < ntx; ++n) {
  1395. #pragma unroll
  1396. for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
  1397. const int k0 = k00 + k01;
  1398. A[n][k01/4 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K);
  1399. A[n][k01/4 + 1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K);
  1400. }
  1401. #pragma unroll
  1402. for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) {
  1403. const int k0 = k00 + k01;
  1404. #pragma unroll
  1405. for (int l = 0; l < mma_C::ne/2; ++l) {
  1406. const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
  1407. const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16];
  1408. const int8_t * sc = (const int8_t *) &sc_packed;
  1409. #pragma unroll
  1410. for (int ksc = 0; ksc < sizeof(int); ++ksc) {
  1411. scA[n][l][k01/4 + ksc] = sc[ksc];
  1412. }
  1413. }
  1414. }
  1415. #pragma unroll
  1416. for (int l = 0; l < mma_C::ne/2; ++l) {
  1417. const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
  1418. dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K];
  1419. }
  1420. }
  1421. #pragma unroll
  1422. for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
  1423. float tmp[ntx][mma_C::ne] = {{0.0f}};
  1424. #pragma unroll
  1425. for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
  1426. mma_B B[2];
  1427. float dB[mma_C::ne/2];
  1428. B[0].load(y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K);
  1429. B[1].load(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k01, MMQ_TILE_Y_K);
  1430. #pragma unroll
  1431. for (int l = 0; l < mma_C::ne/2; ++l) {
  1432. const int j = j0 + mma_C::get_j(l);
  1433. dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
  1434. }
  1435. #pragma unroll
  1436. for (int n = 0; n < ntx; ++n) {
  1437. mma_C C[2];
  1438. C[0].mma_K4(A[n][k01/4 + 0], B[0]);
  1439. C[1].mma_K4(A[n][k01/4 + 1], B[1]);
  1440. #pragma unroll
  1441. for (int l = 0; l < mma_C::ne; ++l) {
  1442. tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2];
  1443. }
  1444. }
  1445. }
  1446. #pragma unroll
  1447. for (int n = 0; n < ntx; ++n) {
  1448. #pragma unroll
  1449. for (int l = 0; l < mma_C::ne; ++l) {
  1450. sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp[n][l]*dA[n][l/2];
  1451. }
  1452. }
  1453. }
  1454. #else
  1455. GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
  1456. NO_DEVICE_CODE;
  1457. #endif // INT8_MMA_AVAILABLE
  1458. }
  1459. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
  1460. const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
  1461. #ifdef INT8_MMA_AVAILABLE
  1462. int * x_qs = (int *) x_tile;
  1463. float * x_df = (float *) (x_qs + WARP_SIZE*2);
  1464. #else
  1465. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
  1466. int * x_qs = (int *) x_tile;
  1467. float * x_df = (float *) (x_qs + txs.qs);
  1468. #endif // INT8_MMA_AVAILABLE
  1469. const int kbx = threadIdx.x / QI4_NL;
  1470. const int kqsx = threadIdx.x % QI4_NL;
  1471. #pragma unroll
  1472. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  1473. int i = i0 + threadIdx.y;
  1474. if (need_check) {
  1475. i = min(i, i_max);
  1476. }
  1477. const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx;
  1478. const int aux_q4 = get_int_b2(bxi->qs, kqsx);
  1479. const int2 v = get_int_from_table_16(aux_q4);
  1480. const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
  1481. #ifdef INT8_MMA_AVAILABLE
  1482. x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
  1483. x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
  1484. #else
  1485. x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
  1486. x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
  1487. #endif // INT8_MMA_AVAILABLE
  1488. }
  1489. const int blocks_per_tile_x_row = WARP_SIZE / QI4_NL;
  1490. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  1491. #pragma unroll
  1492. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_NL) {
  1493. int i = i0 + threadIdx.y * QI4_NL + threadIdx.x / blocks_per_tile_x_row;
  1494. if (need_check) {
  1495. i = min(i, i_max);
  1496. }
  1497. const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
  1498. #ifdef INT8_MMA_AVAILABLE
  1499. x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
  1500. #else
  1501. x_df[i*(WARP_SIZE/4) + i/4 + kbxd] = __half2float(bxi->d);
  1502. #endif // INT8_MMA_AVAILABLE
  1503. }
  1504. }
  1505. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xxs(
  1506. const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
  1507. #ifdef INT8_MMA_AVAILABLE
  1508. int * x_qs = (int *) x_tile;
  1509. float * x_df = (float *) (x_qs + WARP_SIZE*2);
  1510. #else
  1511. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);
  1512. int * x_qs = (int *) x_tile;
  1513. float * x_df = (float *) (x_qs + txs.qs);
  1514. #endif // INT8_MMA_AVAILABLE
  1515. const int kqsx = threadIdx.x % (QI2_XXS/2);
  1516. #pragma unroll
  1517. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_XXS/2)) {
  1518. int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_XXS) + threadIdx.x/(QI2_XXS/2);
  1519. if (need_check) {
  1520. i = min(i, i_max);
  1521. }
  1522. const block_iq2_xxs * bxi = (const block_iq2_xxs *) x + kbx0 + i*stride;
  1523. const int q2 = get_int_b2(bxi->qs, 2*kqsx+0);
  1524. const uint8_t * aux8 = (const uint8_t *) &q2;
  1525. const uint32_t aux32 = get_int_b2(bxi->qs, 2*kqsx+1);
  1526. #pragma unroll
  1527. for (int l = 0; l < QR2_XXS; ++l) {
  1528. const int * grid_pos = (const int *) (iq2xxs_grid + aux8[l]);
  1529. const int signs_packed = ksigns_iq2xs[(aux32 >> (7*l)) & 0x7F];
  1530. const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000);
  1531. const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0);
  1532. const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
  1533. const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
  1534. #ifdef INT8_MMA_AVAILABLE
  1535. x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
  1536. x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;
  1537. #else
  1538. x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid0;
  1539. x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid1;
  1540. #endif // INT8_MMA_AVAILABLE
  1541. }
  1542. const int ls = aux32 >> 28;
  1543. const float d = bxi->d;
  1544. #ifdef INT8_MMA_AVAILABLE
  1545. x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
  1546. #else
  1547. x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/4;
  1548. #endif // INT8_MMA_AVAILABLE
  1549. }
  1550. }
  1551. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xs(
  1552. const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
  1553. #ifdef INT8_MMA_AVAILABLE
  1554. int * x_qs = (int *) x_tile;
  1555. float * x_df = (float *) (x_qs + WARP_SIZE*2);
  1556. #else
  1557. constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
  1558. int * x_qs = (int *) x_tile;
  1559. float * x_df = (float *) (x_qs + txs.qs);
  1560. #endif // INT8_MMA_AVAILABLE
  1561. const int kqsx = threadIdx.x % (QI2_XS/2);
  1562. #pragma unroll
  1563. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_XS/2)) {
  1564. int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_XS) + threadIdx.x/(QI2_XS/2);
  1565. if (need_check) {
  1566. i = min(i, i_max);
  1567. }
  1568. const block_iq2_xs * bxi = (const block_iq2_xs *) x + kbx0 + i*stride;
  1569. const int2 q2_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));
  1570. const uint16_t * q2 = (const uint16_t *) &q2_packed;
  1571. #pragma unroll
  1572. for (int l = 0; l < QR2_XS; ++l) {
  1573. const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l] & 0x000001FF));
  1574. const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9));
  1575. const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
  1576. const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
  1577. #ifdef INT8_MMA_AVAILABLE
  1578. x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
  1579. x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
  1580. #else
  1581. x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l;
  1582. x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h;
  1583. #endif // INT8_MMA_AVAILABLE
  1584. }
  1585. const int ls = bxi->scales[kqsx];
  1586. const float d = bxi->d;
  1587. #ifdef INT8_MMA_AVAILABLE
  1588. x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
  1589. x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
  1590. #else
  1591. x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
  1592. x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
  1593. #endif // INT8_MMA_AVAILABLE
  1594. }
  1595. }
  1596. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_s(
  1597. const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
  1598. #ifdef INT8_MMA_AVAILABLE
  1599. int * x_qs = (int *) x_tile;
  1600. float * x_df = (float *) (x_qs + WARP_SIZE*2);
  1601. #else
  1602. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);
  1603. int * x_qs = (int *) x_tile;
  1604. float * x_df = (float *) (x_qs + txs.qs);
  1605. #endif // INT8_MMA_AVAILABLE
  1606. const int kqsx = threadIdx.x % (QI2_S/2);
  1607. #pragma unroll
  1608. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_S/2)) {
  1609. int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_S) + threadIdx.x/(QI2_S/2);
  1610. if (need_check) {
  1611. i = min(i, i_max);
  1612. }
  1613. const block_iq2_s * bxi = (const block_iq2_s *) x + kbx0 + i*stride;
  1614. const int qs_packed = get_int_b2(bxi->qs, kqsx);
  1615. const uint8_t * qs = (const uint8_t *) &qs_packed;
  1616. const int qh = bxi->qh[kqsx];
  1617. const int signs_packed_32 = get_int_b2(bxi->qs, QK_K/32 + kqsx);
  1618. const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32;
  1619. #pragma unroll
  1620. for (int l = 0; l < QR2_S; ++l) {
  1621. const int * grid_pos = (const int *)(iq2s_grid + (qs[l] | ((qh << (8-2*l)) & 0x300)));
  1622. const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000);
  1623. const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000);
  1624. const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
  1625. const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
  1626. #ifdef INT8_MMA_AVAILABLE
  1627. x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
  1628. x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
  1629. #else
  1630. x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l;
  1631. x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h;
  1632. #endif // INT8_MMA_AVAILABLE
  1633. }
  1634. const int ls = bxi->scales[kqsx];
  1635. const float d = bxi->d;
  1636. #ifdef INT8_MMA_AVAILABLE
  1637. x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
  1638. x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
  1639. #else
  1640. x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
  1641. x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
  1642. #endif // INT8_MMA_AVAILABLE
  1643. }
  1644. }
  1645. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_xxs(
  1646. const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
  1647. #ifdef INT8_MMA_AVAILABLE
  1648. int * x_qs = (int *) x_tile;
  1649. float * x_df = (float *) (x_qs + WARP_SIZE*2);
  1650. #else
  1651. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);
  1652. int * x_qs = (int *) x_tile;
  1653. float * x_df = (float *) (x_qs + txs.qs);
  1654. #endif // INT8_MMA_AVAILABLE
  1655. const int kqsx = threadIdx.x % (QI3_XXS/2);
  1656. #pragma unroll
  1657. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI3_XXS/2)) {
  1658. int i = i0 + threadIdx.y*(2*WARP_SIZE/QI3_XXS) + threadIdx.x/(QI3_XXS/2);
  1659. if (need_check) {
  1660. i = min(i, i_max);
  1661. }
  1662. const block_iq3_xxs * bxi = (const block_iq3_xxs *) x + kbx0 + i*stride;
  1663. const int2 q3_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));
  1664. const uint8_t * q3 = (const uint8_t *) &q3_packed;
  1665. const uint32_t aux32 = get_int_b2(bxi->qs, QK_K/16 + kqsx);
  1666. #pragma unroll
  1667. for (int l = 0; l < QR3_XXS; ++l) {
  1668. const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]);
  1669. const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l)) & 0x7F));
  1670. const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
  1671. const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
  1672. #ifdef INT8_MMA_AVAILABLE
  1673. x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
  1674. x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;
  1675. #else
  1676. x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l;
  1677. x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h;
  1678. #endif // INT8_MMA_AVAILABLE
  1679. }
  1680. const int ls = aux32 >> 28;
  1681. const float d = bxi->d;
  1682. #ifdef INT8_MMA_AVAILABLE
  1683. x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
  1684. #else
  1685. x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/2;
  1686. #endif // INT8_MMA_AVAILABLE
  1687. }
  1688. }
  1689. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_s(
  1690. const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
  1691. #ifdef INT8_MMA_AVAILABLE
  1692. int * x_qs = (int *) x_tile;
  1693. float * x_df = (float *) (x_qs + WARP_SIZE*2);
  1694. #else
  1695. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
  1696. int * x_qs = (int *) x_tile;
  1697. float * x_df = (float *) (x_qs + txs.qs);
  1698. #endif // INT8_MMA_AVAILABLE
  1699. const int kqsx = threadIdx.x % (QI3_S/2);
  1700. #pragma unroll
  1701. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI3_S/2)) {
  1702. int i = i0 + threadIdx.y*(2*WARP_SIZE/QI3_S) + threadIdx.x/(QI3_S/2);
  1703. if (need_check) {
  1704. i = min(i, i_max);
  1705. }
  1706. const block_iq3_s * bxi = (const block_iq3_s *) x + kbx0 + i*stride;
  1707. const int2 qs_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));
  1708. const uint8_t * qs = (const uint8_t *) &qs_packed;
  1709. const int qh = bxi->qh[kqsx];
  1710. const int signs_packed_32 = get_int_b2(bxi->signs, kqsx);
  1711. const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32;
  1712. #pragma unroll
  1713. for (int l = 0; l < QR3_S; ++l) {
  1714. const int2 grid_pos = make_int2(
  1715. iq3s_grid[qs[2*l+0] | ((qh << (8 - 2*l)) & 0x100)],
  1716. iq3s_grid[qs[2*l+1] | ((qh << (7 - 2*l)) & 0x100)]);
  1717. const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000);
  1718. const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000);
  1719. const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
  1720. const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
  1721. #ifdef INT8_MMA_AVAILABLE
  1722. x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;
  1723. x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;
  1724. #else
  1725. x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid_l;
  1726. x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid_h;
  1727. #endif // INT8_MMA_AVAILABLE
  1728. }
  1729. const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);
  1730. const float d = bxi->d;
  1731. #ifdef INT8_MMA_AVAILABLE
  1732. x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
  1733. #else
  1734. x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = ls*d;
  1735. #endif // INT8_MMA_AVAILABLE
  1736. }
  1737. }
  1738. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_s(
  1739. const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
  1740. #ifdef INT8_MMA_AVAILABLE
  1741. int * x_qs = (int *) x_tile;
  1742. half2 * x_ds = (half2 *) (x_qs + WARP_SIZE*2);
  1743. #else
  1744. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
  1745. int * x_qs = (int *) x_tile;
  1746. half2 * x_ds = (half2 *) (x_qs + txs.qs);
  1747. #endif // INT8_MMA_AVAILABLE
  1748. const int kqsx = threadIdx.x % QI1_S;
  1749. #pragma unroll
  1750. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI1_S) {
  1751. int i = i0 + threadIdx.y*(WARP_SIZE/QI1_S) + threadIdx.x/QI1_S;
  1752. if (need_check) {
  1753. i = min(i, i_max);
  1754. }
  1755. const block_iq1_s * bxi = (const block_iq1_s *) x + kbx0 + i*stride;
  1756. const int qs_packed = get_int_b2(bxi->qs, kqsx);
  1757. const uint8_t * qs = (const uint8_t *) &qs_packed;
  1758. const int qh = bxi->qh[kqsx];
  1759. #pragma unroll
  1760. for (int l = 0; l < QR1_S/2; ++l) {
  1761. const int grid = iq1s_grid_gpu[qs[l] | (((qh >> (3*l)) & 0x07) << 8)];
  1762. const int grid0 = (grid >> 0) & 0x0F0F0F0F;
  1763. const int grid1 = (grid >> 4) & 0x0F0F0F0F;
  1764. #ifdef INT8_MMA_AVAILABLE
  1765. x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;
  1766. x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;
  1767. #else
  1768. x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid0;
  1769. x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid1;
  1770. #endif // INT8_MMA_AVAILABLE
  1771. }
  1772. const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);
  1773. const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
  1774. #ifdef INT8_MMA_AVAILABLE
  1775. x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
  1776. #else
  1777. x_ds[i*(WARP_SIZE/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta);
  1778. #endif // INT8_MMA_AVAILABLE
  1779. }
  1780. }
  1781. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
  1782. const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
  1783. #ifdef INT8_MMA_AVAILABLE
  1784. int * x_qs = (int *) x_tile;
  1785. float * x_df = (float *) (x_qs + WARP_SIZE*2);
  1786. #else
  1787. constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
  1788. int * x_qs = (int *) x_tile;
  1789. float * x_df = (float *) (x_qs + txs.qs);
  1790. #endif // INT8_MMA_AVAILABLE
  1791. const int kbx = 0; // threadIdx.x / QI4_XS
  1792. const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
  1793. #pragma unroll
  1794. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  1795. int i = i0 + threadIdx.y;
  1796. if (need_check) {
  1797. i = min(i, i_max);
  1798. }
  1799. const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride + kbx;
  1800. const int aux_q4 = get_int_b4(bxi->qs, kqsx);
  1801. const int2 v = get_int_from_table_16(aux_q4);
  1802. const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
  1803. #ifdef INT8_MMA_AVAILABLE
  1804. x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
  1805. x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
  1806. #else
  1807. x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
  1808. x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
  1809. #endif // INT8_MMA_AVAILABLE
  1810. }
  1811. #pragma unroll
  1812. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
  1813. int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
  1814. if (need_check) {
  1815. i = min(i, i_max);
  1816. }
  1817. const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;
  1818. const float d = __half2float(bxi->d);
  1819. const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
  1820. | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
  1821. #ifdef INT8_MMA_AVAILABLE
  1822. x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
  1823. #else
  1824. x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
  1825. #endif // INT8_MMA_AVAILABLE
  1826. }
  1827. }
  1828. template<int mmq_x, int mmq_y, int nwarps, bool need_check>
  1829. static __device__ __forceinline__ void mmq_write_back_dp4a(
  1830. const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
  1831. #pragma unroll
  1832. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  1833. const int j = j0 + threadIdx.y;
  1834. if (j > j_max) {
  1835. return;
  1836. }
  1837. #pragma unroll
  1838. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  1839. const int i = i0 + threadIdx.x;
  1840. if (need_check && i > i_max) {
  1841. continue;
  1842. }
  1843. dst[j*stride + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
  1844. }
  1845. }
  1846. }
  1847. template<int mmq_x, int mmq_y, int nwarps, bool need_check>
  1848. static __device__ __forceinline__ void mmq_write_back_mma(
  1849. const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
  1850. typedef mma_int_C_I16J8 mma_C;
  1851. constexpr int granularity = mmq_get_granularity_device(mmq_x);
  1852. constexpr int rows_per_warp = 2 * granularity;
  1853. constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
  1854. const int i0 = (threadIdx.y / ntx) * (ntx*mma_C::I);
  1855. #ifdef INT8_MMA_AVAILABLE
  1856. static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
  1857. #endif // INT8_MMA_AVAILABLE
  1858. #pragma unroll
  1859. for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
  1860. #pragma unroll
  1861. for (int n = 0; n < ntx; ++n) {
  1862. #pragma unroll
  1863. for (int l = 0; l < mma_C::ne; ++l) {
  1864. const int j = j0 + (threadIdx.y % ntx) * mma_C::J + mma_C::get_j(l);
  1865. if (j > j_max) {
  1866. continue;
  1867. }
  1868. const int i = i0 + n*mma_C::I + mma_C::get_i(l);
  1869. if (need_check && i > i_max) {
  1870. continue;
  1871. }
  1872. dst[j*stride + i] = sum[(j0/mma_C::J + n)*mma_C::ne + l];
  1873. }
  1874. }
  1875. }
  1876. }
  1877. // -------------------------------------------------------------------------------------------------------------------------------------
  1878. template <int mmq_x, int mmq_y, int nwarps, bool need_check, ggml_type type>
  1879. struct mmq_type_traits;
  1880. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1881. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
  1882. static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
  1883. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
  1884. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_DS4>;
  1885. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1886. };
  1887. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1888. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
  1889. static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
  1890. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
  1891. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1892. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1893. };
  1894. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1895. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
  1896. static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
  1897. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
  1898. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
  1899. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1900. };
  1901. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1902. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
  1903. static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
  1904. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
  1905. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1906. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1907. };
  1908. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1909. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
  1910. static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
  1911. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
  1912. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
  1913. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1914. };
  1915. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1916. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
  1917. static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
  1918. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>;
  1919. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1920. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1921. };
  1922. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1923. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
  1924. static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
  1925. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
  1926. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1927. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1928. };
  1929. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1930. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
  1931. static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
  1932. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
  1933. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1934. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1935. };
  1936. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1937. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
  1938. static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
  1939. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
  1940. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1941. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1942. };
  1943. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1944. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
  1945. static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
  1946. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
  1947. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1948. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1949. };
  1950. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1951. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_XXS> {
  1952. static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ;
  1953. static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs<mmq_y, nwarps, need_check>;
  1954. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
  1955. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1956. };
  1957. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1958. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_XS> {
  1959. static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ;
  1960. static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs<mmq_y, nwarps, need_check>;
  1961. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1962. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1963. };
  1964. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1965. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_S> {
  1966. static constexpr int vdr = VDR_IQ2_S_Q8_1_MMQ;
  1967. static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s<mmq_y, nwarps, need_check>;
  1968. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1969. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1970. };
  1971. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1972. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ3_XXS> {
  1973. static constexpr int vdr = VDR_IQ3_XXS_Q8_1_MMQ;
  1974. static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs<mmq_y, nwarps, need_check>;
  1975. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
  1976. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1977. };
  1978. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1979. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ3_S> {
  1980. static constexpr int vdr = VDR_IQ3_S_Q8_1_MMQ;
  1981. static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s<mmq_y, nwarps, need_check>;
  1982. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
  1983. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1984. };
  1985. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1986. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ1_S> {
  1987. static constexpr int vdr = VDR_IQ1_S_Q8_1_MMQ;
  1988. static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s<mmq_y, nwarps, need_check>;
  1989. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1990. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1991. };
  1992. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1993. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_NL> {
  1994. static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ;
  1995. static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl<mmq_y, nwarps, need_check>;
  1996. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
  1997. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1998. };
  1999. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  2000. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_XS> {
  2001. static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ;
  2002. static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs<mmq_y, nwarps, need_check>;
  2003. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
  2004. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  2005. };
  2006. template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
  2007. static __device__ void mul_mat_q_process_tile(
  2008. const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
  2009. const int & ne00, const int & ne01, const int & stride01, const int & ne10, const int & ne11, const int & stride11, const int & ne0,
  2010. const int & it, const int & jt, const int & kb0_start, const int & kb0_stop) {
  2011. constexpr int qk = ggml_cuda_type_traits<type>::qk;
  2012. constexpr int mmq_y = get_mmq_y_device();
  2013. constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
  2014. extern __shared__ char data_mul_mat_q[];
  2015. int * tile_y = (int *) data_mul_mat_q;
  2016. int * tile_x = tile_y + GGML_PAD(mmq_x*(WARP_SIZE + WARP_SIZE/QI8_1), nwarps*WARP_SIZE);
  2017. #ifdef INT8_MMA_AVAILABLE
  2018. constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_mma;
  2019. constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
  2020. #else
  2021. constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_dp4a;
  2022. constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
  2023. #endif // INT8_MMA_AVAILABLE
  2024. constexpr int blocks_per_iter = MMQ_ITER_K / qk;
  2025. float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
  2026. const int tile_x_max_i = ne01 - it*mmq_y - 1;
  2027. const int tile_y_max_j = ne11 - jt*mmq_x - 1;
  2028. const int * y = (const int *) yc + jt*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
  2029. for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
  2030. load_tiles(x, tile_x, stride01*it*mmq_y + kb0, tile_x_max_i, stride01);
  2031. {
  2032. const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));
  2033. #pragma unroll
  2034. for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
  2035. int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
  2036. tile_y[l] = by0[l];
  2037. }
  2038. }
  2039. __syncthreads();
  2040. vec_dot(tile_x, tile_y, sum, 0);
  2041. __syncthreads();
  2042. {
  2043. const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int));
  2044. #pragma unroll
  2045. for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
  2046. int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
  2047. tile_y[l] = by0[l];
  2048. }
  2049. }
  2050. __syncthreads();
  2051. vec_dot(tile_x, tile_y, sum, WARP_SIZE);
  2052. __syncthreads();
  2053. }
  2054. if (fixup) {
  2055. write_back(sum, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x);
  2056. } else {
  2057. write_back(sum, dst + jt*mmq_x*ne0 + it*mmq_y, ne0, tile_x_max_i, tile_y_max_j);
  2058. }
  2059. }
  2060. // The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598
  2061. template <ggml_type type, int mmq_x, int nwarps, bool need_check>
  2062. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  2063. #if defined(RDNA3) || defined(RDNA2)
  2064. __launch_bounds__(WARP_SIZE*nwarps, 2)
  2065. #endif // defined(RDNA3) || defined(RDNA2)
  2066. #else
  2067. #if __CUDA_ARCH__ >= CC_VOLTA
  2068. __launch_bounds__(WARP_SIZE*nwarps, 1)
  2069. #else
  2070. __launch_bounds__(WARP_SIZE*nwarps, 2)
  2071. #endif // __CUDA_ARCH__ >= CC_VOLTA
  2072. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  2073. static __global__ void mul_mat_q(
  2074. const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
  2075. const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) {
  2076. // Skip unused template specializations for faster compilation:
  2077. if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {
  2078. NO_DEVICE_CODE;
  2079. return;
  2080. }
  2081. constexpr int qk = ggml_cuda_type_traits<type>::qk;
  2082. constexpr int mmq_y = get_mmq_y_device();
  2083. // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
  2084. #if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
  2085. {
  2086. constexpr bool fixup = false;
  2087. mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
  2088. (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
  2089. blockIdx.x, blockIdx.y, 0, ne00/qk);
  2090. return;
  2091. }
  2092. #endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
  2093. const int64_t blocks_per_ne00 = ne00 / qk;
  2094. constexpr int blocks_per_iter = MMQ_ITER_K / qk;
  2095. const int ntx = (ne11 + mmq_x - 1) / mmq_x; // Number of tiles x
  2096. const int nty = (ne01 + mmq_y - 1) / mmq_y; // Number of tiles y
  2097. // kbc == k block continuous, current index in continuous ijk space.
  2098. int64_t kbc = (int64_t) blockIdx.x *blocks_per_ne00*ntx*nty / gridDim.x;
  2099. int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*blocks_per_ne00*ntx*nty / gridDim.x;
  2100. kbc -= (kbc % blocks_per_ne00) % blocks_per_iter;
  2101. kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter;
  2102. // kb0 == k index when doing the matrix multiplication for an output tile.
  2103. int kb0_start = kbc % blocks_per_ne00;
  2104. int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc);
  2105. while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) {
  2106. const int jt = kbc / (blocks_per_ne00*nty); // j index of current tile.
  2107. const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00; // i index of current tile.
  2108. constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
  2109. mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
  2110. (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
  2111. it, jt, kb0_start, kb0_stop);
  2112. kbc += blocks_per_ne00;
  2113. kbc -= kbc % blocks_per_ne00;
  2114. kb0_start = 0;
  2115. kb0_stop = min(blocks_per_ne00, kbc_stop - kbc);
  2116. }
  2117. if (kbc >= kbc_stop) {
  2118. return;
  2119. }
  2120. const int jt = kbc / (blocks_per_ne00*nty);
  2121. const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
  2122. constexpr bool fixup = true; // Last index writes it data to fixup buffer to avoid data races with other blocks.
  2123. mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
  2124. (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
  2125. it, jt, kb0_start, kb0_stop);
  2126. }
  2127. template <ggml_type type, int mmq_x, int nwarps, bool need_check>
  2128. static __global__ void mul_mat_q_stream_k_fixup(
  2129. float * __restrict__ dst, const float * __restrict__ tmp_last_tile, const int ne00, const int ne01, const int ne11, const int ne0, const int block_num_mmq) {
  2130. constexpr int mmq_y = get_mmq_y_device();
  2131. constexpr int qk = ggml_cuda_type_traits<type>::qk;
  2132. constexpr int blocks_per_iter = MMQ_ITER_K / qk;
  2133. const int64_t blocks_per_ne00 = ne00 / qk;
  2134. float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
  2135. const int ntx = (ne11 + mmq_x - 1) / mmq_x;
  2136. const int nty = (ne01 + mmq_y - 1) / mmq_y;
  2137. bool any_fixup = false;
  2138. const int bidx_start = ((blockIdx.y*nty + blockIdx.x) * block_num_mmq) / (gridDim.y*gridDim.x);
  2139. const int bidx_stop = ((blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq + gridDim.y*gridDim.x - 1) / (gridDim.y*gridDim.x);
  2140. int64_t kbc_0;
  2141. int64_t kbc_stop_0 = (int64_t) bidx_start*blocks_per_ne00*ntx*nty / block_num_mmq;
  2142. for (int bidx = bidx_start; bidx < bidx_stop; ++bidx) {
  2143. kbc_0 = kbc_stop_0;
  2144. kbc_stop_0 = (int64_t) (bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq;
  2145. const int64_t kbc = kbc_0 - (kbc_0 % blocks_per_ne00) % blocks_per_iter;
  2146. const int64_t kbc_stop = kbc_stop_0 - (kbc_stop_0 % blocks_per_ne00) % blocks_per_iter;
  2147. // Skip fixup tile if the MMQ CUDA block never wrote anything to it:
  2148. if (kbc == kbc_stop || kbc_stop % blocks_per_ne00 == 0) {
  2149. continue;
  2150. }
  2151. const int jt = kbc_stop / (blocks_per_ne00*nty);
  2152. const int it = (kbc_stop - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
  2153. // Skip fixup tile if it's unrelated to the output tile assigned to this CUDA block:
  2154. if (it != blockIdx.x || jt != blockIdx.y) {
  2155. continue;
  2156. }
  2157. any_fixup = true;
  2158. #pragma unroll
  2159. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  2160. const int j = j0 + threadIdx.y;
  2161. #pragma unroll
  2162. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  2163. const int i = i0 + threadIdx.x;
  2164. sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
  2165. }
  2166. }
  2167. }
  2168. if (!any_fixup) {
  2169. return;
  2170. }
  2171. dst += blockIdx.y*mmq_x*ne0 + blockIdx.x*mmq_y;
  2172. const int i_max = ne01 - blockIdx.x*mmq_y - 1;
  2173. const int j_max = ne11 - blockIdx.y*mmq_x - 1;
  2174. #pragma unroll
  2175. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  2176. const int j = j0 + threadIdx.y;
  2177. if (j > j_max) {
  2178. return;
  2179. }
  2180. #pragma unroll
  2181. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  2182. const int i = i0 + threadIdx.x;
  2183. if (need_check && i > i_max) {
  2184. continue;
  2185. }
  2186. dst[j*ne0 + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
  2187. }
  2188. }
  2189. }
  2190. struct mmq_args {
  2191. const char * x; const char * y; float * dst;
  2192. int64_t ne00; int64_t ne01; int64_t stride01;
  2193. int64_t ne10; int64_t ne11; int64_t stride11;
  2194. int64_t ne0;
  2195. };
  2196. template<ggml_type type>
  2197. static int mmq_get_shmem(const int mmq_x, const int mmq_y, const int cc) {
  2198. const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
  2199. const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
  2200. const int shmem_x = int8_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
  2201. const int shmem_y = mmq_x*sizeof(block_q8_1_mmq);
  2202. return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
  2203. }
  2204. template <ggml_type type, int mmq_x>
  2205. static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
  2206. const int id = ggml_cuda_get_device();
  2207. const int cc = ggml_cuda_info().devices[id].cc;
  2208. const int nsm = ggml_cuda_info().devices[id].nsm;
  2209. const int mmq_y = get_mmq_y_host(cc);
  2210. const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1);
  2211. const int shmem = mmq_get_shmem<type>(mmq_x, mmq_y, cc);
  2212. #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
  2213. static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
  2214. if (!shmem_limit_raised[id]) {
  2215. CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
  2216. CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
  2217. shmem_limit_raised[id] = true;
  2218. }
  2219. #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
  2220. const int nty = (args.ne01 + mmq_y - 1) / mmq_y;
  2221. const int ntx = (args.ne11 + mmq_x - 1) / mmq_x;
  2222. const dim3 block_nums_xy_tiling(nty, ntx, 1);
  2223. const bool use_stream_k = cc >= CC_VOLTA && cc < CC_OFFSET_AMD;
  2224. if (!use_stream_k) {
  2225. if (args.ne01 % mmq_y == 0) {
  2226. constexpr bool need_check = false;
  2227. mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, shmem, stream>>>
  2228. (args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
  2229. } else {
  2230. constexpr bool need_check = true;
  2231. mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, shmem, stream>>>
  2232. (args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
  2233. }
  2234. return;
  2235. }
  2236. const dim3 block_nums_mmq(nsm, 1, 1);
  2237. ggml_cuda_pool & pool = ctx.pool(id);
  2238. ggml_cuda_pool_alloc<float> tmp_fixup(pool, block_nums_mmq.x * mmq_x*mmq_y);
  2239. if (args.ne01 % mmq_y == 0) {
  2240. constexpr bool need_check = false;
  2241. mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_mmq, block_dims, shmem, stream>>>
  2242. (args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
  2243. mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, 0, stream>>>
  2244. (args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x);
  2245. } else {
  2246. constexpr bool need_check = true;
  2247. mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_mmq, block_dims, shmem, stream>>>
  2248. (args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
  2249. mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, 0, stream>>>
  2250. (args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x);
  2251. }
  2252. }
  2253. template <ggml_type type>
  2254. void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
  2255. const int id = ggml_cuda_get_device();
  2256. const int nsm = ggml_cuda_info().devices[id].nsm;
  2257. const int cc = ggml_cuda_info().devices[id].cc;
  2258. const int smpbo = ggml_cuda_info().devices[id].smpbo;
  2259. const int mmq_x_max = get_mmq_x_max_host(cc);
  2260. const int mmq_y = get_mmq_y_host(cc);
  2261. const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
  2262. const bool use_stream_k = cc >= CC_VOLTA && cc < CC_OFFSET_AMD;
  2263. int mmq_x_best = 0;
  2264. int nparts_best = INT_MAX;
  2265. for (int mmq_x = 8; mmq_x <= mmq_x_max && nparts_best > 1; mmq_x += 8) {
  2266. const int granularity = mmq_get_granularity_host(mmq_x, cc);
  2267. if (mmq_x % granularity != 0 || mmq_get_shmem<type>(mmq_x, mmq_y, cc) > smpbo) {
  2268. continue;
  2269. }
  2270. const int ntiles_x = (args.ne11 + mmq_x - 1) / mmq_x;
  2271. const int nwaves_xy_tiling = ntiles_x*block_num_y;
  2272. const int nparts = use_stream_k ? ntiles_x : nwaves_xy_tiling;
  2273. if (nparts < nparts_best) {
  2274. mmq_x_best = mmq_x;
  2275. nparts_best = nparts;
  2276. }
  2277. }
  2278. switch (mmq_x_best) {
  2279. case 8:
  2280. launch_mul_mat_q<type, 8>(ctx, args, stream);
  2281. break;
  2282. case 16:
  2283. launch_mul_mat_q<type, 16>(ctx, args, stream);
  2284. break;
  2285. case 24:
  2286. launch_mul_mat_q<type, 24>(ctx, args, stream);
  2287. break;
  2288. case 32:
  2289. launch_mul_mat_q<type, 32>(ctx, args, stream);
  2290. break;
  2291. case 40:
  2292. launch_mul_mat_q<type, 40>(ctx, args, stream);
  2293. break;
  2294. case 48:
  2295. launch_mul_mat_q<type, 48>(ctx, args, stream);
  2296. break;
  2297. case 56:
  2298. launch_mul_mat_q<type, 56>(ctx, args, stream);
  2299. break;
  2300. case 64:
  2301. launch_mul_mat_q<type, 64>(ctx, args, stream);
  2302. break;
  2303. case 72:
  2304. launch_mul_mat_q<type, 72>(ctx, args, stream);
  2305. break;
  2306. case 80:
  2307. launch_mul_mat_q<type, 80>(ctx, args, stream);
  2308. break;
  2309. case 88:
  2310. launch_mul_mat_q<type, 88>(ctx, args, stream);
  2311. break;
  2312. case 96:
  2313. launch_mul_mat_q<type, 96>(ctx, args, stream);
  2314. break;
  2315. case 104:
  2316. launch_mul_mat_q<type, 104>(ctx, args, stream);
  2317. break;
  2318. case 112:
  2319. launch_mul_mat_q<type, 112>(ctx, args, stream);
  2320. break;
  2321. case 120:
  2322. launch_mul_mat_q<type, 120>(ctx, args, stream);
  2323. break;
  2324. case 128:
  2325. launch_mul_mat_q<type, 128>(ctx, args, stream);
  2326. break;
  2327. default:
  2328. fprintf(stderr, "mmq_x_best=%d\n", mmq_x_best);
  2329. GGML_ABORT("fatal error");
  2330. break;
  2331. }
  2332. }
  2333. #define DECL_MMQ_CASE(type) \
  2334. template void mul_mat_q_case<type>(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) \
  2335. extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
  2336. extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
  2337. extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
  2338. extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
  2339. extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
  2340. extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
  2341. extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
  2342. extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
  2343. extern DECL_MMQ_CASE(GGML_TYPE_Q5_K);
  2344. extern DECL_MMQ_CASE(GGML_TYPE_Q6_K);
  2345. extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS);
  2346. extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XS);
  2347. extern DECL_MMQ_CASE(GGML_TYPE_IQ2_S);
  2348. extern DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS);
  2349. extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S);
  2350. extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
  2351. extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
  2352. extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
  2353. // -------------------------------------------------------------------------------------------------------------------------
  2354. void ggml_cuda_op_mul_mat_q(
  2355. ggml_backend_cuda_context & ctx,
  2356. const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
  2357. const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
  2358. const int64_t src1_padded_row_size, cudaStream_t stream);
  2359. bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11);