rope.cu 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. /**
  2. * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
  3. *
  4. * MIT License
  5. *
  6. * Copyright (c) 2023-2024 The ggml authors
  7. *
  8. * Permission is hereby granted, free of charge, to any person obtaining a copy
  9. * of this software and associated documentation files (the "Software"), to deal
  10. * in the Software without restriction, including without limitation the rights
  11. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  12. * copies of the Software, and to permit persons to whom the Software is
  13. * furnished to do so, subject to the following conditions:
  14. *
  15. * The above copyright notice and this permission notice shall be included in all
  16. * copies or substantial portions of the Software.
  17. *
  18. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  19. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  20. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  21. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  22. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  23. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  24. * SOFTWARE.
  25. */
  26. #include "rope.cuh"
  27. struct rope_corr_dims {
  28. float v[2];
  29. };
  30. static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
  31. const float y = (i0 / 2 - low) / max(0.001f, high - low);
  32. return 1.0f - min(1.0f, max(0.0f, y));
  33. }
  34. // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
  35. // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
  36. static __device__ void rope_yarn(
  37. float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
  38. float * cos_theta, float * sin_theta) {
  39. // Get n-d rotational scaling corrected for extrapolation
  40. float theta_interp = freq_scale * theta_extrap;
  41. float theta = theta_interp;
  42. if (ext_factor != 0.0f) {
  43. float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
  44. theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
  45. // Get n-d magnitude scaling corrected for interpolation
  46. mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
  47. }
  48. *cos_theta = cosf(theta) * mscale;
  49. *sin_theta = sinf(theta) * mscale;
  50. }
  51. template<typename T, bool has_ff>
  52. static __global__ void rope_norm(
  53. const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
  54. float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
  55. const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
  56. if (i0 >= ne0) {
  57. return;
  58. }
  59. const int row = blockDim.x*blockIdx.x + threadIdx.x;
  60. if (i0 >= n_dims) {
  61. const int i = row*ne0 + i0;
  62. dst[i + 0] = x[i + 0];
  63. dst[i + 1] = x[i + 1];
  64. return;
  65. }
  66. const int i = row*ne0 + i0;
  67. const int i2 = row/p_delta_rows;
  68. const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
  69. const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
  70. float cos_theta;
  71. float sin_theta;
  72. rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
  73. const float x0 = x[i + 0];
  74. const float x1 = x[i + 1];
  75. dst[i + 0] = x0*cos_theta - x1*sin_theta;
  76. dst[i + 1] = x0*sin_theta + x1*cos_theta;
  77. }
  78. template<typename T, bool has_ff>
  79. static __global__ void rope_neox(
  80. const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
  81. float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
  82. const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
  83. if (i0 >= ne0) {
  84. return;
  85. }
  86. const int row = blockDim.x*blockIdx.x + threadIdx.x;
  87. if (i0 >= n_dims) {
  88. const int i = row*ne0 + i0;
  89. dst[i + 0] = x[i + 0];
  90. dst[i + 1] = x[i + 1];
  91. return;
  92. }
  93. const int i = row*ne0 + i0/2;
  94. const int i2 = row/p_delta_rows;
  95. const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
  96. const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
  97. float cos_theta;
  98. float sin_theta;
  99. rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
  100. const float x0 = x[i + 0];
  101. const float x1 = x[i + n_dims/2];
  102. dst[i + 0] = x0*cos_theta - x1*sin_theta;
  103. dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
  104. }
  105. template<typename T>
  106. static void rope_norm_cuda(
  107. const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
  108. float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
  109. GGML_ASSERT(ne0 % 2 == 0);
  110. const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
  111. const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
  112. const dim3 block_nums(nr, n_blocks_x, 1);
  113. const float theta_scale = powf(freq_base, -2.0f/n_dims);
  114. if (freq_factors == nullptr) {
  115. rope_norm<T, false><<<block_nums, block_dims, 0, stream>>>(
  116. x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
  117. theta_scale, freq_factors
  118. );
  119. } else {
  120. rope_norm<T, true><<<block_nums, block_dims, 0, stream>>>(
  121. x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
  122. theta_scale, freq_factors
  123. );
  124. }
  125. }
  126. template<typename T>
  127. static void rope_neox_cuda(
  128. const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
  129. float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
  130. GGML_ASSERT(ne0 % 2 == 0);
  131. const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
  132. const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
  133. const dim3 block_nums(nr, n_blocks_x, 1);
  134. const float theta_scale = powf(freq_base, -2.0f/n_dims);
  135. if (freq_factors == nullptr) {
  136. rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
  137. x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
  138. theta_scale, freq_factors
  139. );
  140. } else {
  141. rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
  142. x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
  143. theta_scale, freq_factors
  144. );
  145. }
  146. }
  147. static void rope_norm_cuda_f16(
  148. const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
  149. float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
  150. rope_norm_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
  151. }
  152. static void rope_norm_cuda_f32(
  153. const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
  154. float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
  155. rope_norm_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
  156. }
  157. static void rope_neox_cuda_f16(
  158. const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
  159. float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
  160. rope_neox_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
  161. }
  162. static void rope_neox_cuda_f32(
  163. const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
  164. float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
  165. ) {
  166. rope_neox_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
  167. }
  168. void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  169. const ggml_tensor * src0 = dst->src[0];
  170. const ggml_tensor * src1 = dst->src[1];
  171. const ggml_tensor * src2 = dst->src[2];
  172. const float * src0_d = (const float *)src0->data;
  173. const float * src1_d = (const float *)src1->data;
  174. float * dst_d = (float *)dst->data;
  175. cudaStream_t stream = ctx.stream();
  176. GGML_ASSERT(ggml_is_contiguous(src0));
  177. GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
  178. GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
  179. GGML_ASSERT(src0->type == dst->type);
  180. const int64_t ne00 = src0->ne[0];
  181. const int64_t ne01 = src0->ne[1];
  182. const int64_t nr = ggml_nrows(src0);
  183. //const int n_past = ((int32_t *) dst->op_params)[0];
  184. const int n_dims = ((int32_t *) dst->op_params)[1];
  185. const int mode = ((int32_t *) dst->op_params)[2];
  186. //const int n_ctx = ((int32_t *) dst->op_params)[3];
  187. const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
  188. // RoPE alteration for extended context
  189. float freq_base;
  190. float freq_scale;
  191. float ext_factor;
  192. float attn_factor;
  193. float beta_fast;
  194. float beta_slow;
  195. memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
  196. memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
  197. memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
  198. memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
  199. memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
  200. memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
  201. const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
  202. const int32_t * pos = (const int32_t *) src1_d;
  203. const float * freq_factors = nullptr;
  204. if (src2 != nullptr) {
  205. freq_factors = (const float *) src2->data;
  206. }
  207. rope_corr_dims corr_dims;
  208. ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
  209. // compute
  210. if (is_neox) {
  211. if (src0->type == GGML_TYPE_F32) {
  212. rope_neox_cuda_f32(
  213. (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
  214. attn_factor, corr_dims, freq_factors, stream
  215. );
  216. } else if (src0->type == GGML_TYPE_F16) {
  217. rope_neox_cuda_f16(
  218. (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
  219. attn_factor, corr_dims, freq_factors, stream
  220. );
  221. } else {
  222. GGML_ABORT("fatal error");
  223. }
  224. } else {
  225. if (src0->type == GGML_TYPE_F32) {
  226. rope_norm_cuda_f32(
  227. (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
  228. attn_factor, corr_dims, freq_factors, stream
  229. );
  230. } else if (src0->type == GGML_TYPE_F16) {
  231. rope_norm_cuda_f16(
  232. (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
  233. attn_factor, corr_dims, freq_factors, stream
  234. );
  235. } else {
  236. GGML_ABORT("fatal error");
  237. }
  238. }
  239. }