mmvq.cu 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. #include "mmvq.cuh"
  2. #include "vecdotq.cuh"
  3. typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs);
  4. template <int ncols_y, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
  5. #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
  6. // tell the compiler to use as many registers as it wants, see nwarps definition below
  7. __launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
  8. #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
  9. static __global__ void mul_mat_vec_q(
  10. const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  11. const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
  12. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
  13. constexpr int nwarps = 1;
  14. constexpr int rows_per_cuda_block = 1;
  15. #else
  16. constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
  17. constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
  18. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
  19. const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
  20. const int row0 = rows_per_cuda_block*blockIdx.x;
  21. const int blocks_per_row_x = ncols_x / qk;
  22. const int blocks_per_col_y = nrows_y / QK8_1;
  23. constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
  24. // partial sum for each thread
  25. float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
  26. const block_q_t * x = (const block_q_t *) vx;
  27. const block_q8_1 * y = (const block_q8_1 *) vy;
  28. for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
  29. const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx
  30. // x block quant index when casting the quants to int
  31. const int kqs = vdr * (tid % (qi/vdr));
  32. #pragma unroll
  33. for (int j = 0; j < ncols_y; ++j) {
  34. #pragma unroll
  35. for (int i = 0; i < rows_per_cuda_block; ++i) {
  36. tmp[j][i] += vec_dot_q_cuda(
  37. &x[kbx + (row0 + i)*blocks_per_row_x], &y[j*blocks_per_col_y + kby], kqs);
  38. }
  39. }
  40. }
  41. __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][WARP_SIZE];
  42. if (threadIdx.y > 0) {
  43. #pragma unroll
  44. for (int j = 0; j < ncols_y; ++j) {
  45. #pragma unroll
  46. for (int i = 0; i < rows_per_cuda_block; ++i) {
  47. tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i];
  48. }
  49. }
  50. }
  51. __syncthreads();
  52. if (threadIdx.y > 0) {
  53. return;
  54. }
  55. // sum up partial sums and write back result
  56. #pragma unroll
  57. for (int j = 0; j < ncols_y; ++j) {
  58. #pragma unroll
  59. for (int i = 0; i < rows_per_cuda_block; ++i) {
  60. #pragma unroll
  61. for (int l = 0; l < nwarps-1; ++l) {
  62. tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
  63. }
  64. tmp[j][i] = warp_reduce_sum(tmp[j][i]);
  65. }
  66. if (threadIdx.x < rows_per_cuda_block) {
  67. dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
  68. }
  69. }
  70. }
  71. template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot>
  72. static void mul_mat_vec_q_cuda(
  73. const void * vx, const void * vy, float * dst,
  74. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  75. GGML_ASSERT(ncols_x % qk == 0);
  76. GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
  77. int id = ggml_cuda_get_device();
  78. int64_t nwarps = 1;
  79. int64_t rows_per_cuda_block = 1;
  80. if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2
  81. switch(ncols_y) {
  82. case 1:
  83. nwarps = 4;
  84. rows_per_cuda_block = 1;
  85. break;
  86. case 2:
  87. case 3:
  88. case 4:
  89. nwarps = 4;
  90. rows_per_cuda_block = 2;
  91. break;
  92. case 5:
  93. case 6:
  94. case 7:
  95. case 8:
  96. nwarps = 2;
  97. rows_per_cuda_block = 2;
  98. break;
  99. default:
  100. GGML_ASSERT(false);
  101. break;
  102. }
  103. }
  104. const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
  105. const dim3 block_nums(nblocks, 1, 1);
  106. const dim3 block_dims(WARP_SIZE, nwarps, 1);
  107. switch (ncols_y) {
  108. case 1:
  109. mul_mat_vec_q<1, qk, qi, block_q_t, vdr, vec_dot>
  110. <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
  111. break;
  112. case 2:
  113. mul_mat_vec_q<2, qk, qi, block_q_t, vdr, vec_dot>
  114. <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
  115. break;
  116. case 3:
  117. mul_mat_vec_q<3, qk, qi, block_q_t, vdr, vec_dot>
  118. <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
  119. break;
  120. case 4:
  121. mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot>
  122. <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
  123. break;
  124. case 5:
  125. mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot>
  126. <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
  127. break;
  128. case 6:
  129. mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot>
  130. <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
  131. break;
  132. case 7:
  133. mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot>
  134. <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
  135. break;
  136. case 8:
  137. mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot>
  138. <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
  139. break;
  140. default:
  141. GGML_ASSERT(false);
  142. break;
  143. }
  144. }
  145. static void mul_mat_vec_q4_0_q8_1_cuda(
  146. const void * vx, const void * vy, float * dst,
  147. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  148. mul_mat_vec_q_cuda<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
  149. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  150. }
  151. static void mul_mat_vec_q4_1_q8_1_cuda(
  152. const void * vx, const void * vy, float * dst,
  153. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  154. mul_mat_vec_q_cuda<QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
  155. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  156. }
  157. static void mul_mat_vec_q5_0_q8_1_cuda(
  158. const void * vx, const void * vy, float * dst,
  159. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  160. mul_mat_vec_q_cuda<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
  161. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  162. }
  163. static void mul_mat_vec_q5_1_q8_1_cuda(
  164. const void * vx, const void * vy, float * dst,
  165. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  166. mul_mat_vec_q_cuda<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
  167. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  168. }
  169. static void mul_mat_vec_q8_0_q8_1_cuda(
  170. const void * vx, const void * vy, float * dst,
  171. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  172. mul_mat_vec_q_cuda<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
  173. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  174. }
  175. static void mul_mat_vec_q2_K_q8_1_cuda(
  176. const void * vx, const void * vy, float * dst,
  177. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  178. mul_mat_vec_q_cuda<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
  179. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  180. }
  181. static void mul_mat_vec_q3_K_q8_1_cuda(
  182. const void * vx, const void * vy, float * dst,
  183. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  184. mul_mat_vec_q_cuda<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
  185. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  186. }
  187. static void mul_mat_vec_q4_K_q8_1_cuda(
  188. const void * vx, const void * vy, float * dst,
  189. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  190. mul_mat_vec_q_cuda<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
  191. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  192. }
  193. static void mul_mat_vec_q5_K_q8_1_cuda(
  194. const void * vx, const void * vy, float * dst,
  195. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  196. mul_mat_vec_q_cuda<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
  197. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  198. }
  199. static void mul_mat_vec_q6_K_q8_1_cuda(
  200. const void * vx, const void * vy, float * dst,
  201. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  202. mul_mat_vec_q_cuda<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
  203. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  204. }
  205. static void mul_mat_vec_iq2_xxs_q8_1_cuda(
  206. const void * vx, const void * vy, float * dst,
  207. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  208. mul_mat_vec_q_cuda<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
  209. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  210. }
  211. static void mul_mat_vec_iq2_xs_q8_1_cuda(
  212. const void * vx, const void * vy, float * dst,
  213. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  214. mul_mat_vec_q_cuda<QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
  215. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  216. }
  217. static void mul_mat_vec_iq2_s_q8_1_cuda(
  218. const void * vx, const void * vy, float * dst,
  219. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  220. mul_mat_vec_q_cuda<QK_K, QI2_S, block_iq2_s, 1, vec_dot_iq2_s_q8_1>
  221. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  222. }
  223. static void mul_mat_vec_iq3_xxs_q8_1_cuda(
  224. const void * vx, const void * vy, float * dst,
  225. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  226. mul_mat_vec_q_cuda<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
  227. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  228. }
  229. static void mul_mat_vec_iq1_s_q8_1_cuda(
  230. const void * vx, const void * vy, float * dst,
  231. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  232. mul_mat_vec_q_cuda<QK_K, QI1_S, block_iq1_s, 1, vec_dot_iq1_s_q8_1>
  233. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  234. }
  235. static void mul_mat_vec_iq1_m_q8_1_cuda(
  236. const void * vx, const void * vy, float * dst,
  237. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  238. mul_mat_vec_q_cuda<QK_K, QI1_S, block_iq1_m, 1, vec_dot_iq1_m_q8_1>
  239. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  240. }
  241. static void mul_mat_vec_iq4_nl_q8_1_cuda(
  242. const void * vx, const void * vy, float * dst,
  243. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  244. mul_mat_vec_q_cuda<QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
  245. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  246. }
  247. static void mul_mat_vec_iq4_xs_q8_1_cuda(
  248. const void * vx, const void * vy, float * dst,
  249. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  250. mul_mat_vec_q_cuda<QK_K, QI4_XS, block_iq4_xs, 1, vec_dot_iq4_xs_q8_1>
  251. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  252. }
  253. static void mul_mat_vec_iq3_s_q8_1_cuda(
  254. const void * vx, const void * vy, float * dst,
  255. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  256. mul_mat_vec_q_cuda<QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1>
  257. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  258. }
  259. void ggml_cuda_op_mul_mat_vec_q(
  260. ggml_backend_cuda_context & ctx,
  261. const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
  262. const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
  263. const int64_t src1_padded_row_size, cudaStream_t stream) {
  264. const int64_t ne00 = src0->ne[0];
  265. const int64_t row_diff = row_high - row_low;
  266. const int64_t ne10 = src1->ne[0];
  267. GGML_ASSERT(ne10 % QK8_1 == 0);
  268. const int64_t ne0 = dst->ne[0];
  269. int id = ggml_cuda_get_device();
  270. // the main device has a larger memory buffer to hold the results from all GPUs
  271. // nrows_dst == nrows of the matrix that the kernel writes into
  272. const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
  273. switch (src0->type) {
  274. case GGML_TYPE_Q4_0:
  275. mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  276. break;
  277. case GGML_TYPE_Q4_1:
  278. mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  279. break;
  280. case GGML_TYPE_Q5_0:
  281. mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  282. break;
  283. case GGML_TYPE_Q5_1:
  284. mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  285. break;
  286. case GGML_TYPE_Q8_0:
  287. mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  288. break;
  289. case GGML_TYPE_Q2_K:
  290. mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  291. break;
  292. case GGML_TYPE_Q3_K:
  293. mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  294. break;
  295. case GGML_TYPE_Q4_K:
  296. mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  297. break;
  298. case GGML_TYPE_Q5_K:
  299. mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  300. break;
  301. case GGML_TYPE_Q6_K:
  302. mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  303. break;
  304. case GGML_TYPE_IQ2_XXS:
  305. mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  306. break;
  307. case GGML_TYPE_IQ2_XS:
  308. mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  309. break;
  310. case GGML_TYPE_IQ2_S:
  311. mul_mat_vec_iq2_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  312. break;
  313. case GGML_TYPE_IQ3_XXS:
  314. mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  315. break;
  316. case GGML_TYPE_IQ1_S:
  317. mul_mat_vec_iq1_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  318. break;
  319. case GGML_TYPE_IQ1_M:
  320. mul_mat_vec_iq1_m_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  321. break;
  322. case GGML_TYPE_IQ4_NL:
  323. mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  324. break;
  325. case GGML_TYPE_IQ4_XS:
  326. mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  327. break;
  328. case GGML_TYPE_IQ3_S:
  329. mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  330. break;
  331. default:
  332. GGML_ASSERT(false);
  333. break;
  334. }
  335. GGML_UNUSED(src1);
  336. GGML_UNUSED(dst);
  337. GGML_UNUSED(src1_ddf_i);
  338. GGML_UNUSED(src1_ncols);
  339. GGML_UNUSED(src1_padded_row_size);
  340. }