0008-add-unpad-operator.patch 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
  2. From: Michael Yang <mxyng@pm.me>
  3. Date: Thu, 17 Oct 2024 17:19:25 -0700
  4. Subject: [PATCH] add unpad operator
  5. ---
  6. ggml/include/ggml.h | 10 +++++
  7. ggml/src/ggml-cpu/ggml-cpu.c | 58 ++++++++++++++++++++++++++++
  8. ggml/src/ggml-cuda/ggml-cuda.cu | 4 ++
  9. ggml/src/ggml-cuda/pad.cu | 46 ++++++++++++++++++++++
  10. ggml/src/ggml-cuda/pad.cuh | 1 +
  11. ggml/src/ggml-metal/ggml-metal.m | 33 ++++++++++++++++
  12. ggml/src/ggml-metal/ggml-metal.metal | 45 +++++++++++++++++++++
  13. ggml/src/ggml.c | 25 +++++++++++-
  14. 8 files changed, 220 insertions(+), 2 deletions(-)
  15. diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
  16. index dd0c6a96..8d269a9c 100644
  17. --- a/ggml/include/ggml.h
  18. +++ b/ggml/include/ggml.h
  19. @@ -487,6 +487,7 @@ extern "C" {
  20. GGML_OP_UPSCALE, // nearest interpolate
  21. GGML_OP_PAD,
  22. GGML_OP_PAD_REFLECT_1D,
  23. + GGML_OP_UNPAD,
  24. GGML_OP_ARANGE,
  25. GGML_OP_TIMESTEP_EMBEDDING,
  26. GGML_OP_ARGSORT,
  27. @@ -1743,6 +1744,15 @@ extern "C" {
  28. int p0,
  29. int p1);
  30. + // unpad each dimension: [x, ..., x, y, ..., y] -> [x, ..., x]
  31. + GGML_API struct ggml_tensor * ggml_unpad(
  32. + struct ggml_context * ctx,
  33. + struct ggml_tensor * a,
  34. + int p0,
  35. + int p1,
  36. + int p2,
  37. + int p3);
  38. +
  39. // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
  40. // timesteps: [N,]
  41. // return: [N, dim]
  42. diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
  43. index 72325349..2f606d82 100644
  44. --- a/ggml/src/ggml-cpu/ggml-cpu.c
  45. +++ b/ggml/src/ggml-cpu/ggml-cpu.c
  46. @@ -10844,6 +10844,59 @@ static void ggml_compute_forward_pad_reflect_1d(
  47. }
  48. }
  49. +static void ggml_compute_forward_unpad_f32(
  50. + const struct ggml_compute_params *params,
  51. + struct ggml_tensor *dst) {
  52. +
  53. + const struct ggml_tensor * src0 = dst->src[0];
  54. +
  55. + GGML_ASSERT(src0->nb[0] == sizeof(float));
  56. + GGML_ASSERT( dst->nb[0] == sizeof(float));
  57. +
  58. + const int ith = params->ith;
  59. + const int nth = params->nth;
  60. +
  61. + GGML_TENSOR_UNARY_OP_LOCALS
  62. +
  63. + float * dst_ptr = (float *) dst->data;
  64. +
  65. + // TODO: optimize
  66. +
  67. + for (int64_t i2 = 0; i2 < ne2; ++i2) {
  68. + for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
  69. + for (int64_t i0 = 0; i0 < ne0; ++i0) {
  70. + for (int64_t i3 = 0; i3 < ne3; ++i3) {
  71. + const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
  72. +
  73. + const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
  74. +
  75. + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
  76. + dst_ptr[dst_idx] = *src_ptr;
  77. + }
  78. + }
  79. + }
  80. + }
  81. + }
  82. +}
  83. +
  84. +static void ggml_compute_forward_unpad(
  85. + const struct ggml_compute_params * params,
  86. + struct ggml_tensor * dst) {
  87. +
  88. + const struct ggml_tensor * src0 = dst->src[0];
  89. +
  90. + switch (src0->type) {
  91. + case GGML_TYPE_F32:
  92. + {
  93. + ggml_compute_forward_unpad_f32(params, dst);
  94. + } break;
  95. + default:
  96. + {
  97. + GGML_ABORT("fatal error");
  98. + }
  99. + }
  100. +}
  101. +
  102. // ggml_compute_forward_arange
  103. static void ggml_compute_forward_arange_f32(
  104. @@ -13137,6 +13190,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
  105. {
  106. ggml_compute_forward_pad_reflect_1d(params, tensor);
  107. } break;
  108. + case GGML_OP_UNPAD:
  109. + {
  110. + ggml_compute_forward_unpad(params, tensor);
  111. + } break;
  112. case GGML_OP_ARANGE:
  113. {
  114. ggml_compute_forward_arange(params, tensor);
  115. @@ -13484,6 +13541,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
  116. case GGML_OP_UPSCALE:
  117. case GGML_OP_PAD:
  118. case GGML_OP_PAD_REFLECT_1D:
  119. + case GGML_OP_UNPAD:
  120. case GGML_OP_ARANGE:
  121. case GGML_OP_TIMESTEP_EMBEDDING:
  122. case GGML_OP_ARGSORT:
  123. diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
  124. index 36165840..1adf08fa 100644
  125. --- a/ggml/src/ggml-cuda/ggml-cuda.cu
  126. +++ b/ggml/src/ggml-cuda/ggml-cuda.cu
  127. @@ -2198,6 +2198,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
  128. case GGML_OP_PAD:
  129. ggml_cuda_op_pad(ctx, dst);
  130. break;
  131. + case GGML_OP_UNPAD:
  132. + ggml_cuda_op_unpad(ctx, dst);
  133. + break;
  134. case GGML_OP_ARANGE:
  135. ggml_cuda_op_arange(ctx, dst);
  136. break;
  137. @@ -3197,6 +3200,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
  138. return ggml_is_contiguous(op->src[0]);
  139. case GGML_OP_UPSCALE:
  140. case GGML_OP_PAD:
  141. + case GGML_OP_UNPAD:
  142. case GGML_OP_ARANGE:
  143. case GGML_OP_TIMESTEP_EMBEDDING:
  144. case GGML_OP_LEAKY_RELU:
  145. diff --git a/ggml/src/ggml-cuda/pad.cu b/ggml/src/ggml-cuda/pad.cu
  146. index aba539e8..b4b87409 100644
  147. --- a/ggml/src/ggml-cuda/pad.cu
  148. +++ b/ggml/src/ggml-cuda/pad.cu
  149. @@ -47,3 +47,49 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  150. src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
  151. dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
  152. }
  153. +
  154. +static __global__ void unpad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) {
  155. + // blockIdx.z: idx of ne2*ne3, aka ne02*ne03
  156. + // blockIdx.y: idx of ne1
  157. + // blockIDx.x: idx of ne0 / BLOCK_SIZE
  158. + int nidx = threadIdx.x + blockIdx.x * blockDim.x;
  159. + if (nidx >= ne0) {
  160. + return;
  161. + }
  162. +
  163. + // operation
  164. + int offset_dst =
  165. + nidx +
  166. + blockIdx.y * ne0 +
  167. + blockIdx.z * ne0 * gridDim.y;
  168. + if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02*ne03) {
  169. + int offset_src =
  170. + nidx +
  171. + blockIdx.y * ne00 +
  172. + blockIdx.z * ne00 * ne01;
  173. + dst[offset_dst] = x[offset_src];
  174. + }
  175. +}
  176. +
  177. +static void unpad_f32_cuda(const float * x, float * dst,
  178. + const int ne00, const int ne01, const int ne02, const int ne03,
  179. + const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
  180. + int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
  181. + dim3 gridDim(num_blocks, ne1, ne2*ne3);
  182. + unpad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02, ne03);
  183. +}
  184. +
  185. +void ggml_cuda_op_unpad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  186. + const ggml_tensor * src0 = dst->src[0];
  187. + const float * src0_d = (const float *)src0->data;
  188. + float * dst_d = (float *)dst->data;
  189. + cudaStream_t stream = ctx.stream();
  190. +
  191. + GGML_ASSERT(src0->type == GGML_TYPE_F32);
  192. + GGML_ASSERT(dst->type == GGML_TYPE_F32);
  193. + GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
  194. +
  195. + unpad_f32_cuda(src0_d, dst_d,
  196. + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
  197. + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
  198. +}
  199. \ No newline at end of file
  200. diff --git a/ggml/src/ggml-cuda/pad.cuh b/ggml/src/ggml-cuda/pad.cuh
  201. index 8fd386b0..e2ededc3 100644
  202. --- a/ggml/src/ggml-cuda/pad.cuh
  203. +++ b/ggml/src/ggml-cuda/pad.cuh
  204. @@ -3,3 +3,4 @@
  205. #define CUDA_PAD_BLOCK_SIZE 256
  206. void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
  207. +void ggml_cuda_op_unpad(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
  208. diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
  209. index fd9a4e77..e4c093f9 100644
  210. --- a/ggml/src/ggml-metal/ggml-metal.m
  211. +++ b/ggml/src/ggml-metal/ggml-metal.m
  212. @@ -331,6 +331,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
  213. GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
  214. GGML_METAL_KERNEL_TYPE_PAD_F32,
  215. GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32,
  216. + GGML_METAL_KERNEL_TYPE_UNPAD_F32,
  217. GGML_METAL_KERNEL_TYPE_ARANGE_F32,
  218. GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
  219. GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
  220. @@ -946,6 +947,7 @@ @implementation GGMLMetalClass
  221. GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
  222. GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
  223. GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32, pad_reflect_1d_f32, true);
  224. + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UNPAD_F32, unpad_f32, true);
  225. GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
  226. GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
  227. GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
  228. @@ -1254,6 +1256,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
  229. case GGML_OP_UPSCALE:
  230. case GGML_OP_PAD:
  231. case GGML_OP_PAD_REFLECT_1D:
  232. + case GGML_OP_UNPAD:
  233. case GGML_OP_ARANGE:
  234. case GGML_OP_TIMESTEP_EMBEDDING:
  235. case GGML_OP_ARGSORT:
  236. @@ -3469,6 +3472,36 @@ static void ggml_metal_encode_node(
  237. const int nth = MIN(1024, ne0);
  238. + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
  239. + } break;
  240. + case GGML_OP_UNPAD:
  241. + {
  242. + GGML_ASSERT(src0->type == GGML_TYPE_F32);
  243. +
  244. + id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UNPAD_F32].pipeline;
  245. +
  246. + [encoder setComputePipelineState:pipeline];
  247. + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
  248. + [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
  249. + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
  250. + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
  251. + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
  252. + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
  253. + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
  254. + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
  255. + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
  256. + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
  257. + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
  258. + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
  259. + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
  260. + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
  261. + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
  262. + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
  263. + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
  264. + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
  265. +
  266. + const int nth = MIN(1024, ne0);
  267. +
  268. [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
  269. } break;
  270. case GGML_OP_ARANGE:
  271. diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
  272. index d092a169..f38909d0 100644
  273. --- a/ggml/src/ggml-metal/ggml-metal.metal
  274. +++ b/ggml/src/ggml-metal/ggml-metal.metal
  275. @@ -2953,6 +2953,51 @@ kernel void kernel_pad_reflect_1d_f32(
  276. }
  277. }
  278. +kernel void kernel_unpad_f32(
  279. + device const char * src0,
  280. + device char * dst,
  281. + constant int64_t & ne00,
  282. + constant int64_t & ne01,
  283. + constant int64_t & ne02,
  284. + constant int64_t & ne03,
  285. + constant uint64_t & nb00,
  286. + constant uint64_t & nb01,
  287. + constant uint64_t & nb02,
  288. + constant uint64_t & nb03,
  289. + constant int64_t & ne0,
  290. + constant int64_t & ne1,
  291. + constant int64_t & ne2,
  292. + constant int64_t & ne3,
  293. + constant uint64_t & nb0,
  294. + constant uint64_t & nb1,
  295. + constant uint64_t & nb2,
  296. + constant uint64_t & nb3,
  297. + uint3 tgpig[[threadgroup_position_in_grid]],
  298. + uint3 tpitg[[thread_position_in_threadgroup]],
  299. + uint3 ntg[[threads_per_threadgroup]]) {
  300. +
  301. + const int64_t i3 = tgpig.z;
  302. + const int64_t i2 = tgpig.y;
  303. + const int64_t i1 = tgpig.x;
  304. +
  305. + const int64_t i03 = i3;
  306. + const int64_t i02 = i2;
  307. + const int64_t i01 = i1;
  308. +
  309. + device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
  310. + device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
  311. +
  312. + if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
  313. + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
  314. + if (i0 < ne00) {
  315. + dst_ptr[i0] = src0_ptr[i0];
  316. + }
  317. + }
  318. +
  319. + return;
  320. + }
  321. +}
  322. +
  323. kernel void kernel_arange_f32(
  324. device char * dst,
  325. constant int64_t & ne0,
  326. diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
  327. index 7fc06724..635aa299 100644
  328. --- a/ggml/src/ggml.c
  329. +++ b/ggml/src/ggml.c
  330. @@ -962,6 +962,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
  331. "UPSCALE",
  332. "PAD",
  333. "PAD_REFLECT_1D",
  334. + "UNPAD",
  335. "ARANGE",
  336. "TIMESTEP_EMBEDDING",
  337. "ARGSORT",
  338. @@ -996,7 +997,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
  339. "OPT_STEP_ADAMW",
  340. };
  341. -static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
  342. +static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
  343. static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
  344. "none",
  345. @@ -1059,6 +1060,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
  346. "upscale(x)",
  347. "pad(x)",
  348. "pad_reflect_1d(x)",
  349. + "unpad(x)",
  350. "arange(start, stop, step)",
  351. "timestep_embedding(timesteps, dim, max_period)",
  352. "argsort(x)",
  353. @@ -1093,7 +1095,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
  354. "adamw(x)",
  355. };
  356. -static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
  357. +static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
  358. static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
  359. @@ -4225,6 +4227,25 @@ struct ggml_tensor * ggml_pad_reflect_1d(
  360. return result;
  361. }
  362. +// ggml_unpad
  363. +
  364. +struct ggml_tensor * ggml_unpad(
  365. + struct ggml_context * ctx,
  366. + struct ggml_tensor * a,
  367. + int p0, int p1, int p2, int p3) {
  368. +
  369. + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
  370. + a->ne[0] - p0,
  371. + a->ne[1] - p1,
  372. + a->ne[2] - p2,
  373. + a->ne[3] - p3);
  374. +
  375. + result->op = GGML_OP_UNPAD;
  376. + result->src[0] = a;
  377. +
  378. + return result;
  379. +}
  380. +
  381. // ggml_arange
  382. struct ggml_tensor * ggml_arange(