convert.cu 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712
  1. /**
  2. * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
  3. *
  4. * MIT License
  5. *
  6. * Copyright (c) 2023-2024 The ggml authors
  7. *
  8. * Permission is hereby granted, free of charge, to any person obtaining a copy
  9. * of this software and associated documentation files (the "Software"), to deal
  10. * in the Software without restriction, including without limitation the rights
  11. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  12. * copies of the Software, and to permit persons to whom the Software is
  13. * furnished to do so, subject to the following conditions:
  14. *
  15. * The above copyright notice and this permission notice shall be included in all
  16. * copies or substantial portions of the Software.
  17. *
  18. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  19. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  20. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  21. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  22. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  23. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  24. * SOFTWARE.
  25. */
  26. #include "convert.cuh"
  27. #include "dequantize.cuh"
  28. #define CUDA_Q8_0_NE_ALIGN 2048
  29. template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
  30. static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
  31. const int64_t i = (int64_t)2*(blockDim.x*blockIdx.x + threadIdx.x);
  32. if (i >= k) {
  33. return;
  34. }
  35. const int64_t ib = i/qk; // block index
  36. const int64_t iqs = (i%qk)/qr; // quant index
  37. const int64_t iybs = i - i%qk; // y block start index
  38. const int64_t y_offset = qr == 1 ? 1 : qk/2;
  39. // dequantize
  40. dfloat2 v;
  41. dequantize_kernel(vx, ib, iqs, v);
  42. y[iybs + iqs + 0] = v.x;
  43. y[iybs + iqs + y_offset] = v.y;
  44. }
  45. template <bool need_check>
  46. static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int64_t k) {
  47. #if __CUDA_ARCH__ >= CC_PASCAL
  48. constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE;
  49. const int64_t i0 = CUDA_Q8_0_NE_ALIGN*blockIdx.x;
  50. const int * x0 = ((int *) vx) + blockIdx.x * nint;
  51. half2 * y2 = (half2 *) (y + i0);
  52. __shared__ int vals[nint];
  53. #pragma unroll
  54. for (int ix0 = 0; ix0 < nint; ix0 += WARP_SIZE) {
  55. if (need_check && i0*sizeof(block_q8_0)/QK8_0 + sizeof(int)*(ix0 + threadIdx.x) >= k*sizeof(block_q8_0)/QK8_0) {
  56. break;
  57. }
  58. const int ix = ix0 + threadIdx.x;
  59. vals[ix] = x0[ix];
  60. }
  61. __syncthreads();
  62. #pragma unroll
  63. for (int iy = 0; iy < CUDA_Q8_0_NE_ALIGN; iy += 2*WARP_SIZE) {
  64. if (need_check && i0 + iy + 2*threadIdx.x >= k) {
  65. return;
  66. }
  67. const half * b0 = ((const half *) vals) + (sizeof(block_q8_0)/sizeof(half)) * ((iy + 2*threadIdx.x)/QK8_0);
  68. const half d = *b0;
  69. const char2 qs = ((const char2 *) (b0 + 1))[threadIdx.x % (QK8_0/2)];
  70. y2[iy/2 + threadIdx.x] = __hmul2(make_half2(qs.x, qs.y), __half2half2(d));
  71. }
  72. #else
  73. GGML_UNUSED(vx);
  74. GGML_UNUSED(y);
  75. GGML_UNUSED(k);
  76. NO_DEVICE_CODE;
  77. #endif // __CUDA_ARCH__ >= CC_PASCAL
  78. }
  79. template<typename dst_t>
  80. static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
  81. const int64_t i = blockIdx.x;
  82. // assume 32 threads
  83. const int64_t tid = threadIdx.x;
  84. const int64_t il = tid/8;
  85. const int64_t ir = tid%8;
  86. const int64_t ib = 8*i + ir;
  87. if (ib >= nb32) {
  88. return;
  89. }
  90. dst_t * y = yy + 256*i + 32*ir + 4*il;
  91. const block_q4_0 * x = (const block_q4_0 *)vx + ib;
  92. const float d = __half2float(x->d);
  93. const float dm = -8*d;
  94. const uint8_t * q = x->qs + 4*il;
  95. for (int l = 0; l < 4; ++l) {
  96. y[l+ 0] = d * (q[l] & 0xF) + dm;
  97. y[l+16] = d * (q[l] >> 4) + dm;
  98. }
  99. }
  100. template<typename dst_t>
  101. static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
  102. const int64_t i = blockIdx.x;
  103. // assume 32 threads
  104. const int64_t tid = threadIdx.x;
  105. const int64_t il = tid/8;
  106. const int64_t ir = tid%8;
  107. const int64_t ib = 8*i + ir;
  108. if (ib >= nb32) {
  109. return;
  110. }
  111. dst_t * y = yy + 256*i + 32*ir + 4*il;
  112. const block_q4_1 * x = (const block_q4_1 *)vx + ib;
  113. const float2 d = __half22float2(x->dm);
  114. const uint8_t * q = x->qs + 4*il;
  115. for (int l = 0; l < 4; ++l) {
  116. y[l+ 0] = d.x * (q[l] & 0xF) + d.y;
  117. y[l+16] = d.x * (q[l] >> 4) + d.y;
  118. }
  119. }
  120. //================================== k-quants
  121. template<typename dst_t>
  122. static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  123. const int64_t i = blockIdx.x;
  124. const block_q2_K * x = (const block_q2_K *) vx;
  125. const int64_t tid = threadIdx.x;
  126. const int64_t n = tid/32;
  127. const int64_t l = tid - 32*n;
  128. const int64_t is = 8*n + l/16;
  129. const uint8_t q = x[i].qs[32*n + l];
  130. dst_t * y = yy + i*QK_K + 128*n;
  131. float dall = __low2half(x[i].dm);
  132. float dmin = __high2half(x[i].dm);
  133. y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
  134. y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
  135. y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
  136. y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
  137. }
  138. template<typename dst_t>
  139. static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  140. const int64_t i = blockIdx.x;
  141. const block_q3_K * x = (const block_q3_K *) vx;
  142. const int64_t r = threadIdx.x/4;
  143. const int64_t tid = r/2;
  144. const int64_t is0 = r%2;
  145. const int64_t l0 = 16*is0 + 4*(threadIdx.x%4);
  146. const int64_t n = tid / 4;
  147. const int64_t j = tid - 4*n;
  148. uint8_t m = 1 << (4*n + j);
  149. int64_t is = 8*n + 2*j + is0;
  150. int shift = 2*j;
  151. int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
  152. is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) :
  153. is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) :
  154. (x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4);
  155. float d_all = x[i].d;
  156. float dl = d_all * (us - 32);
  157. dst_t * y = yy + i*QK_K + 128*n + 32*j;
  158. const uint8_t * q = x[i].qs + 32*n;
  159. const uint8_t * hm = x[i].hmask;
  160. for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
  161. }
  162. static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
  163. if (j < 4) {
  164. d = q[j] & 63; m = q[j + 4] & 63;
  165. } else {
  166. d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
  167. m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
  168. }
  169. }
  170. template<typename dst_t>
  171. static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  172. const block_q4_K * x = (const block_q4_K *) vx;
  173. const int64_t i = blockIdx.x;
  174. // assume 32 threads
  175. const int64_t tid = threadIdx.x;
  176. const int64_t il = tid/8;
  177. const int64_t ir = tid%8;
  178. const int64_t is = 2*il;
  179. const int64_t n = 4;
  180. dst_t * y = yy + i*QK_K + 64*il + n*ir;
  181. const float dall = __low2half(x[i].dm);
  182. const float dmin = __high2half(x[i].dm);
  183. const uint8_t * q = x[i].qs + 32*il + n*ir;
  184. uint8_t sc, m;
  185. get_scale_min_k4(is + 0, x[i].scales, sc, m);
  186. const float d1 = dall * sc; const float m1 = dmin * m;
  187. get_scale_min_k4(is + 1, x[i].scales, sc, m);
  188. const float d2 = dall * sc; const float m2 = dmin * m;
  189. for (int l = 0; l < n; ++l) {
  190. y[l + 0] = d1 * (q[l] & 0xF) - m1;
  191. y[l +32] = d2 * (q[l] >> 4) - m2;
  192. }
  193. }
  194. template<typename dst_t>
  195. static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  196. const block_q5_K * x = (const block_q5_K *) vx;
  197. const int64_t i = blockIdx.x;
  198. // assume 64 threads - this is very slightly better than the one below
  199. const int64_t tid = threadIdx.x;
  200. const int64_t il = tid/16; // il is in 0...3
  201. const int64_t ir = tid%16; // ir is in 0...15
  202. const int64_t is = 2*il; // is is in 0...6
  203. dst_t * y = yy + i*QK_K + 64*il + 2*ir;
  204. const float dall = __low2half(x[i].dm);
  205. const float dmin = __high2half(x[i].dm);
  206. const uint8_t * ql = x[i].qs + 32*il + 2*ir;
  207. const uint8_t * qh = x[i].qh + 2*ir;
  208. uint8_t sc, m;
  209. get_scale_min_k4(is + 0, x[i].scales, sc, m);
  210. const float d1 = dall * sc; const float m1 = dmin * m;
  211. get_scale_min_k4(is + 1, x[i].scales, sc, m);
  212. const float d2 = dall * sc; const float m2 = dmin * m;
  213. uint8_t hm = 1 << (2*il);
  214. y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1;
  215. y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1;
  216. hm <<= 1;
  217. y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
  218. y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
  219. }
  220. template<typename dst_t>
  221. static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  222. const block_q6_K * x = (const block_q6_K *) vx;
  223. const int64_t i = blockIdx.x;
  224. // assume 64 threads - this is very slightly better than the one below
  225. const int64_t tid = threadIdx.x;
  226. const int64_t ip = tid/32; // ip is 0 or 1
  227. const int64_t il = tid - 32*ip; // 0...32
  228. const int64_t is = 8*ip + il/16;
  229. dst_t * y = yy + i*QK_K + 128*ip + il;
  230. const float d = x[i].d;
  231. const uint8_t * ql = x[i].ql + 64*ip + il;
  232. const uint8_t qh = x[i].qh[32*ip + il];
  233. const int8_t * sc = x[i].scales + is;
  234. y[ 0] = d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
  235. y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
  236. y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
  237. y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
  238. }
  239. template<typename dst_t>
  240. static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  241. const int64_t i = blockIdx.x;
  242. const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
  243. const int64_t tid = threadIdx.x;
  244. const int64_t il = tid/8; // 0...3
  245. const int64_t ib = tid%8; // 0...7
  246. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  247. const uint16_t * q2 = x[i].qs + 4*ib;
  248. const uint8_t * aux8 = (const uint8_t *)q2;
  249. const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[il]);
  250. const uint32_t aux32 = q2[2] | (q2[3] << 16);
  251. const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f;
  252. const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
  253. for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
  254. }
  255. template<typename dst_t>
  256. static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  257. const int64_t i = blockIdx.x;
  258. const block_iq2_xs * x = (const block_iq2_xs *) vx;
  259. const int64_t tid = threadIdx.x;
  260. const int64_t il = tid/8; // 0...3
  261. const int64_t ib = tid%8; // 0...7
  262. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  263. const uint16_t * q2 = x[i].qs + 4*ib;
  264. const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
  265. const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
  266. const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
  267. for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
  268. }
  269. template<typename dst_t>
  270. static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  271. const int64_t i = blockIdx.x;
  272. const block_iq2_s * x = (const block_iq2_s *) vx;
  273. const int64_t tid = threadIdx.x;
  274. const int64_t il = tid/8; // 0...3
  275. const int64_t ib = tid%8; // 0...7
  276. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  277. const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
  278. const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
  279. const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
  280. for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
  281. }
  282. template<typename dst_t>
  283. static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  284. const int64_t i = blockIdx.x;
  285. const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
  286. const int64_t tid = threadIdx.x;
  287. const int64_t il = tid/8; // 0...3
  288. const int64_t ib = tid%8; // 0...7
  289. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  290. const uint8_t * q3 = x[i].qs + 8*ib;
  291. const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
  292. const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*il+0]);
  293. const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*il+1]);
  294. const uint32_t aux32 = gas[0] | (gas[1] << 16);
  295. const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.5f;
  296. const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
  297. for (int j = 0; j < 4; ++j) {
  298. y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
  299. y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
  300. }
  301. }
  302. template<typename dst_t>
  303. static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  304. const int64_t i = blockIdx.x;
  305. const block_iq3_s * x = (const block_iq3_s *) vx;
  306. const int64_t tid = threadIdx.x;
  307. const int64_t il = tid/8; // 0...3
  308. const int64_t ib = tid%8; // 0...7
  309. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  310. const uint8_t * qs = x[i].qs + 8*ib;
  311. const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
  312. const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));
  313. const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf));
  314. const uint8_t signs = x[i].signs[4*ib + il];
  315. for (int j = 0; j < 4; ++j) {
  316. y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
  317. y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
  318. }
  319. }
  320. template<typename dst_t>
  321. static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  322. const int64_t i = blockIdx.x;
  323. const block_iq1_s * x = (const block_iq1_s *) vx;
  324. const int64_t tid = threadIdx.x;
  325. const int64_t il = tid/8; // 0...3
  326. const int64_t ib = tid%8; // 0...7
  327. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  328. const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
  329. const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
  330. uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
  331. grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)];
  332. grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
  333. grid32[0] &= 0x0f0f0f0f;
  334. for (int j = 0; j < 8; ++j) {
  335. y[j] = d * (q[j] + delta);
  336. }
  337. }
  338. template<typename dst_t>
  339. static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  340. const int64_t i = blockIdx.x;
  341. const block_iq1_m * x = (const block_iq1_m *) vx;
  342. const int64_t tid = threadIdx.x;
  343. const int64_t il = tid/8; // 0...3
  344. const int64_t ib = tid%8; // 0...7
  345. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  346. const uint16_t * sc = (const uint16_t *)x[i].scales;
  347. iq1m_scale_t scale;
  348. scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
  349. const int64_t ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
  350. const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1);
  351. const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;
  352. uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
  353. grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[2*ib+il/2] >> 4*(il%2)) & 7) << 8)];
  354. grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
  355. grid32[0] &= 0x0f0f0f0f;
  356. for (int j = 0; j < 8; ++j) {
  357. y[j] = d * (q[j] + delta);
  358. }
  359. }
  360. template<typename dst_t>
  361. static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  362. const int64_t i = blockIdx.x;
  363. const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
  364. const int64_t tid = threadIdx.x;
  365. const int64_t il = tid/8; // 0...3
  366. const int64_t ib = tid%8; // 0...7
  367. dst_t * y = yy + i*QK_K + 32*ib + 4*il;
  368. const uint8_t * q4 = x[ib].qs + 4*il;
  369. const float d = (float)x[ib].d;
  370. for (int j = 0; j < 4; ++j) {
  371. y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
  372. y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
  373. }
  374. }
  375. template<typename dst_t>
  376. static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  377. const int64_t i = blockIdx.x;
  378. const block_iq4_xs * x = (const block_iq4_xs *)vx;
  379. const int64_t tid = threadIdx.x;
  380. const int64_t il = tid/8; // 0...3
  381. const int64_t ib = tid%8; // 0...7
  382. dst_t * y = yy + i*QK_K + 32*ib + 4*il;
  383. const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
  384. const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
  385. for (int j = 0; j < 4; ++j) {
  386. y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
  387. y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
  388. }
  389. }
  390. template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
  391. static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
  392. const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
  393. dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
  394. }
  395. static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) {
  396. const int num_blocks = (k + CUDA_Q8_0_NE_ALIGN - 1) / CUDA_Q8_0_NE_ALIGN;
  397. if (k % CUDA_Q8_0_NE_ALIGN == 0) {
  398. const bool need_check = false;
  399. dequantize_block_q8_0_f16<need_check><<<num_blocks, WARP_SIZE, 0, stream>>>(vx, y, k);
  400. } else {
  401. const bool need_check = true;
  402. dequantize_block_q8_0_f16<need_check><<<num_blocks, WARP_SIZE, 0, stream>>>(vx, y, k);
  403. }
  404. }
  405. template<typename dst_t>
  406. static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  407. const int nb = k / QK_K;
  408. dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
  409. }
  410. template<typename dst_t>
  411. static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  412. const int nb = k / QK_K;
  413. dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
  414. }
  415. template<typename dst_t>
  416. static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  417. const int nb32 = k / 32;
  418. const int nb = (k + 255) / 256;
  419. dequantize_block_q4_0<<<nb, 32, 0, stream>>>(vx, y, nb32);
  420. }
  421. template<typename dst_t>
  422. static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  423. const int nb32 = k / 32;
  424. const int nb = (k + 255) / 256;
  425. dequantize_block_q4_1<<<nb, 32, 0, stream>>>(vx, y, nb32);
  426. }
  427. template<typename dst_t>
  428. static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  429. const int nb = k / QK_K;
  430. dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
  431. }
  432. template<typename dst_t>
  433. static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  434. const int nb = k / QK_K;
  435. dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
  436. }
  437. template<typename dst_t>
  438. static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  439. const int nb = k / QK_K;
  440. dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
  441. }
  442. template<typename dst_t>
  443. static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  444. const int nb = k / QK_K;
  445. dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
  446. }
  447. template<typename dst_t>
  448. static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  449. const int nb = k / QK_K;
  450. dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
  451. }
  452. template<typename dst_t>
  453. static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  454. const int nb = k / QK_K;
  455. dequantize_block_iq2_s<<<nb, 32, 0, stream>>>(vx, y);
  456. }
  457. template<typename dst_t>
  458. static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  459. const int nb = k / QK_K;
  460. dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
  461. }
  462. template<typename dst_t>
  463. static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  464. const int nb = k / QK_K;
  465. dequantize_block_iq3_s<<<nb, 32, 0, stream>>>(vx, y);
  466. }
  467. template<typename dst_t>
  468. static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  469. const int nb = k / QK_K;
  470. dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);
  471. }
  472. template<typename dst_t>
  473. static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  474. const int nb = (k + QK_K - 1) / QK_K;
  475. dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
  476. }
  477. template<typename dst_t>
  478. static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  479. const int nb = k / QK_K;
  480. dequantize_block_iq1_m<<<nb, 32, 0, stream>>>(vx, y);
  481. }
  482. template<typename dst_t>
  483. static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  484. const int nb = (k + QK_K - 1) / QK_K;
  485. dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
  486. }
  487. template <typename src_t, typename dst_t>
  488. static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
  489. const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
  490. if (i >= k) {
  491. return;
  492. }
  493. const src_t * x = (src_t *) vx;
  494. y[i] = x[i];
  495. }
  496. template <typename src_t, typename dst_t>
  497. static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
  498. const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
  499. convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
  500. }
  501. to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
  502. switch (type) {
  503. case GGML_TYPE_Q4_0:
  504. return dequantize_row_q4_0_cuda;
  505. case GGML_TYPE_Q4_1:
  506. return dequantize_row_q4_1_cuda;
  507. case GGML_TYPE_Q5_0:
  508. return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
  509. case GGML_TYPE_Q5_1:
  510. return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
  511. case GGML_TYPE_Q8_0:
  512. if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= CC_PASCAL) {
  513. return dequantize_block_q8_0_f16_cuda;
  514. }
  515. return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
  516. case GGML_TYPE_Q2_K:
  517. return dequantize_row_q2_K_cuda;
  518. case GGML_TYPE_Q3_K:
  519. return dequantize_row_q3_K_cuda;
  520. case GGML_TYPE_Q4_K:
  521. return dequantize_row_q4_K_cuda;
  522. case GGML_TYPE_Q5_K:
  523. return dequantize_row_q5_K_cuda;
  524. case GGML_TYPE_Q6_K:
  525. return dequantize_row_q6_K_cuda;
  526. case GGML_TYPE_IQ2_XXS:
  527. return dequantize_row_iq2_xxs_cuda;
  528. case GGML_TYPE_IQ2_XS:
  529. return dequantize_row_iq2_xs_cuda;
  530. case GGML_TYPE_IQ2_S:
  531. return dequantize_row_iq2_s_cuda;
  532. case GGML_TYPE_IQ3_XXS:
  533. return dequantize_row_iq3_xxs_cuda;
  534. case GGML_TYPE_IQ1_S:
  535. return dequantize_row_iq1_s_cuda;
  536. case GGML_TYPE_IQ1_M:
  537. return dequantize_row_iq1_m_cuda;
  538. case GGML_TYPE_IQ4_NL:
  539. return dequantize_row_iq4_nl_cuda;
  540. case GGML_TYPE_IQ4_XS:
  541. return dequantize_row_iq4_xs_cuda;
  542. case GGML_TYPE_IQ3_S:
  543. return dequantize_row_iq3_s_cuda;
  544. case GGML_TYPE_F32:
  545. return convert_unary_cuda<float>;
  546. default:
  547. return nullptr;
  548. }
  549. }
  550. to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
  551. switch (type) {
  552. case GGML_TYPE_Q4_0:
  553. return dequantize_row_q4_0_cuda;
  554. case GGML_TYPE_Q4_1:
  555. return dequantize_row_q4_1_cuda;
  556. case GGML_TYPE_Q5_0:
  557. return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
  558. case GGML_TYPE_Q5_1:
  559. return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
  560. case GGML_TYPE_Q8_0:
  561. return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
  562. case GGML_TYPE_Q2_K:
  563. return dequantize_row_q2_K_cuda;
  564. case GGML_TYPE_Q3_K:
  565. return dequantize_row_q3_K_cuda;
  566. case GGML_TYPE_Q4_K:
  567. return dequantize_row_q4_K_cuda;
  568. case GGML_TYPE_Q5_K:
  569. return dequantize_row_q5_K_cuda;
  570. case GGML_TYPE_Q6_K:
  571. return dequantize_row_q6_K_cuda;
  572. case GGML_TYPE_IQ2_XXS:
  573. return dequantize_row_iq2_xxs_cuda;
  574. case GGML_TYPE_IQ2_XS:
  575. return dequantize_row_iq2_xs_cuda;
  576. case GGML_TYPE_IQ2_S:
  577. return dequantize_row_iq2_s_cuda;
  578. case GGML_TYPE_IQ3_XXS:
  579. return dequantize_row_iq3_xxs_cuda;
  580. case GGML_TYPE_IQ1_S:
  581. return dequantize_row_iq1_s_cuda;
  582. case GGML_TYPE_IQ1_M:
  583. return dequantize_row_iq1_m_cuda;
  584. case GGML_TYPE_IQ4_NL:
  585. return dequantize_row_iq4_nl_cuda;
  586. case GGML_TYPE_IQ4_XS:
  587. return dequantize_row_iq4_xs_cuda;
  588. case GGML_TYPE_IQ3_S:
  589. return dequantize_row_iq3_s_cuda;
  590. case GGML_TYPE_F16:
  591. return convert_unary_cuda<half>;
  592. default:
  593. return nullptr;
  594. }
  595. }