quantize.cu 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. #include "quantize.cuh"
  2. #include <cstdint>
  3. static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx0_padded) {
  4. const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
  5. if (ix0 >= kx0_padded) {
  6. return;
  7. }
  8. const int64_t ix1 = blockIdx.y;
  9. const int64_t i_padded = ix1*kx0_padded + ix0;
  10. block_q8_1 * y = (block_q8_1 *) vy;
  11. const int64_t ib = i_padded / QK8_1; // block index
  12. const int64_t iqs = i_padded % QK8_1; // quant index
  13. const float xi = ix0 < kx ? x[ix1*kx + ix0] : 0.0f;
  14. float amax = fabsf(xi);
  15. float sum = xi;
  16. amax = warp_reduce_max(amax);
  17. sum = warp_reduce_sum(sum);
  18. const float d = amax / 127;
  19. const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
  20. y[ib].qs[iqs] = q;
  21. if (iqs > 0) {
  22. return;
  23. }
  24. reinterpret_cast<half&>(y[ib].ds.x) = d;
  25. reinterpret_cast<half&>(y[ib].ds.y) = sum;
  26. }
  27. template <bool need_sum>
  28. static __global__ void quantize_mmq_q8_1(
  29. const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) {
  30. const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
  31. if (ix0 >= kx0_padded) {
  32. return;
  33. }
  34. const int64_t ix1 = kx1*blockIdx.z + blockIdx.y;
  35. block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
  36. const int64_t ib0 = blockIdx.z*(gridDim.y*gridDim.x*blockDim.x/(4*QK8_1)); // first block of channel
  37. const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel
  38. const int64_t iqs = ix0 % (4*QK8_1); // quant index in block
  39. const float xi = ix0 < kx0 ? x[ix1*kx0 + ix0] : 0.0f;
  40. float amax = fabsf(xi);
  41. amax = warp_reduce_max(amax);
  42. float sum;
  43. if (need_sum) {
  44. sum = warp_reduce_sum(xi);
  45. }
  46. const float d = amax / 127;
  47. const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
  48. y[ib].qs[iqs] = q;
  49. if (iqs % QK8_1 != 0) {
  50. return;
  51. }
  52. if (need_sum) {
  53. y[ib].ds[iqs/QK8_1] = make_half2(d, sum);
  54. } else {
  55. ((float *) y[ib].ds)[iqs/QK8_1] = d;
  56. }
  57. }
  58. void quantize_row_q8_1_cuda(
  59. const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
  60. const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
  61. GGML_ASSERT(kx0_padded % QK8_1 == 0);
  62. const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
  63. const dim3 num_blocks(block_num_x, kx1*channels, 1);
  64. const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
  65. quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx0_padded);
  66. GGML_UNUSED(type_x);
  67. }
  68. void quantize_mmq_q8_1_cuda(
  69. const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
  70. const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
  71. GGML_ASSERT(kx0_padded % (4*QK8_1) == 0);
  72. const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
  73. const dim3 num_blocks(block_num_x, kx1, channels);
  74. const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
  75. if (mmq_need_sum(type_x)) {
  76. quantize_mmq_q8_1<true><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
  77. } else {
  78. quantize_mmq_q8_1<false><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
  79. }
  80. }