123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250 |
- /**
- * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
- *
- * MIT License
- *
- * Copyright (c) 2023-2024 The ggml authors
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to deal
- * in the Software without restriction, including without limitation the rights
- * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- * copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
- #include "norm.cuh"
- template <int block_size>
- static __global__ void norm_f32(const float * x, float * dst, const int ncols, const float eps) {
- const int row = blockIdx.x*blockDim.y + threadIdx.y;
- const int tid = threadIdx.x;
- float2 mean_var = make_float2(0.f, 0.f);
- for (int col = tid; col < ncols; col += block_size) {
- const float xi = x[row*ncols + col];
- mean_var.x += xi;
- mean_var.y += xi * xi;
- }
- // sum up partial sums
- mean_var = warp_reduce_sum(mean_var);
- if (block_size > WARP_SIZE) {
- __shared__ float2 s_sum[32];
- int warp_id = threadIdx.x / WARP_SIZE;
- int lane_id = threadIdx.x % WARP_SIZE;
- if (lane_id == 0) {
- s_sum[warp_id] = mean_var;
- }
- __syncthreads();
- mean_var = s_sum[lane_id];
- mean_var = warp_reduce_sum(mean_var);
- }
- const float mean = mean_var.x / ncols;
- const float var = mean_var.y / ncols - mean * mean;
- const float inv_std = rsqrtf(var + eps);
- for (int col = tid; col < ncols; col += block_size) {
- dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std;
- }
- }
- template <int block_size>
- static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
- // blockIdx.x: num_groups idx
- // threadIdx.x: block_size idx
- int start = blockIdx.x * group_size;
- int end = start + group_size;
- start += threadIdx.x;
- if (end >= ne_elements) {
- end = ne_elements;
- }
- float tmp = 0.0f; // partial sum for thread in warp
- for (int j = start; j < end; j += block_size) {
- tmp += x[j];
- }
- tmp = warp_reduce_sum(tmp);
- if (block_size > WARP_SIZE) {
- __shared__ float s_sum[32];
- int warp_id = threadIdx.x / WARP_SIZE;
- int lane_id = threadIdx.x % WARP_SIZE;
- if (lane_id == 0) {
- s_sum[warp_id] = tmp;
- }
- __syncthreads();
- tmp = s_sum[lane_id];
- tmp = warp_reduce_sum(tmp);
- }
- float mean = tmp / group_size;
- tmp = 0.0f;
- for (int j = start; j < end; j += block_size) {
- float xi = x[j] - mean;
- dst[j] = xi;
- tmp += xi * xi;
- }
- tmp = warp_reduce_sum(tmp);
- if (block_size > WARP_SIZE) {
- __shared__ float s_sum[32];
- int warp_id = threadIdx.x / WARP_SIZE;
- int lane_id = threadIdx.x % WARP_SIZE;
- if (lane_id == 0) {
- s_sum[warp_id] = tmp;
- }
- __syncthreads();
- tmp = s_sum[lane_id];
- tmp = warp_reduce_sum(tmp);
- }
- float variance = tmp / group_size;
- float scale = rsqrtf(variance + eps);
- for (int j = start; j < end; j += block_size) {
- dst[j] *= scale;
- }
- }
- template <int block_size>
- static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
- const int row = blockIdx.x*blockDim.y + threadIdx.y;
- const int tid = threadIdx.x;
- float tmp = 0.0f; // partial sum for thread in warp
- for (int col = tid; col < ncols; col += block_size) {
- const float xi = x[row*ncols + col];
- tmp += xi * xi;
- }
- // sum up partial sums
- tmp = warp_reduce_sum(tmp);
- if (block_size > WARP_SIZE) {
- __shared__ float s_sum[32];
- int warp_id = threadIdx.x / WARP_SIZE;
- int lane_id = threadIdx.x % WARP_SIZE;
- if (lane_id == 0) {
- s_sum[warp_id] = tmp;
- }
- __syncthreads();
- tmp = s_sum[lane_id];
- tmp = warp_reduce_sum(tmp);
- }
- const float mean = tmp / ncols;
- const float scale = rsqrtf(mean + eps);
- for (int col = tid; col < ncols; col += block_size) {
- dst[row*ncols + col] = scale * x[row*ncols + col];
- }
- }
- static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
- GGML_ASSERT(ncols % WARP_SIZE == 0);
- if (ncols < 1024) {
- const dim3 block_dims(WARP_SIZE, 1, 1);
- norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
- } else {
- const dim3 block_dims(1024, 1, 1);
- norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
- }
- }
- static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) {
- if (group_size < 1024) {
- const dim3 block_dims(WARP_SIZE, 1, 1);
- group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
- } else {
- const dim3 block_dims(1024, 1, 1);
- group_norm_f32<1024><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
- }
- }
- static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
- GGML_ASSERT(ncols % WARP_SIZE == 0);
- if (ncols < 1024) {
- const dim3 block_dims(WARP_SIZE, 1, 1);
- rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
- } else {
- const dim3 block_dims(1024, 1, 1);
- rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
- }
- }
- void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
- const ggml_tensor * src0 = dst->src[0];
- const float * src0_d = (const float *)src0->data;
- float * dst_d = (float *)dst->data;
- cudaStream_t stream = ctx.stream();
- GGML_ASSERT(ggml_is_contiguous(src0));
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
- const int64_t ne00 = src0->ne[0];
- const int64_t nrows = ggml_nrows(src0);
- float eps;
- memcpy(&eps, dst->op_params, sizeof(float));
- norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
- }
- void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
- const ggml_tensor * src0 = dst->src[0];
- const float * src0_d = (const float *)src0->data;
- float * dst_d = (float *)dst->data;
- cudaStream_t stream = ctx.stream();
- GGML_ASSERT(ggml_is_contiguous(src0));
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
- int num_groups = dst->op_params[0];
- float eps;
- memcpy(&eps, dst->op_params + 1, sizeof(float));
- int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
- group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], eps, group_size, ggml_nelements(src0), stream);
- }
- void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
- const ggml_tensor * src0 = dst->src[0];
- const float * src0_d = (const float *)src0->data;
- float * dst_d = (float *)dst->data;
- cudaStream_t stream = ctx.stream();
- GGML_ASSERT(ggml_is_contiguous(src0));
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
- const int64_t ne00 = src0->ne[0];
- const int64_t nrows = ggml_nrows(src0);
- float eps;
- memcpy(&eps, dst->op_params, sizeof(float));
- rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
- }
|