norm.cu 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. #include "norm.cuh"
  2. template <int block_size>
  3. static __global__ void norm_f32(const float * x, float * dst, const int ncols, const float eps) {
  4. const int row = blockIdx.x*blockDim.y + threadIdx.y;
  5. const int tid = threadIdx.x;
  6. float2 mean_var = make_float2(0.f, 0.f);
  7. for (int col = tid; col < ncols; col += block_size) {
  8. const float xi = x[row*ncols + col];
  9. mean_var.x += xi;
  10. mean_var.y += xi * xi;
  11. }
  12. // sum up partial sums
  13. mean_var = warp_reduce_sum(mean_var);
  14. if (block_size > WARP_SIZE) {
  15. __shared__ float2 s_sum[32];
  16. int warp_id = threadIdx.x / WARP_SIZE;
  17. int lane_id = threadIdx.x % WARP_SIZE;
  18. if (lane_id == 0) {
  19. s_sum[warp_id] = mean_var;
  20. }
  21. __syncthreads();
  22. mean_var = s_sum[lane_id];
  23. mean_var = warp_reduce_sum(mean_var);
  24. }
  25. const float mean = mean_var.x / ncols;
  26. const float var = mean_var.y / ncols - mean * mean;
  27. const float inv_std = rsqrtf(var + eps);
  28. for (int col = tid; col < ncols; col += block_size) {
  29. dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std;
  30. }
  31. }
  32. template <int block_size>
  33. static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
  34. // blockIdx.x: num_groups idx
  35. // threadIdx.x: block_size idx
  36. int start = blockIdx.x * group_size;
  37. int end = start + group_size;
  38. start += threadIdx.x;
  39. if (end >= ne_elements) {
  40. end = ne_elements;
  41. }
  42. float tmp = 0.0f; // partial sum for thread in warp
  43. for (int j = start; j < end; j += block_size) {
  44. tmp += x[j];
  45. }
  46. tmp = warp_reduce_sum(tmp);
  47. if (block_size > WARP_SIZE) {
  48. __shared__ float s_sum[32];
  49. int warp_id = threadIdx.x / WARP_SIZE;
  50. int lane_id = threadIdx.x % WARP_SIZE;
  51. if (lane_id == 0) {
  52. s_sum[warp_id] = tmp;
  53. }
  54. __syncthreads();
  55. tmp = s_sum[lane_id];
  56. tmp = warp_reduce_sum(tmp);
  57. }
  58. float mean = tmp / group_size;
  59. tmp = 0.0f;
  60. for (int j = start; j < end; j += block_size) {
  61. float xi = x[j] - mean;
  62. dst[j] = xi;
  63. tmp += xi * xi;
  64. }
  65. tmp = warp_reduce_sum(tmp);
  66. if (block_size > WARP_SIZE) {
  67. __shared__ float s_sum[32];
  68. int warp_id = threadIdx.x / WARP_SIZE;
  69. int lane_id = threadIdx.x % WARP_SIZE;
  70. if (lane_id == 0) {
  71. s_sum[warp_id] = tmp;
  72. }
  73. __syncthreads();
  74. tmp = s_sum[lane_id];
  75. tmp = warp_reduce_sum(tmp);
  76. }
  77. float variance = tmp / group_size;
  78. float scale = rsqrtf(variance + eps);
  79. for (int j = start; j < end; j += block_size) {
  80. dst[j] *= scale;
  81. }
  82. }
  83. template <int block_size>
  84. static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
  85. const int row = blockIdx.x*blockDim.y + threadIdx.y;
  86. const int tid = threadIdx.x;
  87. float tmp = 0.0f; // partial sum for thread in warp
  88. for (int col = tid; col < ncols; col += block_size) {
  89. const float xi = x[row*ncols + col];
  90. tmp += xi * xi;
  91. }
  92. // sum up partial sums
  93. tmp = warp_reduce_sum(tmp);
  94. if (block_size > WARP_SIZE) {
  95. __shared__ float s_sum[32];
  96. int warp_id = threadIdx.x / WARP_SIZE;
  97. int lane_id = threadIdx.x % WARP_SIZE;
  98. if (lane_id == 0) {
  99. s_sum[warp_id] = tmp;
  100. }
  101. __syncthreads();
  102. tmp = s_sum[lane_id];
  103. tmp = warp_reduce_sum(tmp);
  104. }
  105. const float mean = tmp / ncols;
  106. const float scale = rsqrtf(mean + eps);
  107. for (int col = tid; col < ncols; col += block_size) {
  108. dst[row*ncols + col] = scale * x[row*ncols + col];
  109. }
  110. }
  111. static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
  112. GGML_ASSERT(ncols % WARP_SIZE == 0);
  113. if (ncols < 1024) {
  114. const dim3 block_dims(WARP_SIZE, 1, 1);
  115. norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
  116. } else {
  117. const dim3 block_dims(1024, 1, 1);
  118. norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
  119. }
  120. }
  121. static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const int group_size, const int ne_elements, cudaStream_t stream) {
  122. static const float eps = 1e-6f;
  123. if (group_size < 1024) {
  124. const dim3 block_dims(WARP_SIZE, 1, 1);
  125. group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
  126. } else {
  127. const dim3 block_dims(1024, 1, 1);
  128. group_norm_f32<1024><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
  129. }
  130. }
  131. static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
  132. GGML_ASSERT(ncols % WARP_SIZE == 0);
  133. if (ncols < 1024) {
  134. const dim3 block_dims(WARP_SIZE, 1, 1);
  135. rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
  136. } else {
  137. const dim3 block_dims(1024, 1, 1);
  138. rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
  139. }
  140. }
  141. void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  142. const ggml_tensor * src0 = dst->src[0];
  143. const float * src0_d = (const float *)src0->data;
  144. float * dst_d = (float *)dst->data;
  145. cudaStream_t stream = ctx.stream();
  146. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  147. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  148. const int64_t ne00 = src0->ne[0];
  149. const int64_t nrows = ggml_nrows(src0);
  150. float eps;
  151. memcpy(&eps, dst->op_params, sizeof(float));
  152. norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
  153. }
  154. void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  155. const ggml_tensor * src0 = dst->src[0];
  156. const float * src0_d = (const float *)src0->data;
  157. float * dst_d = (float *)dst->data;
  158. cudaStream_t stream = ctx.stream();
  159. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  160. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  161. int num_groups = dst->op_params[0];
  162. int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
  163. group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], group_size, ggml_nelements(src0), stream);
  164. }
  165. void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  166. const ggml_tensor * src0 = dst->src[0];
  167. const float * src0_d = (const float *)src0->data;
  168. float * dst_d = (float *)dst->data;
  169. cudaStream_t stream = ctx.stream();
  170. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  171. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  172. const int64_t ne00 = src0->ne[0];
  173. const int64_t nrows = ggml_nrows(src0);
  174. float eps;
  175. memcpy(&eps, dst->op_params, sizeof(float));
  176. rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
  177. }