dmmv.cu 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709
  1. /**
  2. * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
  3. *
  4. * MIT License
  5. *
  6. * Copyright (c) 2023-2024 The ggml authors
  7. *
  8. * Permission is hereby granted, free of charge, to any person obtaining a copy
  9. * of this software and associated documentation files (the "Software"), to deal
  10. * in the Software without restriction, including without limitation the rights
  11. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  12. * copies of the Software, and to permit persons to whom the Software is
  13. * furnished to do so, subject to the following conditions:
  14. *
  15. * The above copyright notice and this permission notice shall be included in all
  16. * copies or substantial portions of the Software.
  17. *
  18. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  19. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  20. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  21. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  22. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  23. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  24. * SOFTWARE.
  25. */
  26. #include "dmmv.cuh"
  27. #include "dequantize.cuh"
  28. #include "convert.cuh"
  29. #ifndef K_QUANTS_PER_ITERATION
  30. #define K_QUANTS_PER_ITERATION 2
  31. #else
  32. static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
  33. #endif
  34. static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
  35. static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
  36. const int row = blockIdx.x*blockDim.y + threadIdx.y;
  37. if (row > nrows) return;
  38. const int num_blocks_per_row = ncols / QK_K;
  39. const int ib0 = row*num_blocks_per_row;
  40. const block_q2_K * x = (const block_q2_K *)vx + ib0;
  41. float tmp = 0; // partial sum for thread in warp
  42. const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15
  43. const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
  44. const int step = 16/K_QUANTS_PER_ITERATION;
  45. const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
  46. const int in = tid - step*im; // 0...15 or 0...7
  47. const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2
  48. const int q_offset = 32*im + l0;
  49. const int s_offset = 8*im;
  50. const int y_offset = 128*im + l0;
  51. uint32_t aux[4];
  52. const uint8_t * d = (const uint8_t *)aux;
  53. const uint8_t * m = (const uint8_t *)(aux + 2);
  54. for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
  55. const float * y = yy + i * QK_K + y_offset;
  56. const uint8_t * q = x[i].qs + q_offset;
  57. const float dall = __low2half(x[i].dm);
  58. const float dmin = __high2half(x[i].dm);
  59. const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset);
  60. aux[0] = a[0] & 0x0f0f0f0f;
  61. aux[1] = a[1] & 0x0f0f0f0f;
  62. aux[2] = (a[0] >> 4) & 0x0f0f0f0f;
  63. aux[3] = (a[1] >> 4) & 0x0f0f0f0f;
  64. float sum1 = 0, sum2 = 0;
  65. for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
  66. sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3)
  67. + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3)
  68. + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3)
  69. + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3)
  70. + y[l+16] * d[1] * ((q[l+16] >> 0) & 3)
  71. + y[l+48] * d[3] * ((q[l+16] >> 2) & 3)
  72. + y[l+80] * d[5] * ((q[l+16] >> 4) & 3)
  73. +y[l+112] * d[7] * ((q[l+16] >> 6) & 3);
  74. sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6]
  75. + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7];
  76. }
  77. tmp += dall * sum1 - dmin * sum2;
  78. }
  79. // sum up partial sums and write back result
  80. tmp = warp_reduce_sum(tmp);
  81. if (threadIdx.x == 0) {
  82. dst[row] = tmp;
  83. }
  84. }
  85. static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
  86. const int row = blockIdx.x*blockDim.y + threadIdx.y;
  87. if (row > nrows) return;
  88. const int num_blocks_per_row = ncols / QK_K;
  89. const int ib0 = row*num_blocks_per_row;
  90. const block_q3_K * x = (const block_q3_K *)vx + ib0;
  91. float tmp = 0; // partial sum for thread in warp
  92. const uint16_t kmask1 = 0x0303;
  93. const uint16_t kmask2 = 0x0f0f;
  94. const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
  95. const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
  96. const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop
  97. const int step = 16/K_QUANTS_PER_ITERATION;
  98. const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
  99. const int in = tid - step*im; // 0....15 or 0...7
  100. const uint8_t m = 1 << (4*im);
  101. const int l0 = n*in; // 0...15 or 0...14 in steps of 2
  102. const int q_offset = 32*im + l0;
  103. const int y_offset = 128*im + l0;
  104. uint16_t utmp[4];
  105. const int8_t * s = (const int8_t *)utmp;
  106. const uint16_t s_shift = 4*im;
  107. for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
  108. const float * y = yy + i * QK_K + y_offset;
  109. const uint8_t * q = x[i].qs + q_offset;
  110. const uint8_t * h = x[i].hmask + l0;
  111. const uint16_t * a = (const uint16_t *)x[i].scales;
  112. utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4);
  113. utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4);
  114. utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4);
  115. utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4);
  116. const float d = x[i].d;
  117. float sum = 0;
  118. for (int l = 0; l < n; ++l) {
  119. sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4))
  120. + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4))
  121. + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4))
  122. + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4));
  123. sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4))
  124. + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4))
  125. + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4))
  126. + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4));
  127. }
  128. tmp += d * sum;
  129. }
  130. // sum up partial sums and write back result
  131. tmp = warp_reduce_sum(tmp);
  132. if (threadIdx.x == 0) {
  133. dst[row] = tmp;
  134. }
  135. }
  136. static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
  137. const int row = blockIdx.x*blockDim.y + threadIdx.y;
  138. if (row > nrows) return;
  139. const int num_blocks_per_row = ncols / QK_K;
  140. const int ib0 = row*num_blocks_per_row;
  141. const block_q4_K * x = (const block_q4_K *)vx + ib0;
  142. const uint16_t kmask1 = 0x3f3f;
  143. const uint16_t kmask2 = 0x0f0f;
  144. const uint16_t kmask3 = 0xc0c0;
  145. const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
  146. const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
  147. const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4
  148. const int il = tid/step; // 0...3
  149. const int ir = tid - step*il; // 0...7 or 0...3
  150. const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4
  151. const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
  152. const int in = il%2;
  153. const int l0 = n*(2*ir + in);
  154. const int q_offset = 32*im + l0;
  155. const int y_offset = 64*im + l0;
  156. uint16_t aux[4];
  157. const uint8_t * sc = (const uint8_t *)aux;
  158. #if K_QUANTS_PER_ITERATION == 2
  159. uint32_t q32[4];
  160. const uint8_t * q4 = (const uint8_t *)q32;
  161. #else
  162. uint16_t q16[4];
  163. const uint8_t * q4 = (const uint8_t *)q16;
  164. #endif
  165. float tmp = 0; // partial sum for thread in warp
  166. for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
  167. const float * y1 = yy + i*QK_K + y_offset;
  168. const float * y2 = y1 + 128;
  169. const float dall = __low2half(x[i].dm);
  170. const float dmin = __high2half(x[i].dm);
  171. const uint16_t * a = (const uint16_t *)x[i].scales;
  172. aux[0] = a[im+0] & kmask1;
  173. aux[1] = a[im+2] & kmask1;
  174. aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
  175. aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
  176. #if K_QUANTS_PER_ITERATION == 2
  177. const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset);
  178. const uint32_t * q2 = q1 + 16;
  179. q32[0] = q1[0] & 0x0f0f0f0f;
  180. q32[1] = q1[0] & 0xf0f0f0f0;
  181. q32[2] = q2[0] & 0x0f0f0f0f;
  182. q32[3] = q2[0] & 0xf0f0f0f0;
  183. float4 s = {0.f, 0.f, 0.f, 0.f};
  184. float smin = 0;
  185. for (int l = 0; l < 4; ++l) {
  186. s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+ 4];
  187. s.z += y2[l] * q4[l+8]; s.w += y2[l+32] * q4[l+12];
  188. smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
  189. }
  190. tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
  191. #else
  192. const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset);
  193. const uint16_t * q2 = q1 + 32;
  194. q16[0] = q1[0] & 0x0f0f;
  195. q16[1] = q1[0] & 0xf0f0;
  196. q16[2] = q2[0] & 0x0f0f;
  197. q16[3] = q2[0] & 0xf0f0;
  198. float4 s = {0.f, 0.f, 0.f, 0.f};
  199. float smin = 0;
  200. for (int l = 0; l < 2; ++l) {
  201. s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2];
  202. s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6];
  203. smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
  204. }
  205. tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
  206. #endif
  207. }
  208. // sum up partial sums and write back result
  209. tmp = warp_reduce_sum(tmp);
  210. if (tid == 0) {
  211. dst[row] = tmp;
  212. }
  213. }
  214. static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols) {
  215. const int row = blockIdx.x;
  216. const int num_blocks_per_row = ncols / QK_K;
  217. const int ib0 = row*num_blocks_per_row;
  218. const block_q5_K * x = (const block_q5_K *)vx + ib0;
  219. float tmp = 0; // partial sum for thread in warp
  220. const uint16_t kmask1 = 0x3f3f;
  221. const uint16_t kmask2 = 0x0f0f;
  222. const uint16_t kmask3 = 0xc0c0;
  223. const int tid = threadIdx.x/2; // 0...15
  224. const int ix = threadIdx.x%2;
  225. const int il = tid/4; // 0...3
  226. const int ir = tid - 4*il;// 0...3
  227. const int n = 2;
  228. const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
  229. const int in = il%2;
  230. const int l0 = n*(2*ir + in);
  231. const int q_offset = 32*im + l0;
  232. const int y_offset = 64*im + l0;
  233. const uint8_t hm1 = 1 << (2*im);
  234. const uint8_t hm2 = hm1 << 4;
  235. uint16_t aux[4];
  236. const uint8_t * sc = (const uint8_t *)aux;
  237. uint16_t q16[8];
  238. const uint8_t * q4 = (const uint8_t *)q16;
  239. for (int i = ix; i < num_blocks_per_row; i += 2) {
  240. const uint8_t * ql1 = x[i].qs + q_offset;
  241. const uint8_t * qh = x[i].qh + l0;
  242. const float * y1 = yy + i*QK_K + y_offset;
  243. const float * y2 = y1 + 128;
  244. const float dall = __low2half(x[i].dm);
  245. const float dmin = __high2half(x[i].dm);
  246. const uint16_t * a = (const uint16_t *)x[i].scales;
  247. aux[0] = a[im+0] & kmask1;
  248. aux[1] = a[im+2] & kmask1;
  249. aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
  250. aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
  251. float4 sum = {0.f, 0.f, 0.f, 0.f};
  252. float smin = 0;
  253. const uint16_t * q1 = (const uint16_t *)ql1;
  254. const uint16_t * q2 = q1 + 32;
  255. q16[0] = q1[0] & 0x0f0f;
  256. q16[1] = q1[8] & 0x0f0f;
  257. q16[2] = (q1[0] >> 4) & 0x0f0f;
  258. q16[3] = (q1[8] >> 4) & 0x0f0f;
  259. q16[4] = q2[0] & 0x0f0f;
  260. q16[5] = q2[8] & 0x0f0f;
  261. q16[6] = (q2[0] >> 4) & 0x0f0f;
  262. q16[7] = (q2[8] >> 4) & 0x0f0f;
  263. for (int l = 0; l < n; ++l) {
  264. sum.x += y1[l+ 0] * (q4[l +0] + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
  265. + y1[l+16] * (q4[l +2] + (qh[l+16] & (hm1 << 0) ? 16 : 0));
  266. sum.y += y1[l+32] * (q4[l +4] + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
  267. + y1[l+48] * (q4[l +6] + (qh[l+16] & (hm1 << 1) ? 16 : 0));
  268. sum.z += y2[l+ 0] * (q4[l +8] + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
  269. + y2[l+16] * (q4[l+10] + (qh[l+16] & (hm2 << 0) ? 16 : 0));
  270. sum.w += y2[l+32] * (q4[l+12] + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
  271. + y2[l+48] * (q4[l+14] + (qh[l+16] & (hm2 << 1) ? 16 : 0));
  272. smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
  273. + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
  274. }
  275. tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
  276. }
  277. // sum up partial sums and write back result
  278. tmp = warp_reduce_sum(tmp);
  279. if (threadIdx.x == 0) {
  280. dst[row] = tmp;
  281. }
  282. }
  283. static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
  284. static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
  285. const int row = blockIdx.x*blockDim.y + threadIdx.y;
  286. if (row > nrows) return;
  287. const int num_blocks_per_row = ncols / QK_K;
  288. const int ib0 = row*num_blocks_per_row;
  289. const block_q6_K * x = (const block_q6_K *)vx + ib0;
  290. const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
  291. const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
  292. const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
  293. const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
  294. const int in = tid - step*im; // 0...15 or 0...7
  295. #if K_QUANTS_PER_ITERATION == 1
  296. const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15
  297. const int is = 0;
  298. #else
  299. const int l0 = 4 * in; // 0, 4, 8, ..., 28
  300. const int is = in / 4;
  301. #endif
  302. const int ql_offset = 64*im + l0;
  303. const int qh_offset = 32*im + l0;
  304. const int s_offset = 8*im + is;
  305. const int y_offset = 128*im + l0;
  306. float tmp = 0; // partial sum for thread in warp
  307. for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
  308. const float * y = yy + i * QK_K + y_offset;
  309. const uint8_t * ql = x[i].ql + ql_offset;
  310. const uint8_t * qh = x[i].qh + qh_offset;
  311. const int8_t * s = x[i].scales + s_offset;
  312. const float d = x[i].d;
  313. #if K_QUANTS_PER_ITERATION == 1
  314. float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
  315. + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
  316. + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
  317. + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
  318. + y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
  319. + y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
  320. + y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
  321. +y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
  322. tmp += sum;
  323. #else
  324. float sum = 0;
  325. for (int l = 0; l < 4; ++l) {
  326. sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
  327. + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32)
  328. + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32)
  329. + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
  330. }
  331. tmp += sum;
  332. #endif
  333. }
  334. // sum up partial sums and write back result
  335. tmp = warp_reduce_sum(tmp);
  336. if (tid == 0) {
  337. dst[row] = tmp;
  338. }
  339. }
  340. static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
  341. const half * x = (const half *) vx;
  342. // automatic half -> float type cast if dfloat == float
  343. v.x = x[ib + iqs + 0];
  344. v.y = x[ib + iqs + 1];
  345. }
  346. static constexpr __device__ dequantize_kernel_t get_dequantize_kernel(ggml_type type) {
  347. return type == GGML_TYPE_Q4_0 ? dequantize_q4_0 :
  348. type == GGML_TYPE_Q4_1 ? dequantize_q4_1 :
  349. type == GGML_TYPE_Q5_0 ? dequantize_q5_0 :
  350. type == GGML_TYPE_Q5_1 ? dequantize_q5_1 :
  351. type == GGML_TYPE_Q8_0 ? dequantize_q8_0 :
  352. type == GGML_TYPE_F16 ? convert_f16 :
  353. nullptr;
  354. }
  355. template <ggml_type type>
  356. static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
  357. constexpr int qk = ggml_cuda_type_traits<type>::qk; // quantized weights per x block
  358. constexpr int qr = ggml_cuda_type_traits<type>::qr; // number of quantized weights per data value in x block
  359. constexpr dequantize_kernel_t dequantize_kernel = get_dequantize_kernel(type);
  360. const int64_t row = (int64_t)blockIdx.x*blockDim.y + threadIdx.y;
  361. if (row >= nrows) {
  362. return;
  363. }
  364. const int tid = threadIdx.x;
  365. const int iter_stride = 2*GGML_CUDA_DMMV_X;
  366. const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
  367. const int y_offset = qr == 1 ? 1 : qk/2;
  368. // partial sum for each thread
  369. #ifdef GGML_CUDA_F16
  370. half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics
  371. #else
  372. float tmp = 0.0f;
  373. #endif // GGML_CUDA_F16
  374. for (int i = 0; i < ncols; i += iter_stride) {
  375. const int col = i + vals_per_iter*tid;
  376. const int64_t ib = ((int64_t)row*ncols + col)/qk; // x block index
  377. const int iqs = (col%qk)/qr; // x quant index
  378. const int iybs = col - col%qk; // y block start index
  379. // processing >2 values per i iter is faster for fast GPUs
  380. #pragma unroll
  381. for (int j = 0; j < vals_per_iter; j += 2) {
  382. // process 2 vals per j iter
  383. // dequantize
  384. // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
  385. dfloat2 v;
  386. dequantize_kernel(vx, ib, iqs + j/qr, v);
  387. // matrix multiplication
  388. // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
  389. #ifdef GGML_CUDA_F16
  390. tmp += __hmul2(v, {
  391. y[iybs + iqs + j/qr + 0],
  392. y[iybs + iqs + j/qr + y_offset]
  393. });
  394. #else
  395. tmp += v.x * y[iybs + iqs + j/qr + 0];
  396. tmp += v.y * y[iybs + iqs + j/qr + y_offset];
  397. #endif // GGML_CUDA_F16
  398. }
  399. }
  400. // sum up partial sums and write back result
  401. tmp = warp_reduce_sum(tmp);
  402. if (tid == 0) {
  403. #ifdef GGML_CUDA_F16
  404. dst[row] = tmp.x + tmp.y;
  405. #else
  406. dst[row] = tmp;
  407. #endif // GGML_CUDA_F16
  408. }
  409. }
  410. static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  411. GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
  412. const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
  413. // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
  414. const dim3 block_nums(block_num_y, 1, 1);
  415. const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
  416. dequantize_mul_mat_vec<GGML_TYPE_Q4_0>
  417. <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  418. }
  419. static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  420. GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
  421. const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
  422. const dim3 block_nums(block_num_y, 1, 1);
  423. const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
  424. dequantize_mul_mat_vec<GGML_TYPE_Q4_1>
  425. <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  426. }
  427. static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  428. GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
  429. const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
  430. const dim3 block_nums(block_num_y, 1, 1);
  431. const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
  432. dequantize_mul_mat_vec<GGML_TYPE_Q5_0>
  433. <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  434. }
  435. static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  436. GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
  437. const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
  438. const dim3 block_nums(block_num_y, 1, 1);
  439. const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
  440. dequantize_mul_mat_vec<GGML_TYPE_Q5_1>
  441. <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  442. }
  443. static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  444. GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
  445. const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
  446. const dim3 block_nums(block_num_y, 1, 1);
  447. const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
  448. dequantize_mul_mat_vec<GGML_TYPE_Q8_0>
  449. <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  450. }
  451. static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  452. GGML_ASSERT(ncols % QK_K == 0);
  453. const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
  454. const int block_num_y = (nrows + ny - 1) / ny;
  455. const dim3 block_nums(block_num_y, 1, 1);
  456. const dim3 block_dims(32, ny, 1);
  457. dequantize_mul_mat_vec_q2_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  458. }
  459. static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  460. GGML_ASSERT(ncols % QK_K == 0);
  461. const int ny = 2 / K_QUANTS_PER_ITERATION;
  462. const int block_num_y = (nrows + ny - 1) / ny;
  463. const dim3 block_nums(block_num_y, 1, 1);
  464. const dim3 block_dims(32, ny, 1);
  465. dequantize_mul_mat_vec_q3_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  466. }
  467. static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  468. GGML_ASSERT(ncols % QK_K == 0);
  469. const int ny = 2 / K_QUANTS_PER_ITERATION;
  470. const int block_num_y = (nrows + ny - 1) / ny;
  471. const dim3 block_nums(block_num_y, 1, 1);
  472. const dim3 block_dims(32, ny, 1);
  473. dequantize_mul_mat_vec_q4_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  474. }
  475. static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  476. GGML_ASSERT(ncols % QK_K == 0);
  477. const dim3 block_dims(32, 1, 1);
  478. dequantize_mul_mat_vec_q5_k<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols);
  479. }
  480. static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  481. GGML_ASSERT(ncols % QK_K == 0);
  482. const int ny = 2 / K_QUANTS_PER_ITERATION;
  483. const int block_num_y = (nrows + ny - 1) / ny;
  484. const dim3 block_nums(block_num_y, 1, 1);
  485. const dim3 block_dims(32, ny, 1);
  486. dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  487. }
  488. static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  489. GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
  490. const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
  491. const dim3 block_nums(block_num_y, 1, 1);
  492. const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
  493. dequantize_mul_mat_vec<GGML_TYPE_F16>
  494. <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  495. }
  496. void ggml_cuda_op_dequantize_mul_mat_vec(
  497. ggml_backend_cuda_context & ctx,
  498. const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
  499. const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
  500. const int64_t src1_padded_row_size, cudaStream_t stream) {
  501. GGML_UNUSED(ctx);
  502. const int64_t ne00 = src0->ne[0];
  503. const int64_t row_diff = row_high - row_low;
  504. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  505. // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
  506. #ifdef GGML_CUDA_F16
  507. ggml_cuda_pool_alloc<half> src1_dfloat_a(ctx.pool());
  508. half * src1_dfloat = nullptr; // dfloat == half
  509. bool src1_convert_f16 =
  510. src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
  511. src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
  512. src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
  513. if (src1_convert_f16) {
  514. src1_dfloat = src1_dfloat_a.alloc(ne00);
  515. const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
  516. GGML_ASSERT(to_fp16_cuda != nullptr);
  517. to_fp16_cuda(src1_ddf_i, src1_dfloat, ne00, stream);
  518. }
  519. #else
  520. const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion
  521. #endif // GGML_CUDA_F16
  522. switch (src0->type) {
  523. case GGML_TYPE_Q4_0:
  524. dequantize_mul_mat_vec_q4_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
  525. break;
  526. case GGML_TYPE_Q4_1:
  527. dequantize_mul_mat_vec_q4_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
  528. break;
  529. case GGML_TYPE_Q5_0:
  530. dequantize_mul_mat_vec_q5_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
  531. break;
  532. case GGML_TYPE_Q5_1:
  533. dequantize_mul_mat_vec_q5_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
  534. break;
  535. case GGML_TYPE_Q8_0:
  536. dequantize_mul_mat_vec_q8_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
  537. break;
  538. case GGML_TYPE_Q2_K:
  539. dequantize_mul_mat_vec_q2_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
  540. break;
  541. case GGML_TYPE_Q3_K:
  542. dequantize_mul_mat_vec_q3_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
  543. break;
  544. case GGML_TYPE_Q4_K:
  545. dequantize_mul_mat_vec_q4_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
  546. break;
  547. case GGML_TYPE_Q5_K:
  548. dequantize_mul_mat_vec_q5_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
  549. break;
  550. case GGML_TYPE_Q6_K:
  551. dequantize_mul_mat_vec_q6_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
  552. break;
  553. case GGML_TYPE_F16:
  554. convert_mul_mat_vec_f16_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
  555. break;
  556. default:
  557. GGML_ABORT("fatal error");
  558. break;
  559. }
  560. GGML_UNUSED(src1);
  561. GGML_UNUSED(dst);
  562. GGML_UNUSED(src1_ddq_i);
  563. GGML_UNUSED(src1_ncols);
  564. GGML_UNUSED(src1_padded_row_size);
  565. }
  566. bool ggml_cuda_dmmv_type_supported(ggml_type src0_type) {
  567. return src0_type == GGML_TYPE_Q4_0 || src0_type == GGML_TYPE_Q4_1 ||
  568. src0_type == GGML_TYPE_Q5_0 || src0_type == GGML_TYPE_Q5_1 ||
  569. src0_type == GGML_TYPE_Q8_0 || src0_type == GGML_TYPE_Q2_K ||
  570. src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q4_K ||
  571. src0_type == GGML_TYPE_Q5_K || src0_type == GGML_TYPE_Q6_K ||
  572. src0_type == GGML_TYPE_F16;
  573. }