dmmv.cuh 584 B

123456789101112131415161718
  1. #include "common.cuh"
  2. // dmmv = dequantize_mul_mat_vec
  3. // TODO: remove this?
  4. #ifndef GGML_CUDA_DMMV_X
  5. #define GGML_CUDA_DMMV_X 32
  6. #endif
  7. #ifndef GGML_CUDA_MMV_Y
  8. #define GGML_CUDA_MMV_Y 1
  9. #endif
  10. void ggml_cuda_op_dequantize_mul_mat_vec(
  11. ggml_backend_cuda_context & ctx,
  12. const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
  13. const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
  14. const int64_t src1_padded_row_size, cudaStream_t stream);