rope.cu 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524
  1. /**
  2. * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
  3. *
  4. * MIT License
  5. *
  6. * Copyright (c) 2023-2024 The ggml authors
  7. *
  8. * Permission is hereby granted, free of charge, to any person obtaining a copy
  9. * of this software and associated documentation files (the "Software"), to deal
  10. * in the Software without restriction, including without limitation the rights
  11. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  12. * copies of the Software, and to permit persons to whom the Software is
  13. * furnished to do so, subject to the following conditions:
  14. *
  15. * The above copyright notice and this permission notice shall be included in all
  16. * copies or substantial portions of the Software.
  17. *
  18. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  19. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  20. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  21. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  22. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  23. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  24. * SOFTWARE.
  25. */
  26. #include "rope.cuh"
  27. struct rope_corr_dims {
  28. float v[2];
  29. };
  30. struct mrope_sections {
  31. int v[4];
  32. };
  33. static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
  34. const float y = (i0 / 2 - low) / max(0.001f, high - low);
  35. return 1.0f - min(1.0f, max(0.0f, y));
  36. }
  37. // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
  38. // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
  39. static __device__ void rope_yarn(
  40. float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
  41. float * cos_theta, float * sin_theta) {
  42. // Get n-d rotational scaling corrected for extrapolation
  43. float theta_interp = freq_scale * theta_extrap;
  44. float theta = theta_interp;
  45. if (ext_factor != 0.0f) {
  46. float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
  47. theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
  48. // Get n-d magnitude scaling corrected for interpolation
  49. mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
  50. }
  51. *cos_theta = cosf(theta) * mscale;
  52. *sin_theta = sinf(theta) * mscale;
  53. }
  54. template<typename T, bool has_ff>
  55. static __global__ void rope_norm(
  56. const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
  57. float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
  58. const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
  59. if (i0 >= ne0) {
  60. return;
  61. }
  62. const int row = blockDim.x*blockIdx.x + threadIdx.x;
  63. if (i0 >= n_dims) {
  64. const int i = row*ne0 + i0;
  65. dst[i + 0] = x[i + 0];
  66. dst[i + 1] = x[i + 1];
  67. return;
  68. }
  69. const int i = row*ne0 + i0;
  70. const int i2 = row/p_delta_rows;
  71. const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
  72. const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
  73. float cos_theta;
  74. float sin_theta;
  75. rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
  76. const float x0 = x[i + 0];
  77. const float x1 = x[i + 1];
  78. dst[i + 0] = x0*cos_theta - x1*sin_theta;
  79. dst[i + 1] = x0*sin_theta + x1*cos_theta;
  80. }
  81. template<typename T, bool has_ff>
  82. static __global__ void rope_neox(
  83. const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
  84. float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
  85. const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
  86. if (i0 >= ne0) {
  87. return;
  88. }
  89. const int row = blockDim.x*blockIdx.x + threadIdx.x;
  90. if (i0 >= n_dims) {
  91. const int i = row*ne0 + i0;
  92. dst[i + 0] = x[i + 0];
  93. dst[i + 1] = x[i + 1];
  94. return;
  95. }
  96. const int i = row*ne0 + i0/2;
  97. const int i2 = row/p_delta_rows;
  98. const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
  99. const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
  100. float cos_theta;
  101. float sin_theta;
  102. rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
  103. const float x0 = x[i + 0];
  104. const float x1 = x[i + n_dims/2];
  105. dst[i + 0] = x0*cos_theta - x1*sin_theta;
  106. dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
  107. }
  108. template<typename T, bool has_ff>
  109. static __global__ void rope_multi(
  110. const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
  111. float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) {
  112. const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
  113. if (i0 >= ne0) {
  114. return;
  115. }
  116. const int row = blockDim.x*blockIdx.x + threadIdx.x;
  117. if (i0 >= n_dims) {
  118. const int i = row*ne0 + i0;
  119. dst[i + 0] = x[i + 0];
  120. dst[i + 1] = x[i + 1];
  121. return;
  122. }
  123. const int i = row*ne0 + i0/2;
  124. const int i2 = row/p_delta_rows;
  125. int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
  126. int sec_w = sections.v[1] + sections.v[0];
  127. int sector = (i0 / 2) % sect_dims;
  128. float theta_base = 0.0;
  129. if (sector < sections.v[0]) {
  130. theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
  131. }
  132. else if (sector >= sections.v[0] && sector < sec_w) {
  133. theta_base = pos[i2 + ne2 * 1]*powf(theta_scale, i0/2.0f);
  134. }
  135. else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
  136. theta_base = pos[i2 + ne2 * 2]*powf(theta_scale, i0/2.0f);
  137. }
  138. else if (sector >= sec_w + sections.v[2]) {
  139. theta_base = pos[i2 + ne2 * 3]*powf(theta_scale, i0/2.0f);
  140. }
  141. const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
  142. float cos_theta;
  143. float sin_theta;
  144. rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
  145. const float x0 = x[i + 0];
  146. const float x1 = x[i + n_dims/2];
  147. dst[i + 0] = x0*cos_theta - x1*sin_theta;
  148. dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
  149. }
  150. template<typename T, bool has_ff>
  151. static __global__ void rope_vision(
  152. const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
  153. float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) {
  154. const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
  155. if (i0 >= ne0) {
  156. return;
  157. }
  158. const int row = blockDim.x*blockIdx.x + threadIdx.x;
  159. const int i = row*ne0 + i0/2;
  160. const int i2 = row/p_delta_rows; // i2-th tokens
  161. int sect_dims = sections.v[0] + sections.v[1];
  162. int sec_w = sections.v[1] + sections.v[0];
  163. int sector = (i0 / 2) % sect_dims;
  164. float theta_base = 0.0;
  165. if (sector < sections.v[0]) {
  166. const int p = sector;
  167. theta_base = pos[i2]*powf(theta_scale, p);
  168. }
  169. else if (sector >= sections.v[0] && sector < sec_w) {
  170. const int p = sector - sections.v[0];
  171. theta_base = pos[i2 + ne2]*powf(theta_scale, p);
  172. }
  173. const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
  174. float cos_theta;
  175. float sin_theta;
  176. rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
  177. const float x0 = x[i + 0];
  178. const float x1 = x[i + n_dims];
  179. dst[i + 0] = x0*cos_theta - x1*sin_theta;
  180. dst[i + n_dims] = x0*sin_theta + x1*cos_theta;
  181. }
  182. template<typename T>
  183. static void rope_norm_cuda(
  184. const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
  185. float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
  186. GGML_ASSERT(ne0 % 2 == 0);
  187. const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
  188. const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
  189. const dim3 block_nums(nr, n_blocks_x, 1);
  190. const float theta_scale = powf(freq_base, -2.0f/n_dims);
  191. if (freq_factors == nullptr) {
  192. rope_norm<T, false><<<block_nums, block_dims, 0, stream>>>(
  193. x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
  194. theta_scale, freq_factors
  195. );
  196. } else {
  197. rope_norm<T, true><<<block_nums, block_dims, 0, stream>>>(
  198. x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
  199. theta_scale, freq_factors
  200. );
  201. }
  202. }
  203. template<typename T>
  204. static void rope_neox_cuda(
  205. const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
  206. float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
  207. GGML_ASSERT(ne0 % 2 == 0);
  208. const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
  209. const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
  210. const dim3 block_nums(nr, n_blocks_x, 1);
  211. const float theta_scale = powf(freq_base, -2.0f/n_dims);
  212. if (freq_factors == nullptr) {
  213. rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
  214. x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
  215. theta_scale, freq_factors
  216. );
  217. } else {
  218. rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
  219. x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
  220. theta_scale, freq_factors
  221. );
  222. }
  223. }
  224. template<typename T>
  225. static void rope_multi_cuda(
  226. const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
  227. float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) {
  228. GGML_ASSERT(ne0 % 2 == 0);
  229. const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
  230. const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
  231. const dim3 block_nums(nr, n_blocks_x, 1);
  232. const float theta_scale = powf(freq_base, -2.0f/n_dims);
  233. if (freq_factors == nullptr) {
  234. rope_multi<T, false><<<block_nums, block_dims, 0, stream>>>(
  235. x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
  236. theta_scale, freq_factors, sections
  237. );
  238. } else {
  239. rope_multi<T, true><<<block_nums, block_dims, 0, stream>>>(
  240. x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
  241. theta_scale, freq_factors, sections
  242. );
  243. }
  244. }
  245. template<typename T>
  246. static void rope_vision_cuda(
  247. const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
  248. float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) {
  249. GGML_ASSERT(ne0 % 2 == 0);
  250. const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
  251. const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
  252. const dim3 block_nums(nr, n_blocks_x, 1);
  253. // break down (head_dim, heads, seq) into (CUDA_ROPE_BLOCK_SIZE, x, heads * seq)
  254. // where x ~= ceil(head_dim / CUDA_ROPE_BLOCK_SIZE);
  255. const float theta_scale = powf(freq_base, -2.0f/n_dims);
  256. if (freq_factors == nullptr) {
  257. rope_vision<T, false><<<block_nums, block_dims, 0, stream>>>(
  258. x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
  259. theta_scale, freq_factors, sections
  260. );
  261. } else {
  262. rope_vision<T, true><<<block_nums, block_dims, 0, stream>>>(
  263. x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
  264. theta_scale, freq_factors, sections
  265. );
  266. }
  267. }
  268. static void rope_norm_cuda_f16(
  269. const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
  270. float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
  271. rope_norm_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
  272. }
  273. static void rope_norm_cuda_f32(
  274. const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
  275. float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
  276. rope_norm_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
  277. }
  278. static void rope_neox_cuda_f16(
  279. const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
  280. float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
  281. rope_neox_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
  282. }
  283. static void rope_neox_cuda_f32(
  284. const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
  285. float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
  286. ) {
  287. rope_neox_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
  288. }
  289. static void rope_multi_cuda_f16(
  290. const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
  291. float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
  292. ) {
  293. rope_multi_cuda<half>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
  294. }
  295. static void rope_multi_cuda_f32(
  296. const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
  297. float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
  298. ) {
  299. rope_multi_cuda<float>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
  300. }
  301. static void rope_vision_cuda_f16(
  302. const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
  303. float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
  304. ) {
  305. rope_vision_cuda<half>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
  306. }
  307. static void rope_vision_cuda_f32(
  308. const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
  309. float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
  310. ) {
  311. rope_vision_cuda<float>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
  312. }
  313. void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  314. const ggml_tensor * src0 = dst->src[0];
  315. const ggml_tensor * src1 = dst->src[1];
  316. const ggml_tensor * src2 = dst->src[2];
  317. const float * src0_d = (const float *)src0->data;
  318. const float * src1_d = (const float *)src1->data;
  319. float * dst_d = (float *)dst->data;
  320. cudaStream_t stream = ctx.stream();
  321. GGML_ASSERT(ggml_is_contiguous(src0));
  322. GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
  323. GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
  324. GGML_ASSERT(src0->type == dst->type);
  325. const int64_t ne00 = src0->ne[0]; // head dims
  326. const int64_t ne01 = src0->ne[1]; // num heads
  327. const int64_t ne02 = src0->ne[2]; // num heads
  328. const int64_t nr = ggml_nrows(src0);
  329. //const int n_past = ((int32_t *) dst->op_params)[0];
  330. const int n_dims = ((int32_t *) dst->op_params)[1];
  331. const int mode = ((int32_t *) dst->op_params)[2];
  332. //const int n_ctx = ((int32_t *) dst->op_params)[3];
  333. const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
  334. mrope_sections sections;
  335. // RoPE alteration for extended context
  336. float freq_base;
  337. float freq_scale;
  338. float ext_factor;
  339. float attn_factor;
  340. float beta_fast;
  341. float beta_slow;
  342. memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
  343. memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
  344. memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
  345. memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
  346. memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
  347. memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
  348. memcpy(&sections.v, (int32_t *) dst->op_params + 11, sizeof(int)*4);
  349. const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
  350. const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
  351. const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
  352. if (is_mrope) {
  353. GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0);
  354. }
  355. if (is_vision) {
  356. GGML_ASSERT(n_dims == ne00/2);
  357. }
  358. const int32_t * pos = (const int32_t *) src1_d;
  359. const float * freq_factors = nullptr;
  360. if (src2 != nullptr) {
  361. freq_factors = (const float *) src2->data;
  362. }
  363. rope_corr_dims corr_dims;
  364. ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
  365. // compute
  366. if (is_neox) {
  367. if (src0->type == GGML_TYPE_F32) {
  368. rope_neox_cuda_f32(
  369. (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
  370. attn_factor, corr_dims, freq_factors, stream
  371. );
  372. } else if (src0->type == GGML_TYPE_F16) {
  373. rope_neox_cuda_f16(
  374. (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
  375. attn_factor, corr_dims, freq_factors, stream
  376. );
  377. } else {
  378. GGML_ABORT("fatal error");
  379. }
  380. } else if (is_mrope && !is_vision) {
  381. if (src0->type == GGML_TYPE_F32) {
  382. rope_multi_cuda_f32(
  383. (const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
  384. attn_factor, corr_dims, freq_factors, sections, stream
  385. );
  386. } else if (src0->type == GGML_TYPE_F16) {
  387. rope_multi_cuda_f16(
  388. (const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
  389. attn_factor, corr_dims, freq_factors, sections, stream
  390. );
  391. } else {
  392. GGML_ABORT("fatal error");
  393. }
  394. } else if (is_vision) {
  395. if (src0->type == GGML_TYPE_F32) {
  396. rope_vision_cuda_f32(
  397. (const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
  398. attn_factor, corr_dims, freq_factors, sections, stream
  399. );
  400. } else if (src0->type == GGML_TYPE_F16) {
  401. rope_vision_cuda_f16(
  402. (const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
  403. attn_factor, corr_dims, freq_factors, sections, stream
  404. );
  405. } else {
  406. GGML_ABORT("fatal error");
  407. }
  408. } else {
  409. if (src0->type == GGML_TYPE_F32) {
  410. rope_norm_cuda_f32(
  411. (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
  412. attn_factor, corr_dims, freq_factors, stream
  413. );
  414. } else if (src0->type == GGML_TYPE_F16) {
  415. rope_norm_cuda_f16(
  416. (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
  417. attn_factor, corr_dims, freq_factors, stream
  418. );
  419. } else {
  420. GGML_ABORT("fatal error");
  421. }
  422. }
  423. }