norm.cu 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. /**
  2. * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
  3. *
  4. * MIT License
  5. *
  6. * Copyright (c) 2023-2024 The ggml authors
  7. *
  8. * Permission is hereby granted, free of charge, to any person obtaining a copy
  9. * of this software and associated documentation files (the "Software"), to deal
  10. * in the Software without restriction, including without limitation the rights
  11. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  12. * copies of the Software, and to permit persons to whom the Software is
  13. * furnished to do so, subject to the following conditions:
  14. *
  15. * The above copyright notice and this permission notice shall be included in all
  16. * copies or substantial portions of the Software.
  17. *
  18. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  19. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  20. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  21. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  22. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  23. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  24. * SOFTWARE.
  25. */
  26. #include "norm.cuh"
  27. template <int block_size>
  28. static __global__ void norm_f32(const float * x, float * dst, const int ncols, const float eps) {
  29. const int row = blockIdx.x*blockDim.y + threadIdx.y;
  30. const int tid = threadIdx.x;
  31. float2 mean_var = make_float2(0.f, 0.f);
  32. for (int col = tid; col < ncols; col += block_size) {
  33. const float xi = x[row*ncols + col];
  34. mean_var.x += xi;
  35. mean_var.y += xi * xi;
  36. }
  37. // sum up partial sums
  38. mean_var = warp_reduce_sum(mean_var);
  39. if (block_size > WARP_SIZE) {
  40. __shared__ float2 s_sum[32];
  41. int warp_id = threadIdx.x / WARP_SIZE;
  42. int lane_id = threadIdx.x % WARP_SIZE;
  43. if (lane_id == 0) {
  44. s_sum[warp_id] = mean_var;
  45. }
  46. __syncthreads();
  47. mean_var = s_sum[lane_id];
  48. mean_var = warp_reduce_sum(mean_var);
  49. }
  50. const float mean = mean_var.x / ncols;
  51. const float var = mean_var.y / ncols - mean * mean;
  52. const float inv_std = rsqrtf(var + eps);
  53. for (int col = tid; col < ncols; col += block_size) {
  54. dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std;
  55. }
  56. }
  57. template <int block_size>
  58. static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
  59. // blockIdx.x: num_groups idx
  60. // threadIdx.x: block_size idx
  61. int start = blockIdx.x * group_size;
  62. int end = start + group_size;
  63. start += threadIdx.x;
  64. if (end >= ne_elements) {
  65. end = ne_elements;
  66. }
  67. float tmp = 0.0f; // partial sum for thread in warp
  68. for (int j = start; j < end; j += block_size) {
  69. tmp += x[j];
  70. }
  71. tmp = warp_reduce_sum(tmp);
  72. if (block_size > WARP_SIZE) {
  73. __shared__ float s_sum[32];
  74. int warp_id = threadIdx.x / WARP_SIZE;
  75. int lane_id = threadIdx.x % WARP_SIZE;
  76. if (lane_id == 0) {
  77. s_sum[warp_id] = tmp;
  78. }
  79. __syncthreads();
  80. tmp = s_sum[lane_id];
  81. tmp = warp_reduce_sum(tmp);
  82. }
  83. float mean = tmp / group_size;
  84. tmp = 0.0f;
  85. for (int j = start; j < end; j += block_size) {
  86. float xi = x[j] - mean;
  87. dst[j] = xi;
  88. tmp += xi * xi;
  89. }
  90. tmp = warp_reduce_sum(tmp);
  91. if (block_size > WARP_SIZE) {
  92. __shared__ float s_sum[32];
  93. int warp_id = threadIdx.x / WARP_SIZE;
  94. int lane_id = threadIdx.x % WARP_SIZE;
  95. if (lane_id == 0) {
  96. s_sum[warp_id] = tmp;
  97. }
  98. __syncthreads();
  99. tmp = s_sum[lane_id];
  100. tmp = warp_reduce_sum(tmp);
  101. }
  102. float variance = tmp / group_size;
  103. float scale = rsqrtf(variance + eps);
  104. for (int j = start; j < end; j += block_size) {
  105. dst[j] *= scale;
  106. }
  107. }
  108. template <int block_size>
  109. static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
  110. const int row = blockIdx.x*blockDim.y + threadIdx.y;
  111. const int tid = threadIdx.x;
  112. float tmp = 0.0f; // partial sum for thread in warp
  113. for (int col = tid; col < ncols; col += block_size) {
  114. const float xi = x[row*ncols + col];
  115. tmp += xi * xi;
  116. }
  117. // sum up partial sums
  118. tmp = warp_reduce_sum(tmp);
  119. if (block_size > WARP_SIZE) {
  120. __shared__ float s_sum[32];
  121. int warp_id = threadIdx.x / WARP_SIZE;
  122. int lane_id = threadIdx.x % WARP_SIZE;
  123. if (lane_id == 0) {
  124. s_sum[warp_id] = tmp;
  125. }
  126. __syncthreads();
  127. tmp = s_sum[lane_id];
  128. tmp = warp_reduce_sum(tmp);
  129. }
  130. const float mean = tmp / ncols;
  131. const float scale = rsqrtf(mean + eps);
  132. for (int col = tid; col < ncols; col += block_size) {
  133. dst[row*ncols + col] = scale * x[row*ncols + col];
  134. }
  135. }
  136. static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
  137. GGML_ASSERT(ncols % WARP_SIZE == 0);
  138. if (ncols < 1024) {
  139. const dim3 block_dims(WARP_SIZE, 1, 1);
  140. norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
  141. } else {
  142. const dim3 block_dims(1024, 1, 1);
  143. norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
  144. }
  145. }
  146. static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) {
  147. if (group_size < 1024) {
  148. const dim3 block_dims(WARP_SIZE, 1, 1);
  149. group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
  150. } else {
  151. const dim3 block_dims(1024, 1, 1);
  152. group_norm_f32<1024><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
  153. }
  154. }
  155. static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
  156. GGML_ASSERT(ncols % WARP_SIZE == 0);
  157. if (ncols < 1024) {
  158. const dim3 block_dims(WARP_SIZE, 1, 1);
  159. rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
  160. } else {
  161. const dim3 block_dims(1024, 1, 1);
  162. rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
  163. }
  164. }
  165. void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  166. const ggml_tensor * src0 = dst->src[0];
  167. const float * src0_d = (const float *)src0->data;
  168. float * dst_d = (float *)dst->data;
  169. cudaStream_t stream = ctx.stream();
  170. GGML_ASSERT(ggml_is_contiguous(src0));
  171. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  172. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  173. const int64_t ne00 = src0->ne[0];
  174. const int64_t nrows = ggml_nrows(src0);
  175. float eps;
  176. memcpy(&eps, dst->op_params, sizeof(float));
  177. norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
  178. }
  179. void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  180. const ggml_tensor * src0 = dst->src[0];
  181. const float * src0_d = (const float *)src0->data;
  182. float * dst_d = (float *)dst->data;
  183. cudaStream_t stream = ctx.stream();
  184. GGML_ASSERT(ggml_is_contiguous(src0));
  185. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  186. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  187. int num_groups = dst->op_params[0];
  188. float eps;
  189. memcpy(&eps, dst->op_params + 1, sizeof(float));
  190. int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
  191. group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], eps, group_size, ggml_nelements(src0), stream);
  192. }
  193. void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  194. const ggml_tensor * src0 = dst->src[0];
  195. const float * src0_d = (const float *)src0->data;
  196. float * dst_d = (float *)dst->data;
  197. cudaStream_t stream = ctx.stream();
  198. GGML_ASSERT(ggml_is_contiguous(src0));
  199. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  200. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  201. const int64_t ne00 = src0->ne[0];
  202. const int64_t nrows = ggml_nrows(src0);
  203. float eps;
  204. memcpy(&eps, dst->op_params, sizeof(float));
  205. rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
  206. }