convert.cu 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824
  1. #include "convert.cuh"
  2. #include "dequantize.cuh"
  3. #define CUDA_Q8_0_NE_ALIGN 2048
  4. template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
  5. static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
  6. const int64_t i = (int64_t)2*(blockDim.x*blockIdx.x + threadIdx.x);
  7. if (i >= k) {
  8. return;
  9. }
  10. const int64_t ib = i/qk; // block index
  11. const int64_t iqs = (i%qk)/qr; // quant index
  12. const int64_t iybs = i - i%qk; // y block start index
  13. const int64_t y_offset = qr == 1 ? 1 : qk/2;
  14. // dequantize
  15. dfloat2 v;
  16. dequantize_kernel(vx, ib, iqs, v);
  17. y[iybs + iqs + 0] = v.x;
  18. y[iybs + iqs + y_offset] = v.y;
  19. }
  20. template <bool need_check>
  21. static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int64_t k) {
  22. #if __CUDA_ARCH__ >= CC_PASCAL
  23. constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE;
  24. const int64_t i0 = CUDA_Q8_0_NE_ALIGN*blockIdx.x;
  25. const int * x0 = ((int *) vx) + blockIdx.x * nint;
  26. half2 * y2 = (half2 *) (y + i0);
  27. __shared__ int vals[nint];
  28. #pragma unroll
  29. for (int ix0 = 0; ix0 < nint; ix0 += WARP_SIZE) {
  30. if (need_check && i0*sizeof(block_q8_0)/QK8_0 + sizeof(int)*(ix0 + threadIdx.x) >= k*sizeof(block_q8_0)/QK8_0) {
  31. break;
  32. }
  33. const int ix = ix0 + threadIdx.x;
  34. vals[ix] = x0[ix];
  35. }
  36. __syncthreads();
  37. #pragma unroll
  38. for (int iy = 0; iy < CUDA_Q8_0_NE_ALIGN; iy += 2*WARP_SIZE) {
  39. if (need_check && i0 + iy + 2*threadIdx.x >= k) {
  40. return;
  41. }
  42. const half * b0 = ((const half *) vals) + (sizeof(block_q8_0)/sizeof(half)) * ((iy + 2*threadIdx.x)/QK8_0);
  43. const half d = *b0;
  44. const char2 qs = ((const char2 *) (b0 + 1))[threadIdx.x % (QK8_0/2)];
  45. y2[iy/2 + threadIdx.x] = __hmul2(make_half2(qs.x, qs.y), __half2half2(d));
  46. }
  47. #else
  48. GGML_UNUSED(vx);
  49. GGML_UNUSED(y);
  50. GGML_UNUSED(k);
  51. NO_DEVICE_CODE;
  52. #endif // __CUDA_ARCH__ >= CC_PASCAL
  53. }
  54. template<typename dst_t>
  55. static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
  56. const int64_t i = blockIdx.x;
  57. // assume 32 threads
  58. const int64_t tid = threadIdx.x;
  59. const int64_t il = tid/8;
  60. const int64_t ir = tid%8;
  61. const int64_t ib = 8*i + ir;
  62. if (ib >= nb32) {
  63. return;
  64. }
  65. dst_t * y = yy + 256*i + 32*ir + 4*il;
  66. const block_q4_0 * x = (const block_q4_0 *)vx + ib;
  67. const float d = __half2float(x->d);
  68. const float dm = -8*d;
  69. const uint8_t * q = x->qs + 4*il;
  70. for (int l = 0; l < 4; ++l) {
  71. y[l+ 0] = d * (q[l] & 0xF) + dm;
  72. y[l+16] = d * (q[l] >> 4) + dm;
  73. }
  74. }
  75. template<typename dst_t>
  76. static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
  77. const int64_t i = blockIdx.x;
  78. // assume 32 threads
  79. const int64_t tid = threadIdx.x;
  80. const int64_t il = tid/8;
  81. const int64_t ir = tid%8;
  82. const int64_t ib = 8*i + ir;
  83. if (ib >= nb32) {
  84. return;
  85. }
  86. dst_t * y = yy + 256*i + 32*ir + 4*il;
  87. const block_q4_1 * x = (const block_q4_1 *)vx + ib;
  88. const float2 d = __half22float2(x->dm);
  89. const uint8_t * q = x->qs + 4*il;
  90. for (int l = 0; l < 4; ++l) {
  91. y[l+ 0] = d.x * (q[l] & 0xF) + d.y;
  92. y[l+16] = d.x * (q[l] >> 4) + d.y;
  93. }
  94. }
  95. //================================== k-quants
  96. template<typename dst_t>
  97. static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  98. const int64_t i = blockIdx.x;
  99. const block_q2_K * x = (const block_q2_K *) vx;
  100. const int64_t tid = threadIdx.x;
  101. #if QK_K == 256
  102. const int64_t n = tid/32;
  103. const int64_t l = tid - 32*n;
  104. const int64_t is = 8*n + l/16;
  105. const uint8_t q = x[i].qs[32*n + l];
  106. dst_t * y = yy + i*QK_K + 128*n;
  107. float dall = __low2half(x[i].dm);
  108. float dmin = __high2half(x[i].dm);
  109. y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
  110. y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
  111. y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
  112. y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
  113. #else
  114. const int64_t is = tid/16; // 0 or 1
  115. const int64_t il = tid%16; // 0...15
  116. const uint8_t q = x[i].qs[il] >> (2*is);
  117. dst_t * y = yy + i*QK_K + 16*is + il;
  118. float dall = __low2half(x[i].dm);
  119. float dmin = __high2half(x[i].dm);
  120. y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
  121. y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);
  122. #endif
  123. }
  124. template<typename dst_t>
  125. static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  126. const int64_t i = blockIdx.x;
  127. const block_q3_K * x = (const block_q3_K *) vx;
  128. #if QK_K == 256
  129. const int64_t r = threadIdx.x/4;
  130. const int64_t tid = r/2;
  131. const int64_t is0 = r%2;
  132. const int64_t l0 = 16*is0 + 4*(threadIdx.x%4);
  133. const int64_t n = tid / 4;
  134. const int64_t j = tid - 4*n;
  135. uint8_t m = 1 << (4*n + j);
  136. int64_t is = 8*n + 2*j + is0;
  137. int shift = 2*j;
  138. int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
  139. is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) :
  140. is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) :
  141. (x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4);
  142. float d_all = x[i].d;
  143. float dl = d_all * (us - 32);
  144. dst_t * y = yy + i*QK_K + 128*n + 32*j;
  145. const uint8_t * q = x[i].qs + 32*n;
  146. const uint8_t * hm = x[i].hmask;
  147. for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
  148. #else
  149. const int64_t tid = threadIdx.x;
  150. const int64_t is = tid/16; // 0 or 1
  151. const int64_t il = tid%16; // 0...15
  152. const int64_t im = il/8; // 0...1
  153. const int64_t in = il%8; // 0...7
  154. dst_t * y = yy + i*QK_K + 16*is + il;
  155. const uint8_t q = x[i].qs[il] >> (2*is);
  156. const uint8_t h = x[i].hmask[in] >> (2*is + im);
  157. const float d = (float)x[i].d;
  158. if (is == 0) {
  159. y[ 0] = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
  160. y[32] = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
  161. } else {
  162. y[ 0] = d * ((x[i].scales[0] >> 4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
  163. y[32] = d * ((x[i].scales[1] >> 4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
  164. }
  165. #endif
  166. }
  167. #if QK_K == 256
  168. static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
  169. if (j < 4) {
  170. d = q[j] & 63; m = q[j + 4] & 63;
  171. } else {
  172. d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
  173. m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
  174. }
  175. }
  176. #endif
  177. template<typename dst_t>
  178. static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  179. const block_q4_K * x = (const block_q4_K *) vx;
  180. const int64_t i = blockIdx.x;
  181. #if QK_K == 256
  182. // assume 32 threads
  183. const int64_t tid = threadIdx.x;
  184. const int64_t il = tid/8;
  185. const int64_t ir = tid%8;
  186. const int64_t is = 2*il;
  187. const int64_t n = 4;
  188. dst_t * y = yy + i*QK_K + 64*il + n*ir;
  189. const float dall = __low2half(x[i].dm);
  190. const float dmin = __high2half(x[i].dm);
  191. const uint8_t * q = x[i].qs + 32*il + n*ir;
  192. uint8_t sc, m;
  193. get_scale_min_k4(is + 0, x[i].scales, sc, m);
  194. const float d1 = dall * sc; const float m1 = dmin * m;
  195. get_scale_min_k4(is + 1, x[i].scales, sc, m);
  196. const float d2 = dall * sc; const float m2 = dmin * m;
  197. for (int l = 0; l < n; ++l) {
  198. y[l + 0] = d1 * (q[l] & 0xF) - m1;
  199. y[l +32] = d2 * (q[l] >> 4) - m2;
  200. }
  201. #else
  202. const int64_t tid = threadIdx.x;
  203. const uint8_t * q = x[i].qs;
  204. dst_t * y = yy + i*QK_K;
  205. const float d = (float)x[i].dm[0];
  206. const float m = (float)x[i].dm[1];
  207. y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
  208. y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >> 4) - m * (x[i].scales[1] >> 4);
  209. #endif
  210. }
  211. template<typename dst_t>
  212. static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  213. const block_q5_K * x = (const block_q5_K *) vx;
  214. const int64_t i = blockIdx.x;
  215. #if QK_K == 256
  216. // assume 64 threads - this is very slightly better than the one below
  217. const int64_t tid = threadIdx.x;
  218. const int64_t il = tid/16; // il is in 0...3
  219. const int64_t ir = tid%16; // ir is in 0...15
  220. const int64_t is = 2*il; // is is in 0...6
  221. dst_t * y = yy + i*QK_K + 64*il + 2*ir;
  222. const float dall = __low2half(x[i].dm);
  223. const float dmin = __high2half(x[i].dm);
  224. const uint8_t * ql = x[i].qs + 32*il + 2*ir;
  225. const uint8_t * qh = x[i].qh + 2*ir;
  226. uint8_t sc, m;
  227. get_scale_min_k4(is + 0, x[i].scales, sc, m);
  228. const float d1 = dall * sc; const float m1 = dmin * m;
  229. get_scale_min_k4(is + 1, x[i].scales, sc, m);
  230. const float d2 = dall * sc; const float m2 = dmin * m;
  231. uint8_t hm = 1 << (2*il);
  232. y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1;
  233. y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1;
  234. hm <<= 1;
  235. y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
  236. y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
  237. #else
  238. const int64_t tid = threadIdx.x;
  239. const uint8_t q = x[i].qs[tid];
  240. const int64_t im = tid/8; // 0...3
  241. const int64_t in = tid%8; // 0...7
  242. const int64_t is = tid/16; // 0 or 1
  243. const uint8_t h = x[i].qh[in] >> im;
  244. const float d = x[i].d;
  245. dst_t * y = yy + i*QK_K + tid;
  246. y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16));
  247. y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16));
  248. #endif
  249. }
  250. template<typename dst_t>
  251. static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  252. const block_q6_K * x = (const block_q6_K *) vx;
  253. const int64_t i = blockIdx.x;
  254. #if QK_K == 256
  255. // assume 64 threads - this is very slightly better than the one below
  256. const int64_t tid = threadIdx.x;
  257. const int64_t ip = tid/32; // ip is 0 or 1
  258. const int64_t il = tid - 32*ip; // 0...32
  259. const int64_t is = 8*ip + il/16;
  260. dst_t * y = yy + i*QK_K + 128*ip + il;
  261. const float d = x[i].d;
  262. const uint8_t * ql = x[i].ql + 64*ip + il;
  263. const uint8_t qh = x[i].qh[32*ip + il];
  264. const int8_t * sc = x[i].scales + is;
  265. y[ 0] = d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
  266. y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
  267. y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
  268. y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
  269. #else
  270. // assume 32 threads
  271. const int64_t tid = threadIdx.x;
  272. const int64_t ip = tid/16; // 0 or 1
  273. const int64_t il = tid - 16*ip; // 0...15
  274. dst_t * y = yy + i*QK_K + 16*ip + il;
  275. const float d = x[i].d;
  276. const uint8_t ql = x[i].ql[16*ip + il];
  277. const uint8_t qh = x[i].qh[il] >> (2*ip);
  278. const int8_t * sc = x[i].scales;
  279. y[ 0] = d * sc[ip+0] * ((int8_t)((ql & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
  280. y[32] = d * sc[ip+2] * ((int8_t)((ql >> 4) | (((qh >> 4) & 3) << 4)) - 32);
  281. #endif
  282. }
  283. template<typename dst_t>
  284. static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  285. const int64_t i = blockIdx.x;
  286. const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
  287. const int64_t tid = threadIdx.x;
  288. #if QK_K == 256
  289. const int64_t il = tid/8; // 0...3
  290. const int64_t ib = tid%8; // 0...7
  291. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  292. const uint16_t * q2 = x[i].qs + 4*ib;
  293. const uint8_t * aux8 = (const uint8_t *)q2;
  294. const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[il]);
  295. const uint32_t aux32 = q2[2] | (q2[3] << 16);
  296. const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f;
  297. const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
  298. for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
  299. #else
  300. NO_DEVICE_CODE;
  301. #endif
  302. }
  303. template<typename dst_t>
  304. static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  305. const int64_t i = blockIdx.x;
  306. const block_iq2_xs * x = (const block_iq2_xs *) vx;
  307. const int64_t tid = threadIdx.x;
  308. #if QK_K == 256
  309. const int64_t il = tid/8; // 0...3
  310. const int64_t ib = tid%8; // 0...7
  311. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  312. const uint16_t * q2 = x[i].qs + 4*ib;
  313. const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
  314. const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
  315. const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
  316. for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
  317. #else
  318. NO_DEVICE_CODE;
  319. #endif
  320. }
  321. template<typename dst_t>
  322. static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  323. const int64_t i = blockIdx.x;
  324. const block_iq2_s * x = (const block_iq2_s *) vx;
  325. const int64_t tid = threadIdx.x;
  326. #if QK_K == 256
  327. const int64_t il = tid/8; // 0...3
  328. const int64_t ib = tid%8; // 0...7
  329. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  330. const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
  331. const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
  332. const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
  333. for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
  334. #else
  335. NO_DEVICE_CODE;
  336. #endif
  337. }
  338. template<typename dst_t>
  339. static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  340. const int64_t i = blockIdx.x;
  341. const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
  342. const int64_t tid = threadIdx.x;
  343. #if QK_K == 256
  344. const int64_t il = tid/8; // 0...3
  345. const int64_t ib = tid%8; // 0...7
  346. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  347. const uint8_t * q3 = x[i].qs + 8*ib;
  348. const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
  349. const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*il+0]);
  350. const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*il+1]);
  351. const uint32_t aux32 = gas[0] | (gas[1] << 16);
  352. const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.5f;
  353. const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
  354. for (int j = 0; j < 4; ++j) {
  355. y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
  356. y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
  357. }
  358. #else
  359. NO_DEVICE_CODE;
  360. #endif
  361. }
  362. template<typename dst_t>
  363. static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  364. const int64_t i = blockIdx.x;
  365. const block_iq3_s * x = (const block_iq3_s *) vx;
  366. const int64_t tid = threadIdx.x;
  367. #if QK_K == 256
  368. const int64_t il = tid/8; // 0...3
  369. const int64_t ib = tid%8; // 0...7
  370. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  371. const uint8_t * qs = x[i].qs + 8*ib;
  372. const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
  373. const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));
  374. const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf));
  375. const uint8_t signs = x[i].signs[4*ib + il];
  376. for (int j = 0; j < 4; ++j) {
  377. y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
  378. y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
  379. }
  380. #else
  381. NO_DEVICE_CODE;
  382. #endif
  383. }
  384. template<typename dst_t>
  385. static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  386. const int64_t i = blockIdx.x;
  387. const block_iq1_s * x = (const block_iq1_s *) vx;
  388. const int64_t tid = threadIdx.x;
  389. #if QK_K == 256
  390. const int64_t il = tid/8; // 0...3
  391. const int64_t ib = tid%8; // 0...7
  392. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  393. const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
  394. const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
  395. uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
  396. grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)];
  397. grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
  398. grid32[0] &= 0x0f0f0f0f;
  399. for (int j = 0; j < 8; ++j) {
  400. y[j] = d * (q[j] + delta);
  401. }
  402. #else
  403. NO_DEVICE_CODE;
  404. #endif
  405. }
  406. template<typename dst_t>
  407. static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  408. const int64_t i = blockIdx.x;
  409. const block_iq1_m * x = (const block_iq1_m *) vx;
  410. const int64_t tid = threadIdx.x;
  411. #if QK_K == 256
  412. const int64_t il = tid/8; // 0...3
  413. const int64_t ib = tid%8; // 0...7
  414. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  415. const uint16_t * sc = (const uint16_t *)x[i].scales;
  416. iq1m_scale_t scale;
  417. scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
  418. const int64_t ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
  419. const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1);
  420. const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;
  421. uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
  422. grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[2*ib+il/2] >> 4*(il%2)) & 7) << 8)];
  423. grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
  424. grid32[0] &= 0x0f0f0f0f;
  425. for (int j = 0; j < 8; ++j) {
  426. y[j] = d * (q[j] + delta);
  427. }
  428. #else
  429. NO_DEVICE_CODE;
  430. #endif
  431. }
  432. template<typename dst_t>
  433. static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  434. const int64_t i = blockIdx.x;
  435. const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
  436. const int64_t tid = threadIdx.x;
  437. const int64_t il = tid/8; // 0...3
  438. const int64_t ib = tid%8; // 0...7
  439. dst_t * y = yy + i*QK_K + 32*ib + 4*il;
  440. const uint8_t * q4 = x[ib].qs + 4*il;
  441. const float d = (float)x[ib].d;
  442. for (int j = 0; j < 4; ++j) {
  443. y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
  444. y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
  445. }
  446. }
  447. #if QK_K != 64
  448. template<typename dst_t>
  449. static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  450. const int64_t i = blockIdx.x;
  451. const block_iq4_xs * x = (const block_iq4_xs *)vx;
  452. const int64_t tid = threadIdx.x;
  453. const int64_t il = tid/8; // 0...3
  454. const int64_t ib = tid%8; // 0...7
  455. dst_t * y = yy + i*QK_K + 32*ib + 4*il;
  456. const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
  457. const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
  458. for (int j = 0; j < 4; ++j) {
  459. y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
  460. y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
  461. }
  462. }
  463. #endif
  464. template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
  465. static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
  466. const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
  467. dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
  468. }
  469. static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) {
  470. const int num_blocks = (k + CUDA_Q8_0_NE_ALIGN - 1) / CUDA_Q8_0_NE_ALIGN;
  471. if (k % CUDA_Q8_0_NE_ALIGN == 0) {
  472. const bool need_check = false;
  473. dequantize_block_q8_0_f16<need_check><<<num_blocks, WARP_SIZE, 0, stream>>>(vx, y, k);
  474. } else {
  475. const bool need_check = true;
  476. dequantize_block_q8_0_f16<need_check><<<num_blocks, WARP_SIZE, 0, stream>>>(vx, y, k);
  477. }
  478. }
  479. template<typename dst_t>
  480. static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  481. const int nb = k / QK_K;
  482. #if QK_K == 256
  483. dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
  484. #else
  485. dequantize_block_q2_K<<<nb, 32, 0, stream>>>(vx, y);
  486. #endif
  487. }
  488. template<typename dst_t>
  489. static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  490. const int nb = k / QK_K;
  491. #if QK_K == 256
  492. dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
  493. #else
  494. dequantize_block_q3_K<<<nb, 32, 0, stream>>>(vx, y);
  495. #endif
  496. }
  497. template<typename dst_t>
  498. static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  499. const int nb32 = k / 32;
  500. const int nb = (k + 255) / 256;
  501. dequantize_block_q4_0<<<nb, 32, 0, stream>>>(vx, y, nb32);
  502. }
  503. template<typename dst_t>
  504. static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  505. const int nb32 = k / 32;
  506. const int nb = (k + 255) / 256;
  507. dequantize_block_q4_1<<<nb, 32, 0, stream>>>(vx, y, nb32);
  508. }
  509. template<typename dst_t>
  510. static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  511. const int nb = k / QK_K;
  512. dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
  513. }
  514. template<typename dst_t>
  515. static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  516. const int nb = k / QK_K;
  517. #if QK_K == 256
  518. dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
  519. #else
  520. dequantize_block_q5_K<<<nb, 32, 0, stream>>>(vx, y);
  521. #endif
  522. }
  523. template<typename dst_t>
  524. static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  525. const int nb = k / QK_K;
  526. #if QK_K == 256
  527. dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
  528. #else
  529. dequantize_block_q6_K<<<nb, 32, 0, stream>>>(vx, y);
  530. #endif
  531. }
  532. template<typename dst_t>
  533. static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  534. const int nb = k / QK_K;
  535. dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
  536. }
  537. template<typename dst_t>
  538. static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  539. const int nb = k / QK_K;
  540. dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
  541. }
  542. template<typename dst_t>
  543. static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  544. const int nb = k / QK_K;
  545. dequantize_block_iq2_s<<<nb, 32, 0, stream>>>(vx, y);
  546. }
  547. template<typename dst_t>
  548. static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  549. const int nb = k / QK_K;
  550. dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
  551. }
  552. template<typename dst_t>
  553. static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  554. const int nb = k / QK_K;
  555. dequantize_block_iq3_s<<<nb, 32, 0, stream>>>(vx, y);
  556. }
  557. template<typename dst_t>
  558. static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  559. const int nb = k / QK_K;
  560. dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);
  561. }
  562. template<typename dst_t>
  563. static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  564. const int nb = (k + QK_K - 1) / QK_K;
  565. dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
  566. }
  567. template<typename dst_t>
  568. static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  569. const int nb = k / QK_K;
  570. dequantize_block_iq1_m<<<nb, 32, 0, stream>>>(vx, y);
  571. }
  572. template<typename dst_t>
  573. static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  574. const int nb = (k + QK_K - 1) / QK_K;
  575. #if QK_K == 64
  576. dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
  577. #else
  578. dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
  579. #endif
  580. }
  581. template <typename src_t, typename dst_t>
  582. static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
  583. const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
  584. if (i >= k) {
  585. return;
  586. }
  587. const src_t * x = (src_t *) vx;
  588. y[i] = x[i];
  589. }
  590. template <typename src_t, typename dst_t>
  591. static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
  592. const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
  593. convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
  594. }
  595. to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
  596. switch (type) {
  597. case GGML_TYPE_Q4_0:
  598. return dequantize_row_q4_0_cuda;
  599. case GGML_TYPE_Q4_1:
  600. return dequantize_row_q4_1_cuda;
  601. case GGML_TYPE_Q5_0:
  602. return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
  603. case GGML_TYPE_Q5_1:
  604. return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
  605. case GGML_TYPE_Q8_0:
  606. if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= CC_PASCAL) {
  607. return dequantize_block_q8_0_f16_cuda;
  608. }
  609. return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
  610. case GGML_TYPE_Q2_K:
  611. return dequantize_row_q2_K_cuda;
  612. case GGML_TYPE_Q3_K:
  613. return dequantize_row_q3_K_cuda;
  614. case GGML_TYPE_Q4_K:
  615. return dequantize_row_q4_K_cuda;
  616. case GGML_TYPE_Q5_K:
  617. return dequantize_row_q5_K_cuda;
  618. case GGML_TYPE_Q6_K:
  619. return dequantize_row_q6_K_cuda;
  620. case GGML_TYPE_IQ2_XXS:
  621. return dequantize_row_iq2_xxs_cuda;
  622. case GGML_TYPE_IQ2_XS:
  623. return dequantize_row_iq2_xs_cuda;
  624. case GGML_TYPE_IQ2_S:
  625. return dequantize_row_iq2_s_cuda;
  626. case GGML_TYPE_IQ3_XXS:
  627. return dequantize_row_iq3_xxs_cuda;
  628. case GGML_TYPE_IQ1_S:
  629. return dequantize_row_iq1_s_cuda;
  630. case GGML_TYPE_IQ1_M:
  631. return dequantize_row_iq1_m_cuda;
  632. case GGML_TYPE_IQ4_NL:
  633. return dequantize_row_iq4_nl_cuda;
  634. case GGML_TYPE_IQ4_XS:
  635. return dequantize_row_iq4_xs_cuda;
  636. case GGML_TYPE_IQ3_S:
  637. return dequantize_row_iq3_s_cuda;
  638. case GGML_TYPE_F32:
  639. return convert_unary_cuda<float>;
  640. default:
  641. return nullptr;
  642. }
  643. }
  644. to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
  645. switch (type) {
  646. case GGML_TYPE_Q4_0:
  647. return dequantize_row_q4_0_cuda;
  648. case GGML_TYPE_Q4_1:
  649. return dequantize_row_q4_1_cuda;
  650. case GGML_TYPE_Q5_0:
  651. return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
  652. case GGML_TYPE_Q5_1:
  653. return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
  654. case GGML_TYPE_Q8_0:
  655. return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
  656. case GGML_TYPE_Q2_K:
  657. return dequantize_row_q2_K_cuda;
  658. case GGML_TYPE_Q3_K:
  659. return dequantize_row_q3_K_cuda;
  660. case GGML_TYPE_Q4_K:
  661. return dequantize_row_q4_K_cuda;
  662. case GGML_TYPE_Q5_K:
  663. return dequantize_row_q5_K_cuda;
  664. case GGML_TYPE_Q6_K:
  665. return dequantize_row_q6_K_cuda;
  666. case GGML_TYPE_IQ2_XXS:
  667. return dequantize_row_iq2_xxs_cuda;
  668. case GGML_TYPE_IQ2_XS:
  669. return dequantize_row_iq2_xs_cuda;
  670. case GGML_TYPE_IQ2_S:
  671. return dequantize_row_iq2_s_cuda;
  672. case GGML_TYPE_IQ3_XXS:
  673. return dequantize_row_iq3_xxs_cuda;
  674. case GGML_TYPE_IQ1_S:
  675. return dequantize_row_iq1_s_cuda;
  676. case GGML_TYPE_IQ1_M:
  677. return dequantize_row_iq1_m_cuda;
  678. case GGML_TYPE_IQ4_NL:
  679. return dequantize_row_iq4_nl_cuda;
  680. case GGML_TYPE_IQ4_XS:
  681. return dequantize_row_iq4_xs_cuda;
  682. case GGML_TYPE_IQ3_S:
  683. return dequantize_row_iq3_s_cuda;
  684. case GGML_TYPE_F16:
  685. return convert_unary_cuda<half>;
  686. default:
  687. return nullptr;
  688. }
  689. }