mmvq.cu 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. /**
  2. * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
  3. *
  4. * MIT License
  5. *
  6. * Copyright (c) 2023-2024 The ggml authors
  7. *
  8. * Permission is hereby granted, free of charge, to any person obtaining a copy
  9. * of this software and associated documentation files (the "Software"), to deal
  10. * in the Software without restriction, including without limitation the rights
  11. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  12. * copies of the Software, and to permit persons to whom the Software is
  13. * furnished to do so, subject to the following conditions:
  14. *
  15. * The above copyright notice and this permission notice shall be included in all
  16. * copies or substantial portions of the Software.
  17. *
  18. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  19. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  20. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  21. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  22. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  23. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  24. * SOFTWARE.
  25. */
  26. #include "mmvq.cuh"
  27. #include "vecdotq.cuh"
  28. typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
  29. static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
  30. return type == GGML_TYPE_Q4_0 ? vec_dot_q4_0_q8_1 :
  31. type == GGML_TYPE_Q4_1 ? vec_dot_q4_1_q8_1 :
  32. type == GGML_TYPE_Q5_0 ? vec_dot_q5_0_q8_1 :
  33. type == GGML_TYPE_Q5_1 ? vec_dot_q5_1_q8_1 :
  34. type == GGML_TYPE_Q8_0 ? vec_dot_q8_0_q8_1 :
  35. type == GGML_TYPE_Q2_K ? vec_dot_q2_K_q8_1 :
  36. type == GGML_TYPE_Q3_K ? vec_dot_q3_K_q8_1 :
  37. type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1 :
  38. type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1 :
  39. type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1 :
  40. type == GGML_TYPE_IQ2_XXS ? vec_dot_iq2_xxs_q8_1 :
  41. type == GGML_TYPE_IQ2_XS ? vec_dot_iq2_xs_q8_1 :
  42. type == GGML_TYPE_IQ2_S ? vec_dot_iq2_s_q8_1 :
  43. type == GGML_TYPE_IQ3_XXS ? vec_dot_iq3_xxs_q8_1 :
  44. type == GGML_TYPE_IQ1_S ? vec_dot_iq1_s_q8_1 :
  45. type == GGML_TYPE_IQ1_M ? vec_dot_iq1_m_q8_1 :
  46. type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 :
  47. type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 :
  48. type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 :
  49. nullptr;
  50. }
  51. static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
  52. return type == GGML_TYPE_Q4_0 ? VDR_Q4_0_Q8_1_MMVQ :
  53. type == GGML_TYPE_Q4_1 ? VDR_Q4_1_Q8_1_MMVQ :
  54. type == GGML_TYPE_Q5_0 ? VDR_Q5_0_Q8_1_MMVQ :
  55. type == GGML_TYPE_Q5_1 ? VDR_Q5_1_Q8_1_MMVQ :
  56. type == GGML_TYPE_Q8_0 ? VDR_Q8_0_Q8_1_MMVQ :
  57. type == GGML_TYPE_Q2_K ? VDR_Q2_K_Q8_1_MMVQ :
  58. type == GGML_TYPE_Q3_K ? VDR_Q3_K_Q8_1_MMVQ :
  59. type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMVQ :
  60. type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMVQ :
  61. type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMVQ :
  62. type == GGML_TYPE_IQ2_XXS ? VDR_IQ2_XXS_Q8_1_MMVQ :
  63. type == GGML_TYPE_IQ2_XS ? VDR_IQ2_XS_Q8_1_MMVQ :
  64. type == GGML_TYPE_IQ2_S ? VDR_IQ2_S_Q8_1_MMVQ :
  65. type == GGML_TYPE_IQ3_XXS ? VDR_IQ3_XXS_Q8_1_MMVQ :
  66. type == GGML_TYPE_IQ3_S ? VDR_IQ3_S_Q8_1_MMVQ :
  67. type == GGML_TYPE_IQ4_NL ? VDR_IQ4_NL_Q8_1_MMVQ :
  68. type == GGML_TYPE_IQ4_XS ? VDR_IQ4_XS_Q8_1_MMVQ :
  69. 1;
  70. }
  71. template <ggml_type type, int ncols_y>
  72. #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
  73. // tell the compiler to use as many registers as it wants, see nwarps definition below
  74. __launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
  75. #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
  76. static __global__ void mul_mat_vec_q(
  77. const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  78. const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
  79. constexpr int qk = ggml_cuda_type_traits<type>::qk;
  80. constexpr int qi = ggml_cuda_type_traits<type>::qi;
  81. constexpr int vdr = get_vdr_mmvq(type);
  82. constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
  83. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
  84. constexpr int nwarps = 1;
  85. constexpr int rows_per_cuda_block = 1;
  86. #else
  87. constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
  88. constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
  89. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
  90. const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
  91. const int row0 = rows_per_cuda_block*blockIdx.x;
  92. const int blocks_per_row_x = ncols_x / qk;
  93. const int blocks_per_col_y = nrows_y / QK8_1;
  94. constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
  95. // partial sum for each thread
  96. float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
  97. const block_q8_1 * y = (const block_q8_1 *) vy;
  98. for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
  99. const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx
  100. // x block quant index when casting the quants to int
  101. const int kqs = vdr * (tid % (qi/vdr));
  102. #pragma unroll
  103. for (int j = 0; j < ncols_y; ++j) {
  104. #pragma unroll
  105. for (int i = 0; i < rows_per_cuda_block; ++i) {
  106. tmp[j][i] += vec_dot_q_cuda(vx, &y[j*blocks_per_col_y + kby], (row0 + i)*blocks_per_row_x + kbx, kqs);
  107. }
  108. }
  109. }
  110. __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][WARP_SIZE];
  111. if (threadIdx.y > 0) {
  112. #pragma unroll
  113. for (int j = 0; j < ncols_y; ++j) {
  114. #pragma unroll
  115. for (int i = 0; i < rows_per_cuda_block; ++i) {
  116. tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i];
  117. }
  118. }
  119. }
  120. __syncthreads();
  121. if (threadIdx.y > 0) {
  122. return;
  123. }
  124. // sum up partial sums and write back result
  125. #pragma unroll
  126. for (int j = 0; j < ncols_y; ++j) {
  127. #pragma unroll
  128. for (int i = 0; i < rows_per_cuda_block; ++i) {
  129. #pragma unroll
  130. for (int l = 0; l < nwarps-1; ++l) {
  131. tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
  132. }
  133. tmp[j][i] = warp_reduce_sum(tmp[j][i]);
  134. }
  135. if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
  136. dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
  137. }
  138. }
  139. }
  140. template <ggml_type type>
  141. static void mul_mat_vec_q_cuda(
  142. const void * vx, const void * vy, float * dst,
  143. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  144. GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
  145. GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
  146. int id = ggml_cuda_get_device();
  147. int64_t nwarps = 1;
  148. int64_t rows_per_cuda_block = 1;
  149. if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2
  150. switch(ncols_y) {
  151. case 1:
  152. nwarps = 4;
  153. rows_per_cuda_block = 1;
  154. break;
  155. case 2:
  156. case 3:
  157. case 4:
  158. nwarps = 4;
  159. rows_per_cuda_block = 2;
  160. break;
  161. case 5:
  162. case 6:
  163. case 7:
  164. case 8:
  165. nwarps = 2;
  166. rows_per_cuda_block = 2;
  167. break;
  168. default:
  169. GGML_ABORT("fatal error");
  170. break;
  171. }
  172. }
  173. const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
  174. const dim3 block_nums(nblocks, 1, 1);
  175. const dim3 block_dims(WARP_SIZE, nwarps, 1);
  176. switch (ncols_y) {
  177. case 1:
  178. mul_mat_vec_q<type, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
  179. break;
  180. case 2:
  181. mul_mat_vec_q<type, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
  182. break;
  183. case 3:
  184. mul_mat_vec_q<type, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
  185. break;
  186. case 4:
  187. mul_mat_vec_q<type, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
  188. break;
  189. case 5:
  190. mul_mat_vec_q<type, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
  191. break;
  192. case 6:
  193. mul_mat_vec_q<type, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
  194. break;
  195. case 7:
  196. mul_mat_vec_q<type, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
  197. break;
  198. case 8:
  199. mul_mat_vec_q<type, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
  200. break;
  201. default:
  202. GGML_ABORT("fatal error");
  203. break;
  204. }
  205. }
  206. static void mul_mat_vec_q4_0_q8_1_cuda(
  207. const void * vx, const void * vy, float * dst,
  208. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  209. mul_mat_vec_q_cuda<GGML_TYPE_Q4_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  210. }
  211. static void mul_mat_vec_q4_1_q8_1_cuda(
  212. const void * vx, const void * vy, float * dst,
  213. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  214. mul_mat_vec_q_cuda<GGML_TYPE_Q4_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  215. }
  216. static void mul_mat_vec_q5_0_q8_1_cuda(
  217. const void * vx, const void * vy, float * dst,
  218. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  219. mul_mat_vec_q_cuda<GGML_TYPE_Q5_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  220. }
  221. static void mul_mat_vec_q5_1_q8_1_cuda(
  222. const void * vx, const void * vy, float * dst,
  223. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  224. mul_mat_vec_q_cuda<GGML_TYPE_Q5_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  225. }
  226. static void mul_mat_vec_q8_0_q8_1_cuda(
  227. const void * vx, const void * vy, float * dst,
  228. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  229. mul_mat_vec_q_cuda<GGML_TYPE_Q8_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  230. }
  231. static void mul_mat_vec_q2_K_q8_1_cuda(
  232. const void * vx, const void * vy, float * dst,
  233. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  234. mul_mat_vec_q_cuda<GGML_TYPE_Q2_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  235. }
  236. static void mul_mat_vec_q3_K_q8_1_cuda(
  237. const void * vx, const void * vy, float * dst,
  238. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  239. mul_mat_vec_q_cuda<GGML_TYPE_Q3_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  240. }
  241. static void mul_mat_vec_q4_K_q8_1_cuda(
  242. const void * vx, const void * vy, float * dst,
  243. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  244. mul_mat_vec_q_cuda<GGML_TYPE_Q4_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  245. }
  246. static void mul_mat_vec_q5_K_q8_1_cuda(
  247. const void * vx, const void * vy, float * dst,
  248. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  249. mul_mat_vec_q_cuda<GGML_TYPE_Q5_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  250. }
  251. static void mul_mat_vec_q6_K_q8_1_cuda(
  252. const void * vx, const void * vy, float * dst,
  253. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  254. mul_mat_vec_q_cuda<GGML_TYPE_Q6_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  255. }
  256. static void mul_mat_vec_iq2_xxs_q8_1_cuda(
  257. const void * vx, const void * vy, float * dst,
  258. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  259. mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  260. }
  261. static void mul_mat_vec_iq2_xs_q8_1_cuda(
  262. const void * vx, const void * vy, float * dst,
  263. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  264. mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  265. }
  266. static void mul_mat_vec_iq2_s_q8_1_cuda(
  267. const void * vx, const void * vy, float * dst,
  268. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  269. mul_mat_vec_q_cuda<GGML_TYPE_IQ2_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  270. }
  271. static void mul_mat_vec_iq3_xxs_q8_1_cuda(
  272. const void * vx, const void * vy, float * dst,
  273. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  274. mul_mat_vec_q_cuda<GGML_TYPE_IQ3_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  275. }
  276. static void mul_mat_vec_iq1_s_q8_1_cuda(
  277. const void * vx, const void * vy, float * dst,
  278. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  279. mul_mat_vec_q_cuda<GGML_TYPE_IQ1_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  280. }
  281. static void mul_mat_vec_iq1_m_q8_1_cuda(
  282. const void * vx, const void * vy, float * dst,
  283. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  284. mul_mat_vec_q_cuda<GGML_TYPE_IQ1_M>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  285. }
  286. static void mul_mat_vec_iq4_nl_q8_1_cuda(
  287. const void * vx, const void * vy, float * dst,
  288. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  289. mul_mat_vec_q_cuda<GGML_TYPE_IQ4_NL>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  290. }
  291. static void mul_mat_vec_iq4_xs_q8_1_cuda(
  292. const void * vx, const void * vy, float * dst,
  293. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  294. mul_mat_vec_q_cuda<GGML_TYPE_IQ4_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  295. }
  296. static void mul_mat_vec_iq3_s_q8_1_cuda(
  297. const void * vx, const void * vy, float * dst,
  298. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  299. mul_mat_vec_q_cuda<GGML_TYPE_IQ3_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  300. }
  301. void ggml_cuda_op_mul_mat_vec_q(
  302. ggml_backend_cuda_context & ctx,
  303. const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
  304. const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
  305. const int64_t src1_padded_row_size, cudaStream_t stream) {
  306. const int64_t ne00 = src0->ne[0];
  307. const int64_t row_diff = row_high - row_low;
  308. const int64_t ne10 = src1->ne[0];
  309. GGML_ASSERT(ne10 % QK8_1 == 0);
  310. const int64_t ne0 = dst->ne[0];
  311. int id = ggml_cuda_get_device();
  312. // the main device has a larger memory buffer to hold the results from all GPUs
  313. // nrows_dst == nrows of the matrix that the kernel writes into
  314. const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
  315. switch (src0->type) {
  316. case GGML_TYPE_Q4_0:
  317. mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  318. break;
  319. case GGML_TYPE_Q4_1:
  320. mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  321. break;
  322. case GGML_TYPE_Q5_0:
  323. mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  324. break;
  325. case GGML_TYPE_Q5_1:
  326. mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  327. break;
  328. case GGML_TYPE_Q8_0:
  329. mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  330. break;
  331. case GGML_TYPE_Q2_K:
  332. mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  333. break;
  334. case GGML_TYPE_Q3_K:
  335. mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  336. break;
  337. case GGML_TYPE_Q4_K:
  338. mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  339. break;
  340. case GGML_TYPE_Q5_K:
  341. mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  342. break;
  343. case GGML_TYPE_Q6_K:
  344. mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  345. break;
  346. case GGML_TYPE_IQ2_XXS:
  347. mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  348. break;
  349. case GGML_TYPE_IQ2_XS:
  350. mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  351. break;
  352. case GGML_TYPE_IQ2_S:
  353. mul_mat_vec_iq2_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  354. break;
  355. case GGML_TYPE_IQ3_XXS:
  356. mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  357. break;
  358. case GGML_TYPE_IQ1_S:
  359. mul_mat_vec_iq1_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  360. break;
  361. case GGML_TYPE_IQ1_M:
  362. mul_mat_vec_iq1_m_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  363. break;
  364. case GGML_TYPE_IQ4_NL:
  365. mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  366. break;
  367. case GGML_TYPE_IQ4_XS:
  368. mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  369. break;
  370. case GGML_TYPE_IQ3_S:
  371. mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  372. break;
  373. default:
  374. GGML_ABORT("fatal error");
  375. break;
  376. }
  377. GGML_UNUSED(src1);
  378. GGML_UNUSED(dst);
  379. GGML_UNUSED(src1_ddf_i);
  380. GGML_UNUSED(src1_ncols);
  381. GGML_UNUSED(src1_padded_row_size);
  382. }