unary.cu 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. #include "unary.cuh"
  2. static __global__ void gelu_f32(const float * x, float * dst, const int k) {
  3. const float GELU_COEF_A = 0.044715f;
  4. const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
  5. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  6. if (i >= k) {
  7. return;
  8. }
  9. float xi = x[i];
  10. dst[i] = 0.5f*xi*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi)));
  11. }
  12. static __global__ void gelu_quick_f32(const float * x, float * dst, int k) {
  13. const float GELU_QUICK_COEF = -1.702f;
  14. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  15. if (i >= k) {
  16. return;
  17. }
  18. dst[i] = x[i] * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x[i])));
  19. }
  20. static __global__ void silu_f32(const float * x, float * dst, const int k) {
  21. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  22. if (i >= k) {
  23. return;
  24. }
  25. dst[i] = x[i] / (1.0f + expf(-x[i]));
  26. }
  27. static __global__ void tanh_f32(const float * x, float * dst, int k) {
  28. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  29. if (i >= k) {
  30. return;
  31. }
  32. dst[i] = tanhf(x[i]);
  33. }
  34. static __global__ void relu_f32(const float * x, float * dst, const int k) {
  35. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  36. if (i >= k) {
  37. return;
  38. }
  39. dst[i] = fmaxf(x[i], 0);
  40. }
  41. static __global__ void sigmoid_f32(const float * x, float * dst, const int k) {
  42. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  43. if (i >= k) {
  44. return;
  45. }
  46. dst[i] = 1.0f / (1.0f + expf(-x[i]));
  47. }
  48. static __global__ void hardsigmoid_f32(const float * x, float * dst, const int k) {
  49. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  50. if (i >= k) {
  51. return;
  52. }
  53. dst[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
  54. }
  55. static __global__ void hardswish_f32(const float * x, float * dst, const int k) {
  56. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  57. if (i >= k) {
  58. return;
  59. }
  60. dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
  61. }
  62. static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
  63. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  64. if (i >= k) {
  65. return;
  66. }
  67. dst[i] = fmaxf(x[i], 0) + fminf(x[i], 0.0f) * negative_slope;
  68. }
  69. static __global__ void sqr_f32(const float * x, float * dst, const int k) {
  70. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  71. if (i >= k) {
  72. return;
  73. }
  74. dst[i] = x[i] * x[i];
  75. }
  76. static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
  77. const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
  78. gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
  79. }
  80. static void gelu_quick_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
  81. const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
  82. gelu_quick_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
  83. }
  84. static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
  85. const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
  86. silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
  87. }
  88. static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
  89. const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
  90. tanh_f32<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
  91. }
  92. static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
  93. const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
  94. relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
  95. }
  96. static void sigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
  97. const int num_blocks = (k + CUDA_SIGMOID_BLOCK_SIZE - 1) / CUDA_SIGMOID_BLOCK_SIZE;
  98. sigmoid_f32<<<num_blocks, CUDA_SIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
  99. }
  100. static void hardsigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
  101. const int num_blocks = (k + CUDA_HARDSIGMOID_BLOCK_SIZE - 1) / CUDA_HARDSIGMOID_BLOCK_SIZE;
  102. hardsigmoid_f32<<<num_blocks, CUDA_HARDSIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
  103. }
  104. static void hardswish_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
  105. const int num_blocks = (k + CUDA_HARDSWISH_BLOCK_SIZE - 1) / CUDA_HARDSWISH_BLOCK_SIZE;
  106. hardswish_f32<<<num_blocks, CUDA_HARDSWISH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
  107. }
  108. static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
  109. const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
  110. leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
  111. }
  112. static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
  113. const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE;
  114. sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
  115. }
  116. void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  117. const ggml_tensor * src0 = dst->src[0];
  118. const float * src0_d = (const float *)src0->data;
  119. float * dst_d = (float *)dst->data;
  120. cudaStream_t stream = ctx.stream();
  121. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  122. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  123. gelu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
  124. }
  125. void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  126. const ggml_tensor * src0 = dst->src[0];
  127. const float * src0_d = (const float *)src0->data;
  128. float * dst_d = (float *)dst->data;
  129. cudaStream_t stream = ctx.stream();
  130. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  131. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  132. silu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
  133. }
  134. void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  135. const ggml_tensor * src0 = dst->src[0];
  136. const float * src0_d = (const float *)src0->data;
  137. float * dst_d = (float *)dst->data;
  138. cudaStream_t stream = ctx.stream();
  139. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  140. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  141. gelu_quick_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
  142. }
  143. void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  144. const ggml_tensor * src0 = dst->src[0];
  145. const float * src0_d = (const float *)src0->data;
  146. float * dst_d = (float *)dst->data;
  147. cudaStream_t stream = ctx.stream();
  148. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  149. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  150. tanh_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
  151. }
  152. void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  153. const ggml_tensor * src0 = dst->src[0];
  154. const float * src0_d = (const float *)src0->data;
  155. float * dst_d = (float *)dst->data;
  156. cudaStream_t stream = ctx.stream();
  157. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  158. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  159. relu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
  160. }
  161. void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  162. const ggml_tensor * src0 = dst->src[0];
  163. const float * src0_d = (const float *)src0->data;
  164. float * dst_d = (float *)dst->data;
  165. cudaStream_t stream = ctx.stream();
  166. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  167. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  168. sigmoid_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
  169. }
  170. void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  171. const ggml_tensor * src0 = dst->src[0];
  172. const float * src0_d = (const float *)src0->data;
  173. float * dst_d = (float *)dst->data;
  174. cudaStream_t stream = ctx.stream();
  175. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  176. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  177. hardsigmoid_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
  178. }
  179. void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  180. const ggml_tensor * src0 = dst->src[0];
  181. const float * src0_d = (const float *)src0->data;
  182. float * dst_d = (float *)dst->data;
  183. cudaStream_t stream = ctx.stream();
  184. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  185. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  186. hardswish_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
  187. }
  188. void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  189. const ggml_tensor * src0 = dst->src[0];
  190. const float * src0_d = (const float *)src0->data;
  191. float * dst_d = (float *)dst->data;
  192. cudaStream_t stream = ctx.stream();
  193. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  194. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  195. float negative_slope;
  196. memcpy(&negative_slope, dst->op_params, sizeof(float));
  197. leaky_relu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), negative_slope, stream);
  198. }
  199. void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  200. const ggml_tensor * src0 = dst->src[0];
  201. const float * src0_d = (const float *)src0->data;
  202. float * dst_d = (float *)dst->data;
  203. cudaStream_t stream = ctx.stream();
  204. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  205. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  206. sqr_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
  207. }