0014-qwen2vl-support.patch 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
  2. From: jmorganca <jmorganca@gmail.com>
  3. Date: Sun, 15 Dec 2024 23:56:24 -0800
  4. Subject: [PATCH] qwen2vl support
  5. ---
  6. ggml/src/ggml-metal/ggml-metal-impl.h | 1 +
  7. ggml/src/ggml-metal/ggml-metal.m | 54 +++++++---
  8. ggml/src/ggml-metal/ggml-metal.metal | 146 ++++++++++++++++++++++++++
  9. 3 files changed, 186 insertions(+), 15 deletions(-)
  10. diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h
  11. index e3dc25f1..766a4999 100644
  12. --- a/ggml/src/ggml-metal/ggml-metal-impl.h
  13. +++ b/ggml/src/ggml-metal/ggml-metal-impl.h
  14. @@ -143,6 +143,7 @@ typedef struct {
  15. float attn_factor;
  16. float beta_fast;
  17. float beta_slow;
  18. + int32_t sections[4];
  19. } ggml_metal_kargs_rope;
  20. typedef struct {
  21. diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
  22. index 787fc713..806c9fd3 100644
  23. --- a/ggml/src/ggml-metal/ggml-metal.m
  24. +++ b/ggml/src/ggml-metal/ggml-metal.m
  25. @@ -302,6 +302,10 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
  26. GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
  27. GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
  28. GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
  29. + GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32,
  30. + GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16,
  31. + GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32,
  32. + GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16,
  33. GGML_METAL_KERNEL_TYPE_IM2COL_F16,
  34. GGML_METAL_KERNEL_TYPE_IM2COL_F32,
  35. GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
  36. @@ -902,6 +906,10 @@ @implementation GGMLMetalClass
  37. GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
  38. GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
  39. GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
  40. + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, rope_multi_f32, true);
  41. + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16, rope_multi_f16, true);
  42. + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32, rope_vision_f32, true);
  43. + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16, rope_vision_f16, true);
  44. GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
  45. GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
  46. GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true);
  47. @@ -1129,16 +1137,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
  48. case GGML_OP_NORM:
  49. return true;
  50. case GGML_OP_ROPE:
  51. - {
  52. - const int mode = ((const int32_t *) op->op_params)[2];
  53. - if (mode & GGML_ROPE_TYPE_MROPE) {
  54. - return false;
  55. - }
  56. - if (mode & GGML_ROPE_TYPE_VISION) {
  57. - return false;
  58. - }
  59. - return true;
  60. - }
  61. + return true;
  62. case GGML_OP_IM2COL:
  63. return op->src[0]->type == GGML_TYPE_F16;
  64. case GGML_OP_POOL_1D:
  65. @@ -3057,6 +3056,7 @@ static void ggml_metal_encode_node(
  66. float attn_factor;
  67. float beta_fast;
  68. float beta_slow;
  69. + int32_t sections[4];
  70. memcpy(&freq_base, (const int32_t *) dst->op_params + 5, sizeof(float));
  71. memcpy(&freq_scale, (const int32_t *) dst->op_params + 6, sizeof(float));
  72. @@ -3064,21 +3064,44 @@ static void ggml_metal_encode_node(
  73. memcpy(&attn_factor, (const int32_t *) dst->op_params + 8, sizeof(float));
  74. memcpy(&beta_fast, (const int32_t *) dst->op_params + 9, sizeof(float));
  75. memcpy(&beta_slow, (const int32_t *) dst->op_params + 10, sizeof(float));
  76. + memcpy(&sections, (const int32_t *) dst->op_params + 11, sizeof(int32_t)*4);
  77. const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
  78. + const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
  79. + const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
  80. +
  81. + if (is_mrope) {
  82. + GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
  83. + }
  84. +
  85. + if (is_vision) {
  86. + GGML_ASSERT(n_dims == ne00/2);
  87. + }
  88. id<MTLComputePipelineState> pipeline = nil;
  89. - if (!is_neox) {
  90. + if (is_neox) {
  91. switch (src0->type) {
  92. - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
  93. - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
  94. + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
  95. + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
  96. + default: GGML_ABORT("fatal error");
  97. + };
  98. + } else if (is_mrope && !is_vision) {
  99. + switch (src0->type) {
  100. + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32].pipeline; break;
  101. + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16].pipeline; break;
  102. + default: GGML_ABORT("fatal error");
  103. + };
  104. + } else if (is_vision) {
  105. + switch (src0->type) {
  106. + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32].pipeline; break;
  107. + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16].pipeline; break;
  108. default: GGML_ABORT("fatal error");
  109. };
  110. } else {
  111. switch (src0->type) {
  112. - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
  113. - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
  114. + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
  115. + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
  116. default: GGML_ABORT("fatal error");
  117. };
  118. }
  119. @@ -3109,6 +3132,7 @@ static void ggml_metal_encode_node(
  120. /*.attn_factor =*/ attn_factor,
  121. /*.beta_fast =*/ beta_fast,
  122. /*.beta_slow =*/ beta_slow,
  123. + /*.sections =*/ {sections[0], sections[1], sections[2], sections[3]}
  124. };
  125. [encoder setComputePipelineState:pipeline];
  126. diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
  127. index 204c93e6..67b3240f 100644
  128. --- a/ggml/src/ggml-metal/ggml-metal.metal
  129. +++ b/ggml/src/ggml-metal/ggml-metal.metal
  130. @@ -2568,8 +2568,148 @@ kernel void kernel_rope_neox(
  131. }
  132. }
  133. +
  134. +template<typename T>
  135. +kernel void kernel_rope_multi(
  136. + constant ggml_metal_kargs_rope & args,
  137. + device const char * src0,
  138. + device const char * src1,
  139. + device const char * src2,
  140. + device char * dst,
  141. + ushort tiitg[[thread_index_in_threadgroup]],
  142. + ushort3 tptg [[threads_per_threadgroup]],
  143. + uint3 tgpig[[threadgroup_position_in_grid]]) {
  144. + const int i3 = tgpig[2];
  145. + const int i2 = tgpig[1];
  146. + const int i1 = tgpig[0];
  147. +
  148. + float corr_dims[2];
  149. + rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
  150. +
  151. + device const int32_t * pos = (device const int32_t *) src1;
  152. +
  153. + int sect_dims = args.sections[0] + args.sections[1] + args.sections[2] + args.sections[3];
  154. + int sec_w = args.sections[1] + args.sections[0];
  155. +
  156. + const float inv_ndims = -1.f/args.n_dims;
  157. +
  158. + float cos_theta;
  159. + float sin_theta;
  160. +
  161. + for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
  162. + if (i0 < args.n_dims) {
  163. + const int ic = i0/2;
  164. + const int sector = ic % sect_dims;
  165. +
  166. + float theta_base = (float) pos[i2];
  167. + if (sector >= args.sections[0] && sector < sec_w) {
  168. + theta_base = (float) pos[i2 + args.ne2];
  169. + }
  170. + else if (sector >= sec_w && sector < sec_w + args.sections[2]) {
  171. + theta_base = (float) pos[i2 + args.ne2 * 2];
  172. + }
  173. + else if (sector >= sec_w + args.sections[2]) {
  174. + theta_base = (float) pos[i2 + args.ne2 * 3];
  175. + }
  176. +
  177. + float theta = theta_base*pow(args.freq_base, inv_ndims*i0);
  178. +
  179. + const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
  180. +
  181. + rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
  182. +
  183. + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
  184. + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
  185. +
  186. + const float x0 = src[0];
  187. + const float x1 = src[args.n_dims/2];
  188. +
  189. + dst_data[0] = x0*cos_theta - x1*sin_theta;
  190. + dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
  191. + } else {
  192. + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
  193. + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
  194. +
  195. + dst_data[0] = src[0];
  196. + dst_data[1] = src[1];
  197. + }
  198. + }
  199. +}
  200. +
  201. +template<typename T>
  202. +kernel void kernel_rope_vision(
  203. + constant ggml_metal_kargs_rope & args,
  204. + device const char * src0,
  205. + device const char * src1,
  206. + device const char * src2,
  207. + device char * dst,
  208. + ushort tiitg[[thread_index_in_threadgroup]],
  209. + ushort3 tptg [[threads_per_threadgroup]],
  210. + uint3 tgpig[[threadgroup_position_in_grid]]) {
  211. + const int i3 = tgpig[2];
  212. + const int i2 = tgpig[1];
  213. + const int i1 = tgpig[0];
  214. +
  215. + float corr_dims[2];
  216. + rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
  217. +
  218. + device const int32_t * pos = (device const int32_t *) src1;
  219. +
  220. + int sect_dims = args.sections[0] + args.sections[1];
  221. + int sec_w = args.sections[1] + args.sections[0];
  222. + int sec_e = args.sections[2] + sec_w;
  223. +
  224. + const float inv_ndims = -1.f/args.n_dims;
  225. +
  226. + float cos_theta;
  227. + float sin_theta;
  228. +
  229. + for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
  230. + const int ic = i0/2;
  231. + const int sector = ic % sect_dims;
  232. +
  233. + float theta_base = (float) pos[i2];
  234. + if (sector >= args.sections[0] && sector < sec_w) {
  235. + theta_base = (float) pos[i2 + args.ne2];
  236. + }
  237. + else if (sector >= sec_w && sector < sec_w + args.sections[2]) {
  238. + theta_base = (float) pos[i2 + args.ne2 * 2];
  239. + }
  240. + else if (sector >= sec_w + args.sections[2]) {
  241. + theta_base = (float) pos[i2 + args.ne2 * 3];
  242. + }
  243. +
  244. + int p = sector;
  245. + if (sector >= sec_w + args.sections[2]) {
  246. + p = sector - (sec_w + args.sections[2]);
  247. + } else if (sector >= sec_w) {
  248. + p = sector - sec_w;
  249. + } else if (sector >= args.sections[0]) {
  250. + p = sector - args.sections[0];
  251. + }
  252. +
  253. + const float theta = theta_base*pow(args.freq_base, inv_ndims*2.0f*p);
  254. +
  255. + const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
  256. +
  257. + rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
  258. +
  259. + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
  260. + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
  261. +
  262. + const float x0 = src[0];
  263. + const float x1 = src[args.n_dims];
  264. +
  265. + dst_data[0] = x0*cos_theta - x1*sin_theta;
  266. + dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta;
  267. + }
  268. +}
  269. +
  270. +
  271. typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
  272. typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
  273. +typedef decltype(kernel_rope_multi<float>) kernel_rope_multi_t;
  274. +typedef decltype(kernel_rope_vision<float>) kernel_rope_vision_t;
  275. template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
  276. template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
  277. @@ -2577,6 +2717,12 @@ template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_
  278. template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
  279. template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
  280. +template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi<float>;
  281. +template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi<half>;
  282. +
  283. +template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
  284. +template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
  285. +
  286. typedef void (im2col_t)(
  287. device const float * x,
  288. device char * dst,