vecdotq.cuh 44 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280
  1. #include "common.cuh"
  2. static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) {
  3. const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
  4. int x32 = 0;
  5. x32 |= x16[0] << 0;
  6. x32 |= x16[1] << 16;
  7. return x32;
  8. }
  9. static __device__ __forceinline__ int get_int_from_uint8(const uint8_t * x8, const int & i32) {
  10. const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
  11. int x32 = 0;
  12. x32 |= x16[0] << 0;
  13. x32 |= x16[1] << 16;
  14. return x32;
  15. }
  16. static __device__ __forceinline__ int get_int_from_int8_aligned(const int8_t * x8, const int & i32) {
  17. return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
  18. }
  19. static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * x8, const int & i32) {
  20. return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
  21. }
  22. // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
  23. // MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
  24. #define VDR_Q4_0_Q8_1_MMVQ 2
  25. #define VDR_Q4_0_Q8_1_MMQ 4
  26. template <int vdr> static __device__ __forceinline__ float vec_dot_q4_0_q8_1_impl(
  27. const int * v, const int * u, const float & d4, const half2 & ds8) {
  28. #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
  29. int sumi = 0;
  30. #pragma unroll
  31. for (int i = 0; i < vdr; ++i) {
  32. const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;
  33. const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
  34. // SIMD dot product of quantized values
  35. sumi = __dp4a(vi0, u[2*i+0], sumi);
  36. sumi = __dp4a(vi1, u[2*i+1], sumi);
  37. }
  38. const float2 ds8f = __half22float2(ds8);
  39. // second part effectively subtracts 8 from each quant value
  40. return d4 * (sumi * ds8f.x - (8*vdr/QI4_0) * ds8f.y);
  41. #else
  42. NO_DEVICE_CODE;
  43. #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
  44. }
  45. #define VDR_Q4_1_Q8_1_MMVQ 2
  46. #define VDR_Q4_1_Q8_1_MMQ 4
  47. template <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_impl(
  48. const int * v, const int * u, const half2 & dm4, const half2 & ds8) {
  49. #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
  50. int sumi = 0;
  51. #pragma unroll
  52. for (int i = 0; i < vdr; ++i) {
  53. const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;
  54. const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
  55. // SIMD dot product of quantized values
  56. sumi = __dp4a(vi0, u[2*i+0], sumi);
  57. sumi = __dp4a(vi1, u[2*i+1], sumi);
  58. }
  59. #ifdef GGML_CUDA_F16
  60. const float2 tmp = __half22float2(__hmul2(dm4, ds8));
  61. const float d4d8 = tmp.x;
  62. const float m4s8 = tmp.y;
  63. #else
  64. const float2 dm4f = __half22float2(dm4);
  65. const float2 ds8f = __half22float2(ds8);
  66. const float d4d8 = dm4f.x * ds8f.x;
  67. const float m4s8 = dm4f.y * ds8f.y;
  68. #endif // GGML_CUDA_F16
  69. // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it
  70. return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1));
  71. #else
  72. NO_DEVICE_CODE;
  73. #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
  74. }
  75. #define VDR_Q5_0_Q8_1_MMVQ 2
  76. #define VDR_Q5_0_Q8_1_MMQ 4
  77. template <int vdr> static __device__ __forceinline__ float vec_dot_q5_0_q8_1_impl(
  78. const int * vl, const int * vh, const int * u, const float & d5, const half2 & ds8) {
  79. #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
  80. int sumi = 0;
  81. #pragma unroll
  82. for (int i = 0; i < vdr; ++i) {
  83. int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits
  84. vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4
  85. vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12
  86. vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20
  87. vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28
  88. sumi = __dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values
  89. int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
  90. vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4
  91. vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12
  92. vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20
  93. vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28
  94. sumi = __dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values
  95. }
  96. const float2 ds8f = __half22float2(ds8);
  97. // second part effectively subtracts 16 from each quant value
  98. return d5 * (sumi * ds8f.x - (16*vdr/QI5_0) * ds8f.y);
  99. #else
  100. NO_DEVICE_CODE;
  101. #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
  102. }
  103. #define VDR_Q5_1_Q8_1_MMVQ 2
  104. #define VDR_Q5_1_Q8_1_MMQ 4
  105. template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_impl(
  106. const int * vl, const int * vh, const int * u, const half2 & dm5, const half2 & ds8) {
  107. #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
  108. int sumi = 0;
  109. #pragma unroll
  110. for (int i = 0; i < vdr; ++i) {
  111. int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits
  112. vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4
  113. vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12
  114. vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20
  115. vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28
  116. sumi = __dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values
  117. int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
  118. vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4
  119. vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12
  120. vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20
  121. vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28
  122. sumi = __dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values
  123. }
  124. #ifdef GGML_CUDA_F16
  125. const float2 tmp = __half22float2(__hmul2(dm5, ds8));
  126. const float d5d8 = tmp.x;
  127. const float m5s8 = tmp.y;
  128. #else
  129. const float2 dm5f = __half22float2(dm5);
  130. const float2 ds8f = __half22float2(ds8);
  131. const float d5d8 = dm5f.x * ds8f.x;
  132. const float m5s8 = dm5f.y * ds8f.y;
  133. #endif // GGML_CUDA_F16
  134. // scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it
  135. return sumi*d5d8 + m5s8 / (QI5_1 / vdr);
  136. #else
  137. NO_DEVICE_CODE;
  138. #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
  139. }
  140. #define VDR_Q8_0_Q8_1_MMVQ 2
  141. #define VDR_Q8_0_Q8_1_MMQ 8
  142. template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_q8_1_impl(
  143. const int * v, const int * u, const float & d8_0, const float & d8_1) {
  144. #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
  145. int sumi = 0;
  146. #pragma unroll
  147. for (int i = 0; i < vdr; ++i) {
  148. // SIMD dot product of quantized values
  149. sumi = __dp4a(v[i], u[i], sumi);
  150. }
  151. return d8_0*d8_1 * sumi;
  152. #else
  153. NO_DEVICE_CODE;
  154. #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
  155. }
  156. template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_impl(
  157. const int * v, const int * u, const half2 & dm8, const half2 & ds8) {
  158. #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
  159. int sumi = 0;
  160. #pragma unroll
  161. for (int i = 0; i < vdr; ++i) {
  162. // SIMD dot product of quantized values
  163. sumi = __dp4a(v[i], u[i], sumi);
  164. }
  165. #ifdef GGML_CUDA_F16
  166. const float2 tmp = __half22float2(__hmul2(dm8, ds8));
  167. const float d8d8 = tmp.x;
  168. const float m8s8 = tmp.y;
  169. #else
  170. const float2 dm8f = __half22float2(dm8);
  171. const float2 ds8f = __half22float2(ds8);
  172. const float d8d8 = dm8f.x * ds8f.x;
  173. const float m8s8 = dm8f.y * ds8f.y;
  174. #endif // GGML_CUDA_F16
  175. // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it
  176. return sumi*d8d8 + m8s8 / (QI8_1 / vdr);
  177. #else
  178. NO_DEVICE_CODE;
  179. #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
  180. }
  181. #define VDR_Q2_K_Q8_1_MMVQ 1
  182. #define VDR_Q2_K_Q8_1_MMQ 2
  183. // contiguous v/x values
  184. static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
  185. const int & v, const int * __restrict__ u, const uint8_t * __restrict__ scales,
  186. const half2 & dm2, const float * __restrict__ d8) {
  187. #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
  188. float sumf_d = 0.0f;
  189. float sumf_m = 0.0f;
  190. #pragma unroll
  191. for (int i = 0; i < QR2_K; ++i) {
  192. const int sc = scales[2*i];
  193. const int vi = (v >> (2*i)) & 0x03030303;
  194. sumf_d += d8[i] * (__dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product
  195. // fill int with 4x m
  196. int m = sc >> 4;
  197. m |= m << 8;
  198. m |= m << 16;
  199. sumf_m += d8[i] * __dp4a(m, u[i], 0); // multiply constant q2_K part with sum of q8_1 values
  200. }
  201. const float2 dm2f = __half22float2(dm2);
  202. return dm2f.x*sumf_d - dm2f.y*sumf_m;
  203. #else
  204. NO_DEVICE_CODE;
  205. #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
  206. }
  207. // contiguous u/y values
  208. static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
  209. const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ scales,
  210. const half2 & dm2, const float & d8) {
  211. #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
  212. int sumi_d = 0;
  213. int sumi_m = 0;
  214. #pragma unroll
  215. for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) {
  216. int sumi_d_sc = 0;
  217. const int sc = scales[i0 / (QI8_1/2)];
  218. // fill int with 4x m
  219. int m = sc >> 4;
  220. m |= m << 8;
  221. m |= m << 16;
  222. #pragma unroll
  223. for (int i = i0; i < i0 + QI8_1/2; ++i) {
  224. sumi_d_sc = __dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product
  225. sumi_m = __dp4a(m, u[i], sumi_m); // multiply sum of q8_1 values with m
  226. }
  227. sumi_d += sumi_d_sc * (sc & 0xF);
  228. }
  229. const float2 dm2f = __half22float2(dm2);
  230. return d8 * (dm2f.x*sumi_d - dm2f.y*sumi_m);
  231. #else
  232. NO_DEVICE_CODE;
  233. #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
  234. }
  235. #define VDR_Q3_K_Q8_1_MMVQ 1
  236. #define VDR_Q3_K_Q8_1_MMQ 2
  237. // contiguous v/x values
  238. static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(
  239. const int & vl, const int & vh, const int * __restrict__ u, const uint8_t * __restrict__ scales,
  240. const int & scale_offset, const float & d3, const float * __restrict__ d8) {
  241. #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
  242. float sumf = 0.0f;
  243. #pragma unroll
  244. for (int i = 0; i < QR3_K; ++i) {
  245. const int isc = scale_offset + 2*i;
  246. const int isc_low = isc % (QK_K/32);
  247. const int sc_shift_low = 4 * (isc / (QK_K/32));
  248. const int sc_low = (scales[isc_low] >> sc_shift_low) & 0xF;
  249. const int isc_high = isc % (QK_K/64);
  250. const int sc_shift_high = 2 * (isc / (QK_K/64));
  251. const int sc_high = ((scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4;
  252. const int sc = (sc_low | sc_high) - 32;
  253. const int vil = (vl >> (2*i)) & 0x03030303;
  254. const int vih = ((vh >> i) << 2) & 0x04040404;
  255. const int vi = __vsubss4(vil, vih);
  256. sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product
  257. }
  258. return d3 * sumf;
  259. #else
  260. NO_DEVICE_CODE;
  261. #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
  262. }
  263. // contiguous u/y values
  264. static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
  265. const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales,
  266. const float & d3, const float & d8) {
  267. #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
  268. int sumi = 0;
  269. #pragma unroll
  270. for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) {
  271. int sumi_sc = 0;
  272. for (int i = i0; i < i0 + QI8_1/2; ++i) {
  273. sumi_sc = __dp4a(v[i], u[i], sumi_sc); // SIMD dot product
  274. }
  275. sumi += sumi_sc * scales[i0 / (QI8_1/2)];
  276. }
  277. return d3*d8 * sumi;
  278. #else
  279. NO_DEVICE_CODE;
  280. #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
  281. }
  282. #define VDR_Q4_K_Q8_1_MMVQ 2
  283. #define VDR_Q4_K_Q8_1_MMQ 8
  284. // contiguous v/x values
  285. static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
  286. const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
  287. const uint8_t * __restrict__ m, const half2 & dm4, const float * __restrict__ d8) {
  288. #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
  289. float sumf_d = 0.0f;
  290. float sumf_m = 0.0f;
  291. #pragma unroll
  292. for (int i = 0; i < QR4_K; ++i) {
  293. const int v0i = (v[0] >> (4*i)) & 0x0F0F0F0F;
  294. const int v1i = (v[1] >> (4*i)) & 0x0F0F0F0F;
  295. const int dot1 = __dp4a(v1i, u[2*i+1], __dp4a(v0i, u[2*i+0], 0)); // SIMD dot product
  296. const int dot2 = __dp4a(0x01010101, u[2*i+1], __dp4a(0x01010101, u[2*i+0], 0)); // sum of u
  297. sumf_d += d8[i] * (dot1 * sc[i]);
  298. sumf_m += d8[i] * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values
  299. }
  300. const float2 dm4f = __half22float2(dm4);
  301. return dm4f.x*sumf_d - dm4f.y*sumf_m;
  302. #else
  303. NO_DEVICE_CODE;
  304. #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
  305. }
  306. // contiguous u/y values
  307. static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
  308. const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
  309. const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
  310. #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
  311. float sumf_d = 0.0f;
  312. float sumf_m = 0.0f;
  313. #pragma unroll
  314. for (int i = 0; i < QR4_K*VDR_Q4_K_Q8_1_MMQ/QI8_1; ++i) {
  315. int sumi_d = 0;
  316. #pragma unroll
  317. for (int j = 0; j < QI8_1; ++j) {
  318. sumi_d = __dp4a((v[j] >> (4*i)) & 0x0F0F0F0F, u[i*QI8_1 + j], sumi_d); // SIMD dot product
  319. }
  320. const float2 ds8f = __half22float2(ds8[i]);
  321. sumf_d += ds8f.x * (sc[i] * sumi_d);
  322. sumf_m += ds8f.y * m[i]; // sum of q8_1 block * q4_K min val
  323. }
  324. const float2 dm4f = __half22float2(dm4);
  325. return dm4f.x*sumf_d - dm4f.y*sumf_m;
  326. #else
  327. NO_DEVICE_CODE;
  328. #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
  329. }
  330. #define VDR_Q5_K_Q8_1_MMVQ 2
  331. #define VDR_Q5_K_Q8_1_MMQ 8
  332. // contiguous v/x values
  333. static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq(
  334. const int * __restrict__ vl, const int * __restrict__ vh, const int * __restrict__ u, const uint8_t * __restrict__ sc,
  335. const uint8_t * __restrict__ m, const half2 & dm5, const float * __restrict__ d8) {
  336. #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
  337. float sumf_d = 0.0f;
  338. float sumf_m = 0.0f;
  339. #pragma unroll
  340. for (int i = 0; i < QR5_K; ++i) {
  341. const int vl0i = (vl[0] >> (4*i)) & 0x0F0F0F0F;
  342. const int vl1i = (vl[1] >> (4*i)) & 0x0F0F0F0F;
  343. const int vh0i = ((vh[0] >> i) << 4) & 0x10101010;
  344. const int vh1i = ((vh[1] >> i) << 4) & 0x10101010;
  345. const int v0i = vl0i | vh0i;
  346. const int v1i = vl1i | vh1i;
  347. const int dot1 = __dp4a(v0i, u[2*i+0], __dp4a(v1i, u[2*i+1], 0)); // SIMD dot product
  348. const int dot2 = __dp4a(0x01010101, u[2*i+0], __dp4a(0x01010101, u[2*i+1], 0)); // sum of u
  349. sumf_d += d8[i] * (dot1 * sc[i]);
  350. sumf_m += d8[i] * (dot2 * m[i]);
  351. }
  352. const float2 dm5f = __half22float2(dm5);
  353. return dm5f.x*sumf_d - dm5f.y*sumf_m;
  354. #else
  355. NO_DEVICE_CODE;
  356. #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
  357. }
  358. // contiguous u/y values
  359. static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
  360. const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
  361. const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
  362. #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
  363. float sumf_d = 0.0f;
  364. float sumf_m = 0.0f;
  365. #pragma unroll
  366. for (int i = 0; i < QR5_K*VDR_Q5_K_Q8_1_MMQ/QI8_1; ++i) {
  367. int sumi_d = 0;
  368. #pragma unroll
  369. for (int j = 0; j < QI8_1; ++j) {
  370. sumi_d = __dp4a(v[i*QI8_1 + j], u[i*QI8_1 + j], sumi_d); // SIMD dot product
  371. }
  372. const float2 ds8f = __half22float2(ds8[i]);
  373. sumf_d += ds8f.x * (sc[i] * sumi_d);
  374. sumf_m += ds8f.y * m[i]; // sum of q8_1 block * q4_K min val
  375. }
  376. const float2 dm4f = __half22float2(dm4);
  377. return dm4f.x*sumf_d - dm4f.y*sumf_m;
  378. #else
  379. NO_DEVICE_CODE;
  380. #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
  381. }
  382. #define VDR_Q6_K_Q8_1_MMVQ 1
  383. #define VDR_Q6_K_Q8_1_MMQ 8
  384. // contiguous v/x values
  385. static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(
  386. const int & vl, const int & vh, const int * __restrict__ u, const int8_t * __restrict__ scales,
  387. const float & d, const float * __restrict__ d8) {
  388. #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
  389. float sumf = 0.0f;
  390. #pragma unroll
  391. for (int i = 0; i < QR6_K; ++i) {
  392. const int sc = scales[4*i];
  393. const int vil = (vl >> (4*i)) & 0x0F0F0F0F;
  394. const int vih = ((vh >> (4*i)) << 4) & 0x30303030;
  395. const int vi = __vsubss4((vil | vih), 0x20202020); // vi = (vil | vih) - 32
  396. sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product
  397. }
  398. return d*sumf;
  399. #else
  400. NO_DEVICE_CODE;
  401. #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
  402. }
  403. // contiguous u/y values
  404. static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
  405. const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc,
  406. const float & d6, const float * __restrict__ d8) {
  407. #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
  408. float sumf_d = 0.0f;
  409. #pragma unroll
  410. for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) {
  411. int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale
  412. #pragma unroll
  413. for (int i = i0; i < i0 + 2; ++i) {
  414. sumi_d.x = __dp4a(v[2*i+0], u[2*i+0], sumi_d.x); // SIMD dot product
  415. sumi_d.x = __dp4a(v[2*i+1], u[2*i+1], sumi_d.x); // SIMD dot product
  416. sumi_d.y = __dp4a(v[2*i+4], u[2*i+4], sumi_d.y); // SIMD dot product
  417. sumi_d.y = __dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product
  418. }
  419. sumf_d += d8[i0/4] * (sc[i0/2+0]*sumi_d.x + sc[i0/2+1]*sumi_d.y);
  420. }
  421. return d6 * sumf_d;
  422. #else
  423. NO_DEVICE_CODE;
  424. #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
  425. }
  426. static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
  427. const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
  428. const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
  429. int v[VDR_Q4_0_Q8_1_MMVQ];
  430. int u[2*VDR_Q4_0_Q8_1_MMVQ];
  431. #pragma unroll
  432. for (int i = 0; i < VDR_Q4_0_Q8_1_MMVQ; ++i) {
  433. v[i] = get_int_from_uint8(bq4_0->qs, iqs + i);
  434. u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
  435. u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0);
  436. }
  437. return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMVQ>(v, u, bq4_0->d, bq8_1->ds);
  438. }
  439. static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
  440. const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
  441. const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;
  442. int v[VDR_Q4_1_Q8_1_MMVQ];
  443. int u[2*VDR_Q4_1_Q8_1_MMVQ];
  444. #pragma unroll
  445. for (int i = 0; i < VDR_Q4_1_Q8_1_MMVQ; ++i) {
  446. v[i] = get_int_from_uint8_aligned(bq4_1->qs, iqs + i);
  447. u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
  448. u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_1);
  449. }
  450. return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMVQ>(v, u, bq4_1->dm, bq8_1->ds);
  451. }
  452. static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
  453. const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
  454. const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq;
  455. int vl[VDR_Q5_0_Q8_1_MMVQ];
  456. int vh[VDR_Q5_0_Q8_1_MMVQ];
  457. int u[2*VDR_Q5_0_Q8_1_MMVQ];
  458. #pragma unroll
  459. for (int i = 0; i < VDR_Q5_0_Q8_1_MMVQ; ++i) {
  460. vl[i] = get_int_from_uint8(bq5_0->qs, iqs + i);
  461. vh[i] = get_int_from_uint8(bq5_0->qh, 0) >> (4 * (iqs + i));
  462. u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
  463. u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_0);
  464. }
  465. return vec_dot_q5_0_q8_1_impl<VDR_Q5_0_Q8_1_MMVQ>(vl, vh, u, bq5_0->d, bq8_1->ds);
  466. }
  467. static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
  468. const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
  469. const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;
  470. int vl[VDR_Q5_1_Q8_1_MMVQ];
  471. int vh[VDR_Q5_1_Q8_1_MMVQ];
  472. int u[2*VDR_Q5_1_Q8_1_MMVQ];
  473. #pragma unroll
  474. for (int i = 0; i < VDR_Q5_1_Q8_1_MMVQ; ++i) {
  475. vl[i] = get_int_from_uint8_aligned(bq5_1->qs, iqs + i);
  476. vh[i] = get_int_from_uint8_aligned(bq5_1->qh, 0) >> (4 * (iqs + i));
  477. u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
  478. u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_1);
  479. }
  480. return vec_dot_q5_1_q8_1_impl<VDR_Q5_1_Q8_1_MMVQ>(vl, vh, u, bq5_1->dm, bq8_1->ds);
  481. }
  482. static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
  483. const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
  484. const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq;
  485. int v[VDR_Q8_0_Q8_1_MMVQ];
  486. int u[VDR_Q8_0_Q8_1_MMVQ];
  487. #pragma unroll
  488. for (int i = 0; i < VDR_Q8_0_Q8_1_MMVQ; ++i) {
  489. v[i] = get_int_from_int8(bq8_0->qs, iqs + i);
  490. u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
  491. }
  492. return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d, __low2half(bq8_1->ds));
  493. }
  494. static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
  495. const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
  496. const block_q2_K * bq2_K = (const block_q2_K *) vbq;
  497. const int bq8_offset = QR2_K * (iqs / QI8_1);
  498. const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
  499. const uint8_t * scales = bq2_K->scales + scale_offset;
  500. const int v = get_int_from_uint8_aligned(bq2_K->qs, iqs);
  501. int u[QR2_K];
  502. float d8[QR2_K];
  503. #pragma unroll
  504. for (int i = 0; i < QR2_K; ++ i) {
  505. u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
  506. d8[i] = __low2float(bq8_1[bq8_offset + i].ds);
  507. }
  508. return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8);
  509. }
  510. static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
  511. const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
  512. const block_q3_K * bq3_K = (const block_q3_K *) vbq;
  513. const int bq8_offset = QR3_K * (iqs / (QI3_K/2));
  514. const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
  515. const float d = bq3_K->d;
  516. const int vl = get_int_from_uint8(bq3_K->qs, iqs);
  517. // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
  518. const int vh = ~get_int_from_uint8(bq3_K->hmask, iqs % (QI3_K/2)) >> bq8_offset;
  519. int u[QR3_K];
  520. float d8[QR3_K];
  521. #pragma unroll
  522. for (int i = 0; i < QR3_K; ++i) {
  523. u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
  524. d8[i] = __low2float(bq8_1[bq8_offset + i].ds);
  525. }
  526. return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8);
  527. }
  528. static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
  529. const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
  530. #ifndef GGML_QKK_64
  531. const block_q4_K * bq4_K = (const block_q4_K *) vbq;
  532. int v[2];
  533. int u[2*QR4_K];
  534. float d8[QR4_K];
  535. // iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6
  536. const int bq8_offset = QR4_K * ((iqs/2) / (QI8_1/2));
  537. // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12
  538. // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44
  539. // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76
  540. // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108
  541. const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));
  542. v[0] = q4[0];
  543. v[1] = q4[4];
  544. const uint16_t * scales = (const uint16_t *)bq4_K->scales;
  545. uint16_t aux[2];
  546. const int j = bq8_offset/2;
  547. if (j < 2) {
  548. aux[0] = scales[j+0] & 0x3f3f;
  549. aux[1] = scales[j+2] & 0x3f3f;
  550. } else {
  551. aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
  552. aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
  553. }
  554. const uint8_t * sc = (const uint8_t *)aux;
  555. const uint8_t * m = sc + 2;
  556. for (int i = 0; i < QR4_K; ++i) {
  557. const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
  558. d8[i] = __low2float(bq8i->ds);
  559. const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
  560. u[2*i+0] = q8[0];
  561. u[2*i+1] = q8[4];
  562. }
  563. return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8);
  564. #else
  565. #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
  566. const block_q4_K * bq4_K = (const block_q4_K *) vbq;
  567. float sumf_d = 0.0f;
  568. float sumf_m = 0.0f;
  569. uint16_t aux16[2];
  570. const uint8_t * s = (const uint8_t *)aux16;
  571. const uint16_t * a = (const uint16_t *)bq4_K->scales;
  572. aux16[0] = a[0] & 0x0f0f;
  573. aux16[1] = (a[0] >> 4) & 0x0f0f;
  574. const float dall = bq4_K->dm[0];
  575. const float dmin = bq4_K->dm[1];
  576. const float d8_1 = __low2float(bq8_1[0].ds);
  577. const float d8_2 = __low2float(bq8_1[1].ds);
  578. const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2));
  579. const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4);
  580. const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2));
  581. const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4);
  582. const int * q4 = (const int *)bq4_K->qs + (iqs/2);
  583. const int v1 = q4[0];
  584. const int v2 = q4[4];
  585. const int dot1 = __dp4a(ui2, v2 & 0x0f0f0f0f, __dp4a(ui1, v1 & 0x0f0f0f0f, 0));
  586. const int dot2 = __dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, __dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0));
  587. const int dot3 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0));
  588. const int dot4 = __dp4a(0x01010101, ui4, __dp4a(0x01010101, ui3, 0));
  589. sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]);
  590. sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]);
  591. return dall * sumf_d - dmin * sumf_m;
  592. #else
  593. NO_DEVICE_CODE;
  594. #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
  595. #endif
  596. }
  597. static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
  598. const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
  599. #ifndef GGML_QKK_64
  600. const block_q5_K * bq5_K = (const block_q5_K *) vbq;
  601. int vl[2];
  602. int vh[2];
  603. int u[2*QR5_K];
  604. float d8[QR5_K];
  605. const int bq8_offset = QR5_K * ((iqs/2) / (QI8_1/2));
  606. const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));
  607. const int * qh = (const int *)(bq5_K->qh + 4 * ((iqs/2)%4));
  608. vl[0] = ql[0];
  609. vl[1] = ql[4];
  610. vh[0] = qh[0] >> bq8_offset;
  611. vh[1] = qh[4] >> bq8_offset;
  612. const uint16_t * scales = (const uint16_t *)bq5_K->scales;
  613. uint16_t aux[2];
  614. const int j = bq8_offset/2;
  615. if (j < 2) {
  616. aux[0] = scales[j+0] & 0x3f3f;
  617. aux[1] = scales[j+2] & 0x3f3f;
  618. } else {
  619. aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
  620. aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
  621. }
  622. const uint8_t * sc = (const uint8_t *)aux;
  623. const uint8_t * m = sc + 2;
  624. #pragma unroll
  625. for (int i = 0; i < QR5_K; ++i) {
  626. const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
  627. d8[i] = __low2float(bq8i->ds);
  628. const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
  629. u[2*i+0] = q8[0];
  630. u[2*i+1] = q8[4];
  631. }
  632. return vec_dot_q5_K_q8_1_impl_vmmq(vl, vh, u, sc, m, bq5_K->dm, d8);
  633. #else
  634. #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
  635. const block_q5_K * bq5_K = (const block_q5_K *) vbq;
  636. const int8_t * s = bq5_K->scales;
  637. const float d = bq5_K->d;
  638. const float d8_1 = __low2half(bq8_1[0].ds);
  639. const float d8_2 = __low2half(bq8_1[1].ds);
  640. const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2));
  641. const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4);
  642. const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2));
  643. const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4);
  644. const int * ql = (const int *)bq5_K->qs + (iqs/2);
  645. const int vl1 = ql[0];
  646. const int vl2 = ql[4];
  647. const int step = 4 * (iqs/2); // 0, 4, 8, 12
  648. const int im = step/8; // = 0 for iqs = 0, 2, = 1 for iqs = 4, 6
  649. const int in = step%8; // 0, 4, 0, 4
  650. const int vh = (*((const int *)(bq5_K->qh + in))) >> im;
  651. const int v1 = (((vh << 4) & 0x10101010) ^ 0x10101010) | ((vl1 >> 0) & 0x0f0f0f0f);
  652. const int v2 = (((vh << 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 0) & 0x0f0f0f0f);
  653. const int v3 = (((vh >> 0) & 0x10101010) ^ 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f);
  654. const int v4 = (((vh >> 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 4) & 0x0f0f0f0f);
  655. const float sumf_d = d8_1 * (__dp4a(ui1, v1, 0) * s[0] + __dp4a(ui2, v2, 0) * s[1])
  656. + d8_2 * (__dp4a(ui3, v3, 0) * s[2] + __dp4a(ui4, v4, 0) * s[3]);
  657. return d * sumf_d;
  658. #else
  659. NO_DEVICE_CODE;
  660. #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
  661. #endif
  662. }
  663. static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
  664. const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
  665. const block_q6_K * bq6_K = (const block_q6_K *) vbq;
  666. const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4);
  667. const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8);
  668. const int vh_shift = 2 * ((iqs % (QI6_K/2)) / (QI6_K/4));
  669. const int vl = get_int_from_uint8(bq6_K->ql, iqs);
  670. const int vh = get_int_from_uint8(bq6_K->qh, (QI6_K/4) * (iqs / (QI6_K/2)) + iqs % (QI6_K/4)) >> vh_shift;
  671. const int8_t * scales = bq6_K->scales + scale_offset;
  672. int u[QR6_K];
  673. float d8[QR6_K];
  674. #pragma unroll
  675. for (int i = 0; i < QR6_K; ++i) {
  676. u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1);
  677. d8[i] = __low2float(bq8_1[bq8_offset + 2*i].ds);
  678. }
  679. return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8);
  680. }
  681. static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
  682. const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
  683. #if QK_K == 256
  684. const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq;
  685. #if QR2_XXS == 8
  686. const int ib32 = iqs;
  687. const uint16_t * q2 = bq2->qs + 4*ib32;
  688. const uint8_t * aux8 = (const uint8_t *)q2;
  689. const int8_t * q8 = bq8_1[ib32].qs;
  690. uint32_t aux32 = q2[2] | (q2[3] << 16);
  691. int sumi = 0;
  692. for (int l = 0; l < 4; ++l) {
  693. const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
  694. const uint8_t signs = ksigns_iq2xs[aux32 & 127];
  695. for (int j = 0; j < 8; ++j) {
  696. sumi += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
  697. }
  698. q8 += 8;
  699. aux32 >>= 7;
  700. }
  701. const float d = (float)bq2->d * (0.5f + aux32) * __low2float(bq8_1[ib32].ds) * 0.25f;
  702. return d * sumi;
  703. #else
  704. // iqs is 0...15
  705. const int ib32 = iqs/2;
  706. const int il = iqs%2;
  707. const uint16_t * q2 = bq2->qs + 4*ib32;
  708. const uint8_t * aux8 = (const uint8_t *)q2;
  709. const uint8_t * grid1 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
  710. const uint8_t * grid2 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
  711. const uint32_t aux32 = q2[2] | (q2[3] << 16);
  712. const float d = (float)bq2->d * (0.5f + (aux32 >> 28)) * __low2float(bq8_1[ib32].ds) * 0.25f;
  713. const uint8_t signs1 = ksigns_iq2xs[(aux32 >> 14*il) & 127];
  714. const uint8_t signs2 = ksigns_iq2xs[(aux32 >> (14*il + 7)) & 127];
  715. const int8_t * q8 = bq8_1[ib32].qs + 16*il;
  716. int sumi1 = 0, sumi2 = 0;
  717. for (int j = 0; j < 8; ++j) {
  718. sumi1 += q8[j+0] * grid1[j] * (signs1 & kmask_iq2xs[j] ? -1 : 1);
  719. sumi2 += q8[j+8] * grid2[j] * (signs2 & kmask_iq2xs[j] ? -1 : 1);
  720. }
  721. return d * (sumi1 + sumi2);
  722. #endif
  723. #else
  724. NO_DEVICE_CODE;
  725. #endif
  726. }
  727. static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
  728. const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
  729. #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
  730. #if QK_K == 256
  731. const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq;
  732. const int ib32 = iqs;
  733. const uint16_t * q2 = bq2->qs + 4*ib32;
  734. const int8_t * q8 = bq8_1[ib32].qs;
  735. const uint8_t ls1 = bq2->scales[ib32] & 0xf;
  736. const uint8_t ls2 = bq2->scales[ib32] >> 4;
  737. int sumi1 = 0;
  738. for (int l = 0; l < 2; ++l) {
  739. const uint32_t * grid = (const uint32_t *)(iq2xs_grid + (q2[l] & 511));
  740. const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9));
  741. const int grid_l = __vsub4(grid[0] ^ signs[0], signs[0]);
  742. const int grid_h = __vsub4(grid[1] ^ signs[1], signs[1]);
  743. sumi1 = __dp4a(grid_l, *((const int *)q8 + 0), sumi1);
  744. sumi1 = __dp4a(grid_h, *((const int *)q8 + 1), sumi1);
  745. q8 += 8;
  746. }
  747. int sumi2 = 0;
  748. for (int l = 2; l < 4; ++l) {
  749. const uint32_t * grid = (const uint32_t *)(iq2xs_grid + (q2[l] & 511));
  750. const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9));
  751. const int grid_l = __vsub4(grid[0] ^ signs[0], signs[0]);
  752. const int grid_h = __vsub4(grid[1] ^ signs[1], signs[1]);
  753. sumi2 = __dp4a(grid_l, *((const int *)q8 + 0), sumi2);
  754. sumi2 = __dp4a(grid_h, *((const int *)q8 + 1), sumi2);
  755. q8 += 8;
  756. }
  757. const float d = (float)bq2->d * __low2float(bq8_1[ib32].ds) * 0.25f;
  758. return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
  759. #else
  760. GGML_UNUSED(ksigns64);
  761. NO_DEVICE_CODE;
  762. #endif
  763. #else
  764. GGML_UNUSED(ksigns64);
  765. NO_DEVICE_CODE;
  766. #endif
  767. }
  768. // TODO
  769. static __device__ __forceinline__ float vec_dot_iq2_s_q8_1(
  770. const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
  771. #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
  772. #if QK_K == 256
  773. const block_iq2_s * bq2 = (const block_iq2_s *) vbq;
  774. const int ib32 = iqs;
  775. const int8_t * q8 = bq8_1[ib32].qs;
  776. const uint8_t * signs = bq2->qs + QK_K/8 + 4*ib32;
  777. const uint8_t ls1 = bq2->scales[ib32] & 0xf;
  778. const uint8_t ls2 = bq2->scales[ib32] >> 4;
  779. int sumi1 = 0;
  780. for (int l = 0; l < 2; ++l) {
  781. const uint32_t * grid = (const uint32_t *)(iq2s_grid + (bq2->qs[4*ib32+l] | ((bq2->qh[ib32] << (8-2*l)) & 0x300)));
  782. const uint32_t signs0 = __vcmpeq4(((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201);
  783. const uint32_t signs1 = __vcmpeq4(((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201);
  784. const int grid_l = __vsub4(grid[0] ^ signs0, signs0);
  785. const int grid_h = __vsub4(grid[1] ^ signs1, signs1);
  786. sumi1 = __dp4a(grid_l, *((const int *)q8 + 0), sumi1);
  787. sumi1 = __dp4a(grid_h, *((const int *)q8 + 1), sumi1);
  788. q8 += 8;
  789. }
  790. int sumi2 = 0;
  791. for (int l = 2; l < 4; ++l) {
  792. const uint32_t * grid = (const uint32_t *)(iq2s_grid + (bq2->qs[4*ib32+l] | ((bq2->qh[ib32] << (8-2*l)) & 0x300)));
  793. const uint32_t signs0 = __vcmpeq4(((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201);
  794. const uint32_t signs1 = __vcmpeq4(((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201);
  795. const int grid_l = __vsub4(grid[0] ^ signs0, signs0);
  796. const int grid_h = __vsub4(grid[1] ^ signs1, signs1);
  797. sumi2 = __dp4a(grid_l, *((const int *)q8 + 0), sumi2);
  798. sumi2 = __dp4a(grid_h, *((const int *)q8 + 1), sumi2);
  799. q8 += 8;
  800. }
  801. const float d = (float)bq2->d * __low2float(bq8_1[ib32].ds) * 0.25f;
  802. return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
  803. #else
  804. GGML_UNUSED(ksigns64);
  805. NO_DEVICE_CODE;
  806. #endif
  807. #else
  808. GGML_UNUSED(ksigns64);
  809. NO_DEVICE_CODE;
  810. #endif
  811. }
  812. static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
  813. const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
  814. #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
  815. #if QK_K == 256
  816. const block_iq3_xxs * bq2 = (const block_iq3_xxs *) vbq;
  817. const int ib32 = iqs;
  818. const uint8_t * q3 = bq2->qs + 8*ib32;
  819. const uint16_t * gas = (const uint16_t *)(bq2->qs + QK_K/4) + 2*ib32;
  820. const int8_t * q8 = bq8_1[ib32].qs;
  821. uint32_t aux32 = gas[0] | (gas[1] << 16);
  822. int sumi = 0;
  823. for (int l = 0; l < 4; ++l) {
  824. const uint32_t * grid1 = iq3xxs_grid + q3[2*l+0];
  825. const uint32_t * grid2 = iq3xxs_grid + q3[2*l+1];
  826. const uint32_t * signs = (const uint32_t *)(ksigns64 + (aux32 & 127));
  827. const int grid_l = __vsub4(grid1[0] ^ signs[0], signs[0]);
  828. const int grid_h = __vsub4(grid2[0] ^ signs[1], signs[1]);
  829. sumi = __dp4a(grid_l, *((int *)q8+0), sumi);
  830. sumi = __dp4a(grid_h, *((int *)q8+1), sumi);
  831. q8 += 8;
  832. aux32 >>= 7;
  833. }
  834. const float d = (float)bq2->d * (0.5f + aux32) * __low2float(bq8_1[ib32].ds) * 0.5f;
  835. return d * sumi;
  836. #else
  837. NO_DEVICE_CODE;
  838. #endif
  839. #else
  840. NO_DEVICE_CODE;
  841. #endif
  842. }
  843. // TODO: don't use lookup table for signs
  844. static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
  845. const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
  846. #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
  847. #if QK_K == 256
  848. const block_iq3_s * bq2 = (const block_iq3_s *) vbq;
  849. const int ib32 = iqs;
  850. const uint8_t * qs = bq2->qs + 8*ib32;
  851. const int8_t * q8 = bq8_1[ib32].qs;
  852. int sumi = 0;
  853. for (int l = 0; l < 4; ++l) {
  854. const uint32_t * grid1 = iq3s_grid + (qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256));
  855. const uint32_t * grid2 = iq3s_grid + (qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256));
  856. uint32_t signs0 = __vcmpeq4(((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201);
  857. uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201);
  858. const int grid_l = __vsub4(grid1[0] ^ signs0, signs0);
  859. const int grid_h = __vsub4(grid2[0] ^ signs1, signs1);
  860. sumi = __dp4a(grid_l, *((int *)q8+0), sumi);
  861. sumi = __dp4a(grid_h, *((int *)q8+1), sumi);
  862. q8 += 8;
  863. }
  864. const float d = (float)bq2->d * (1 + 2*((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * __low2float(bq8_1[ib32].ds);
  865. return d * sumi;
  866. #else
  867. NO_DEVICE_CODE;
  868. #endif
  869. #else
  870. NO_DEVICE_CODE;
  871. #endif
  872. }
  873. static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
  874. const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
  875. #if QK_K == 256
  876. const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
  877. const int ib32 = iqs;
  878. int sumi = 0;
  879. #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
  880. const int * q8 = (const int *)bq8_1[ib32].qs;
  881. for (int l = 0; l < 4; ++l) {
  882. const int * grid = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8)));
  883. int grid0 = grid[0] & 0x0f0f0f0f;
  884. int grid1 = (grid[0] >> 4) & 0x0f0f0f0f;
  885. sumi = __dp4a(q8[2*l+1], grid1, __dp4a(q8[2*l+0], grid0, sumi));
  886. }
  887. #else
  888. const int8_t * q8 = bq8_1[ib32].qs;
  889. for (int l = 0; l < 4; ++l) {
  890. const uint8_t * grid = (const uint8_t *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8)));
  891. for (int j = 0; j < 4; ++j) {
  892. sumi += q8[j] * (grid[j] & 0xf) + q8[j+4] * (grid[j] >> 4);
  893. }
  894. q8 += 8;
  895. }
  896. #endif
  897. const float delta = bq1->qh[ib32] & 0x8000 ? -1-IQ1S_DELTA : -1+IQ1S_DELTA;
  898. const float d1q = (float)bq1->d * (2*((bq1->qh[ib32] >> 12) & 7) + 1);
  899. const float d = d1q * __low2float (bq8_1[ib32].ds);
  900. const float m = d1q * __high2float(bq8_1[ib32].ds);
  901. return d * sumi + m * delta;
  902. #else
  903. NO_DEVICE_CODE;
  904. #endif
  905. }
  906. static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
  907. const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
  908. #if QK_K == 256
  909. const block_iq1_m * bq1 = (const block_iq1_m *) vbq;
  910. const int ib32 = iqs;
  911. int sumi[2] = {0, 0};
  912. float sumf[2] = {0.f, 0.f};
  913. #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
  914. const int * q8 = (const int *)bq8_1[ib32].qs;
  915. for (int l = 0; l < 4; ++l) {
  916. const int * grid = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[2*ib32+l/2] >> 4*(l%2)) & 7) << 8)));
  917. int grid0 = grid[0] & 0x0f0f0f0f;
  918. int grid1 = (grid[0] >> 4) & 0x0f0f0f0f;
  919. sumi[l/2] = __dp4a(q8[2*l+1], grid1, __dp4a(q8[2*l+0], grid0, sumi[l/2]));
  920. const float delta = (bq1->qh[2*ib32+l/2] >> 4*(l%2)) & 0x08 ? -1-IQ1M_DELTA : -1+IQ1M_DELTA;
  921. const int sumy = __dp4a(q8[2*l+1], 0x01010101, __dp4a(q8[2*l+0], 0x01010101, 0));
  922. sumf[l/2] += delta*sumy;
  923. }
  924. #else
  925. const int8_t * q8 = bq8_1[ib32].qs;
  926. for (int l = 0; l < 4; ++l) {
  927. const uint8_t * grid = (const uint8_t *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8)));
  928. int sumy = 0;
  929. for (int j = 0; j < 4; ++j) {
  930. sumi[l/2] += q8[j] * (grid[j] & 0xf) + q8[j+4] * (grid[j] >> 4);
  931. sumy += q8[j] + q8[j+4];
  932. }
  933. const float delta = (bq1->qh[2*ib32+l/2] >> 4*(l%2)) & 0x08 ? -1-IQ1M_DELTA : -1+IQ1M_DELTA;
  934. sumf[l/2] += delta*sumy;
  935. q8 += 8;
  936. }
  937. #endif
  938. iq1m_scale_t scale;
  939. const uint16_t * sc = (const uint16_t *)bq1->scales;
  940. scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
  941. const float d = (float)scale.f16 * __low2float (bq8_1[ib32].ds);
  942. return d * ((sumi[0] + sumf[0]) * (2*((sc[ib32/2] >> 6*(ib32%2)) & 0x7) + 1) + (sumi[1] + sumf[1]) * (2*((sc[ib32/2] >> (6*(ib32%2)+3)) & 0x7) + 1));
  943. #else
  944. NO_DEVICE_CODE;
  945. #endif
  946. }
  947. #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
  948. static __device__ __forceinline__ void get_int_from_table_16(const uint32_t & q4, const uint8_t * values,
  949. int & val1, int & val2) {
  950. uint32_t aux32; const uint8_t * q8 = (const uint8_t *)&aux32;
  951. aux32 = q4 & 0x0f0f0f0f;
  952. uint16_t v1 = values[q8[0]] | (values[q8[1]] << 8);
  953. uint16_t v2 = values[q8[2]] | (values[q8[3]] << 8);
  954. val1 = v1 | (v2 << 16);
  955. aux32 = (q4 >> 4) & 0x0f0f0f0f;
  956. v1 = values[q8[0]] | (values[q8[1]] << 8);
  957. v2 = values[q8[2]] | (values[q8[3]] << 8);
  958. val2 = v1 | (v2 << 16);
  959. }
  960. #endif
  961. static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
  962. const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
  963. const block_iq4_nl * bq = (const block_iq4_nl *) vbq;
  964. #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
  965. const uint16_t * q4 = (const uint16_t *)bq->qs + 2*iqs;
  966. const int32_t * q8 = (const int32_t *)bq8_1->qs + iqs;
  967. const uint8_t * values = (const uint8_t *)kvalues_iq4nl;
  968. int v1, v2;
  969. int sumi1 = 0, sumi2 = 0;
  970. for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
  971. const uint32_t aux = q4[2*l] | (q4[2*l+1] << 16);
  972. get_int_from_table_16(aux, values, v1, v2);
  973. sumi1 = __dp4a(v1, q8[l+0], sumi1);
  974. sumi2 = __dp4a(v2, q8[l+4], sumi2);
  975. }
  976. #else
  977. const uint8_t * q4 = bq->qs + 4*iqs;
  978. const int8_t * q8 = bq8_1->qs + 4*iqs;
  979. int sumi1 = 0, sumi2 = 0;
  980. for (int l = 0; l < 4*VDR_Q4_0_Q8_1_MMVQ; ++l) {
  981. sumi1 += q8[l+ 0] * kvalues_iq4nl[q4[l] & 0xf];
  982. sumi2 += q8[l+16] * kvalues_iq4nl[q4[l] >> 4];
  983. }
  984. #endif
  985. const float d = (float)bq->d * __low2float(bq8_1->ds);
  986. return d * (sumi1 + sumi2);
  987. }
  988. static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
  989. const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
  990. #if QK_K == 256
  991. #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
  992. const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq;
  993. const uint8_t * values = (const uint8_t *)kvalues_iq4nl;
  994. // iqs is 0...7
  995. const int ib32 = iqs;
  996. const int32_t * q8 = (const int *)bq8_1[ib32].qs;
  997. const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32;
  998. const int8_t ls = ((bq4->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((bq4->scales_h >> 2*ib32) & 3) << 4);
  999. const float d = (float)bq4->d * (ls - 32) * __low2float(bq8_1[ib32].ds);
  1000. int v1, v2;
  1001. int sumi1 = 0, sumi2 = 0;
  1002. for (int j = 0; j < 4; ++j) {
  1003. get_int_from_table_16(q4[j], values, v1, v2);
  1004. sumi1 = __dp4a(v1, q8[j+0], sumi1);
  1005. sumi2 = __dp4a(v2, q8[j+4], sumi2);
  1006. }
  1007. return d * (sumi1 + sumi2);
  1008. #else
  1009. NO_DEVICE_CODE;
  1010. #endif
  1011. #else
  1012. return vec_dot_iq4_xs_q8_1(vbq, bq8_1, iqs);
  1013. #endif
  1014. }