ggml-blas.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547
  1. /**
  2. * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
  3. *
  4. * MIT License
  5. *
  6. * Copyright (c) 2023-2024 The ggml authors
  7. *
  8. * Permission is hereby granted, free of charge, to any person obtaining a copy
  9. * of this software and associated documentation files (the "Software"), to deal
  10. * in the Software without restriction, including without limitation the rights
  11. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  12. * copies of the Software, and to permit persons to whom the Software is
  13. * furnished to do so, subject to the following conditions:
  14. *
  15. * The above copyright notice and this permission notice shall be included in all
  16. * copies or substantial portions of the Software.
  17. *
  18. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  19. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  20. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  21. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  22. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  23. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  24. * SOFTWARE.
  25. */
  26. #ifdef GGML_USE_BLAS
  27. #include "ggml-impl.h"
  28. #include "ggml-blas.h"
  29. #include "ggml-backend-impl.h"
  30. #include <future>
  31. #include <vector>
  32. #include <cstring>
  33. #if defined(GGML_BLAS_USE_ACCELERATE)
  34. # include <Accelerate/Accelerate.h>
  35. #elif defined(GGML_BLAS_USE_MKL)
  36. # include <mkl.h>
  37. #elif defined(GGML_BLAS_USE_BLIS)
  38. # include <blis.h>
  39. #elif defined(GGML_BLAS_USE_NVPL)
  40. # include <nvpl_blas.h>
  41. #else
  42. # include <cblas.h>
  43. #endif
  44. struct ggml_backend_blas_context {
  45. int n_threads = GGML_DEFAULT_N_THREADS;
  46. std::unique_ptr<char[]> work_data;
  47. size_t work_size = 0;
  48. #ifndef GGML_USE_OPENMP
  49. std::vector<std::future<void>> tasks;
  50. #endif
  51. };
  52. static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {
  53. const struct ggml_tensor * src0 = dst->src[0];
  54. const struct ggml_tensor * src1 = dst->src[1];
  55. GGML_TENSOR_BINARY_OP_LOCALS
  56. const enum ggml_type type = src0->type;
  57. GGML_ASSERT(ne0 == ne01);
  58. GGML_ASSERT(ne1 == ne11);
  59. GGML_ASSERT(ne2 == ne12);
  60. GGML_ASSERT(ne3 == ne13);
  61. // we don't support permuted src0 or src1
  62. GGML_ASSERT(nb00 == ggml_type_size(type));
  63. GGML_ASSERT(nb10 == ggml_type_size(src1->type));
  64. // dst cannot be transposed or permuted
  65. GGML_ASSERT(nb0 == sizeof(float));
  66. GGML_ASSERT(nb0 <= nb1);
  67. GGML_ASSERT(nb1 <= nb2);
  68. GGML_ASSERT(nb2 <= nb3);
  69. // broadcast factors
  70. const int64_t r2 = ne12/ne02;
  71. const int64_t r3 = ne13/ne03;
  72. const int64_t ne_plane = ne01*ne00;
  73. const size_t desired_wsize = type == GGML_TYPE_F32 ? 0 : ne03*ne02*ne_plane*sizeof(float);
  74. if (ctx->work_size < desired_wsize) {
  75. ctx->work_data.reset(new char[desired_wsize]);
  76. ctx->work_size = desired_wsize;
  77. }
  78. void * wdata = ctx->work_data.get();
  79. // convert src0 to float
  80. if (type != GGML_TYPE_F32) {
  81. const auto * type_traits = ggml_get_type_traits(type);
  82. ggml_to_float_t const to_float = type_traits->to_float;
  83. for (int64_t i03 = 0; i03 < ne03; i03++) {
  84. for (int64_t i02 = 0; i02 < ne02; i02++) {
  85. const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
  86. float * const wplane = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane;
  87. const int min_cols_per_thread = 4096;
  88. const int min_rows_per_thread = std::max((int)(min_cols_per_thread/ne00), 1);
  89. const int n_threads = std::max(std::min(ctx->n_threads, (int)(ne01/min_rows_per_thread)), 1);
  90. #ifdef GGML_USE_OPENMP
  91. #pragma omp parallel for num_threads(n_threads)
  92. for (int64_t i01 = 0; i01 < ne01; i01++) {
  93. to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
  94. }
  95. #else
  96. for (int i = 1; i < n_threads; i++) {
  97. const int64_t start = i*ne01/n_threads;
  98. const int64_t end = (i + 1)*ne01/n_threads;
  99. if (start < end) {
  100. ctx->tasks.push_back(std::async(std::launch::async, [=]() {
  101. for (int64_t i01 = start; i01 < end; i01++) {
  102. to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
  103. }
  104. }));
  105. }
  106. }
  107. {
  108. // reuse the current thread for the first task
  109. const int64_t start = 0;
  110. const int64_t end = ne01/n_threads;
  111. for (int64_t i01 = start; i01 < end; i01++) {
  112. to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
  113. }
  114. }
  115. #endif
  116. }
  117. }
  118. #ifndef GGML_USE_OPENMP
  119. // wait for all tasks to finish
  120. for (auto & task : ctx->tasks) {
  121. task.get();
  122. }
  123. ctx->tasks.clear();
  124. #endif
  125. }
  126. #if defined(OPENBLAS_VERSION)
  127. openblas_set_num_threads(ctx->n_threads);
  128. #endif
  129. #if defined(GGML_BLAS_USE_BLIS)
  130. bli_thread_set_num_threads(ctx->n_threads);
  131. #endif
  132. #if defined(GGML_BLAS_USE_NVPL)
  133. nvpl_blas_set_num_threads(ctx->n_threads);
  134. #endif
  135. for (int64_t i13 = 0; i13 < ne13; i13++) {
  136. for (int64_t i12 = 0; i12 < ne12; i12++) {
  137. const int64_t i03 = i13/r3;
  138. const int64_t i02 = i12/r2;
  139. const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
  140. const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);
  141. float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
  142. if (type != GGML_TYPE_F32) {
  143. x = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane;
  144. }
  145. cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
  146. ne1, ne01, ne10,
  147. 1.0f, y, ne10,
  148. x, ne00,
  149. 0.0f, d, ne01);
  150. }
  151. }
  152. }
  153. static void ggml_backend_blas_out_prod(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {
  154. const struct ggml_tensor * src0 = dst->src[0];
  155. const struct ggml_tensor * src1 = dst->src[1];
  156. GGML_TENSOR_BINARY_OP_LOCALS
  157. GGML_ASSERT(ne0 == ne00);
  158. GGML_ASSERT(ne1 == ne10);
  159. GGML_ASSERT(ne2 == ne02);
  160. GGML_ASSERT(ne02 == ne12);
  161. GGML_ASSERT(ne3 == ne13);
  162. GGML_ASSERT(ne03 == ne13);
  163. // we don't support permuted src0 or src1
  164. GGML_ASSERT(nb00 == sizeof(float));
  165. // dst cannot be transposed or permuted
  166. GGML_ASSERT(nb0 == sizeof(float));
  167. // GGML_ASSERT(nb0 <= nb1);
  168. // GGML_ASSERT(nb1 <= nb2);
  169. // GGML_ASSERT(nb2 <= nb3);
  170. // Arguments to ggml_compute_forward_out_prod (expressed as major,minor)
  171. // src0: (k,n)
  172. // src1: (k,m)
  173. // dst: (m,n)
  174. //
  175. // Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f)
  176. // Also expressed as (major,minor)
  177. // a: (m,k): so src1 transposed
  178. // b: (k,n): so src0
  179. // c: (m,n)
  180. //
  181. // However, if ggml_is_transposed(src1) is true, then
  182. // src1->data already contains a transposed version, so sgemm mustn't
  183. // transpose it further.
  184. int n = src0->ne[0];
  185. int k = src0->ne[1];
  186. int m = src1->ne[0];
  187. CBLAS_TRANSPOSE transposeA;
  188. int lda;
  189. if (!ggml_is_transposed(src1)) {
  190. transposeA = CblasTrans;
  191. lda = m;
  192. } else {
  193. transposeA = CblasNoTrans;
  194. lda = k;
  195. }
  196. float * a = (float *) ((char *) src1->data);
  197. float * b = (float *) ((char *) src0->data);
  198. float * c = (float *) ((char *) dst->data);
  199. cblas_sgemm(CblasRowMajor, transposeA, CblasNoTrans, m, n, k, 1.0, a, lda, b, n, 0.0, c, n);
  200. GGML_UNUSED(ctx);
  201. }
  202. // backend interface
  203. static const char * ggml_backend_blas_get_name(ggml_backend_t backend) {
  204. return "BLAS";
  205. GGML_UNUSED(backend);
  206. }
  207. static void ggml_backend_blas_free(ggml_backend_t backend) {
  208. ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;
  209. delete ctx;
  210. delete backend;
  211. }
  212. static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
  213. ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;
  214. for (int i = 0; i < cgraph->n_nodes; i++) {
  215. struct ggml_tensor * node = cgraph->nodes[i];
  216. switch (node->op) {
  217. case GGML_OP_MUL_MAT:
  218. ggml_backend_blas_mul_mat(ctx, node);
  219. break;
  220. case GGML_OP_OUT_PROD:
  221. ggml_backend_blas_out_prod(ctx, node);
  222. break;
  223. case GGML_OP_NONE:
  224. case GGML_OP_RESHAPE:
  225. case GGML_OP_VIEW:
  226. case GGML_OP_PERMUTE:
  227. case GGML_OP_TRANSPOSE:
  228. break;
  229. default:
  230. GGML_ABORT("%s: unsupported op %s\n", __func__, ggml_op_desc(node));
  231. }
  232. }
  233. return GGML_STATUS_SUCCESS;
  234. GGML_UNUSED(backend);
  235. }
  236. static struct ggml_backend_i blas_backend_i = {
  237. /* .get_name = */ ggml_backend_blas_get_name,
  238. /* .free = */ ggml_backend_blas_free,
  239. /* .set_tensor_async = */ NULL,
  240. /* .get_tensor_async = */ NULL,
  241. /* .cpy_tensor_async = */ NULL,
  242. /* .synchronize = */ NULL,
  243. /* .graph_plan_create = */ NULL,
  244. /* .graph_plan_free = */ NULL,
  245. /* .graph_plan_update = */ NULL,
  246. /* .graph_plan_compute = */ NULL,
  247. /* .graph_compute = */ ggml_backend_blas_graph_compute,
  248. /* .event_record = */ NULL,
  249. /* .event_wait = */ NULL,
  250. };
  251. static ggml_guid_t ggml_backend_blas_guid(void) {
  252. static ggml_guid guid = { 0x12, 0xa8, 0xae, 0xf4, 0xc0, 0x1e, 0x61, 0x97, 0x8f, 0xeb, 0x33, 0x04, 0xa1, 0x33, 0x51, 0x2d };
  253. return &guid;
  254. }
  255. ggml_backend_t ggml_backend_blas_init(void) {
  256. ggml_backend_blas_context * ctx = new ggml_backend_blas_context;
  257. ggml_backend_t backend = new ggml_backend {
  258. /* .guid = */ ggml_backend_blas_guid(),
  259. /* .interface = */ blas_backend_i,
  260. /* .device = */ ggml_backend_reg_dev_get(ggml_backend_blas_reg(), 0),
  261. /* .context = */ ctx,
  262. };
  263. #if defined(OPENBLAS_VERSION) && defined(GGML_USE_OPENMP)
  264. if (openblas_get_parallel() != OPENBLAS_OPENMP) {
  265. GGML_LOG_DEBUG("%s: warning: ggml is using OpenMP, but OpenBLAS was compiled without OpenMP support\n", __func__);
  266. }
  267. #endif
  268. #if defined(BLIS_ENABLE_CBLAS) && defined(GGML_USE_OPENMP) && !defined(BLIS_ENABLE_OPENMP)
  269. GGML_LOG_DEBUG("%s: warning: ggml is using OpenMP, but BLIS was compiled without OpenMP support\n", __func__);
  270. #endif
  271. return backend;
  272. }
  273. bool ggml_backend_is_blas(ggml_backend_t backend) {
  274. return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_blas_guid());
  275. }
  276. void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads) {
  277. GGML_ASSERT(ggml_backend_is_blas(backend_blas));
  278. ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend_blas->context;
  279. ctx->n_threads = n_threads;
  280. }
  281. // device interface
  282. static const char * ggml_backend_blas_device_get_name(ggml_backend_dev_t dev) {
  283. return "BLAS";
  284. GGML_UNUSED(dev);
  285. }
  286. static const char * ggml_backend_blas_device_get_description(ggml_backend_dev_t dev) {
  287. #if defined(GGML_BLAS_USE_ACCELERATE)
  288. return "Accelerate";
  289. #elif defined(GGML_BLAS_USE_MKL)
  290. return "MKL";
  291. #elif defined(GGML_BLAS_USE_BLIS)
  292. return "BLIS";
  293. #elif defined(GGML_BLAS_USE_NVPL)
  294. return "NVPL";
  295. #elif defined(OPENBLAS_VERSION)
  296. return "OpenBLAS";
  297. #else
  298. return "BLAS";
  299. #endif
  300. GGML_UNUSED(dev);
  301. }
  302. static void ggml_backend_blas_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
  303. // TODO
  304. *free = 0;
  305. *total = 0;
  306. GGML_UNUSED(dev);
  307. }
  308. static enum ggml_backend_dev_type ggml_backend_blas_device_get_type(ggml_backend_dev_t dev) {
  309. return GGML_BACKEND_DEVICE_TYPE_ACCEL;
  310. GGML_UNUSED(dev);
  311. }
  312. static void ggml_backend_blas_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
  313. props->name = ggml_backend_blas_device_get_name(dev);
  314. props->description = ggml_backend_blas_device_get_description(dev);
  315. props->type = ggml_backend_blas_device_get_type(dev);
  316. ggml_backend_blas_device_get_memory(dev, &props->memory_free, &props->memory_total);
  317. props->caps = {
  318. /* .async = */ false,
  319. /* .host_buffer = */ false,
  320. /* .buffer_from_host_ptr = */ true,
  321. /* .events = */ false,
  322. };
  323. }
  324. static ggml_backend_t ggml_backend_blas_device_init_backend(ggml_backend_dev_t dev, const char * params) {
  325. return ggml_backend_blas_init();
  326. GGML_UNUSED(dev);
  327. GGML_UNUSED(params);
  328. }
  329. static ggml_backend_buffer_type_t ggml_backend_blas_device_get_buffer_type(ggml_backend_dev_t dev) {
  330. return ggml_backend_cpu_buffer_type();
  331. GGML_UNUSED(dev);
  332. }
  333. static ggml_backend_buffer_t ggml_backend_blas_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
  334. return ggml_backend_cpu_buffer_from_ptr(ptr, size);
  335. GGML_UNUSED(dev);
  336. GGML_UNUSED(max_tensor_size);
  337. }
  338. static bool ggml_backend_blas_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
  339. const struct ggml_tensor * src0 = op->src[0];
  340. const struct ggml_tensor * src1 = op->src[1];
  341. switch (op->op) {
  342. case GGML_OP_NONE:
  343. case GGML_OP_RESHAPE:
  344. case GGML_OP_VIEW:
  345. case GGML_OP_PERMUTE:
  346. case GGML_OP_TRANSPOSE:
  347. return true;
  348. case GGML_OP_MUL_MAT:
  349. {
  350. // BLAS usually is only faster for large matrices
  351. const struct ggml_tensor * src0 = op->src[0];
  352. const struct ggml_tensor * src1 = op->src[1];
  353. const int64_t ne10 = src1->ne[0];
  354. const int64_t ne0 = op->ne[0];
  355. const int64_t ne1 = op->ne[1];
  356. // TODO: find the optimal value
  357. const int64_t min_batch = 32;
  358. return ggml_is_contiguous(src0) &&
  359. ggml_is_contiguous(src1) &&
  360. src1->type == GGML_TYPE_F32 &&
  361. (ne0 >= min_batch && ne1 >= min_batch && ne10 >= min_batch) &&
  362. (src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL);
  363. }
  364. case GGML_OP_OUT_PROD:
  365. return op->src[0]->type == GGML_TYPE_F32 &&
  366. op->src[1]->type == GGML_TYPE_F32 &&
  367. ggml_is_matrix(src0) &&
  368. ggml_is_matrix(src1) &&
  369. ggml_is_contiguous(src0) &&
  370. (ggml_is_contiguous(src1) || ggml_is_transposed(src1)) &&
  371. (src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL);
  372. default:
  373. return false;
  374. }
  375. GGML_UNUSED(dev);
  376. }
  377. static bool ggml_backend_blas_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
  378. return ggml_backend_buft_is_host(buft);
  379. GGML_UNUSED(dev);
  380. }
  381. static const struct ggml_backend_device_i ggml_backend_blas_device_i = {
  382. /* .get_name = */ ggml_backend_blas_device_get_name,
  383. /* .get_description = */ ggml_backend_blas_device_get_description,
  384. /* .get_memory = */ ggml_backend_blas_device_get_memory,
  385. /* .get_type = */ ggml_backend_blas_device_get_type,
  386. /* .get_props = */ ggml_backend_blas_device_get_props,
  387. /* .init_backend = */ ggml_backend_blas_device_init_backend,
  388. /* .get_buffer_type = */ ggml_backend_blas_device_get_buffer_type,
  389. /* .get_host_buffer_type = */ NULL,
  390. /* .buffer_from_host_ptr = */ ggml_backend_blas_device_buffer_from_host_ptr,
  391. /* .supports_op = */ ggml_backend_blas_device_supports_op,
  392. /* .supports_buft = */ ggml_backend_blas_device_supports_buft,
  393. /* .offload_op = */ NULL,
  394. /* .event_new = */ NULL,
  395. /* .event_free = */ NULL,
  396. /* .event_synchronize = */ NULL,
  397. };
  398. // backend reg interface
  399. static const char * ggml_backend_blas_reg_get_name(ggml_backend_reg_t reg) {
  400. return "BLAS";
  401. GGML_UNUSED(reg);
  402. }
  403. static size_t ggml_backend_blas_reg_get_device_count(ggml_backend_reg_t reg) {
  404. return 1;
  405. GGML_UNUSED(reg);
  406. }
  407. static ggml_backend_dev_t ggml_backend_blas_reg_get_device(ggml_backend_reg_t reg, size_t index) {
  408. GGML_ASSERT(index == 0);
  409. static ggml_backend_device ggml_backend_blas_device = {
  410. /* .iface = */ ggml_backend_blas_device_i,
  411. /* .reg = */ reg,
  412. /* .context = */ nullptr,
  413. };
  414. return &ggml_backend_blas_device;
  415. GGML_UNUSED(reg);
  416. GGML_UNUSED(index);
  417. }
  418. static void * ggml_backend_blas_get_proc_address(ggml_backend_reg_t reg, const char * name) {
  419. if (std::strcmp(name, "ggml_backend_set_n_threads") == 0) {
  420. return (void *)ggml_backend_blas_set_n_threads;
  421. }
  422. return NULL;
  423. GGML_UNUSED(reg);
  424. GGML_UNUSED(name);
  425. }
  426. static const struct ggml_backend_reg_i ggml_backend_blas_reg_i = {
  427. /* .get_name = */ ggml_backend_blas_reg_get_name,
  428. /* .get_device_count = */ ggml_backend_blas_reg_get_device_count,
  429. /* .get_device = */ ggml_backend_blas_reg_get_device,
  430. /* .get_proc_address = */ ggml_backend_blas_get_proc_address,
  431. };
  432. ggml_backend_reg_t ggml_backend_blas_reg(void) {
  433. static struct ggml_backend_reg ggml_backend_blas_reg = {
  434. /* .api_version = */ GGML_BACKEND_API_VERSION,
  435. /* .iface = */ ggml_backend_blas_reg_i,
  436. /* .context = */ NULL,
  437. };
  438. return &ggml_backend_blas_reg;
  439. }
  440. GGML_BACKEND_DL_IMPL(ggml_backend_blas_reg)
  441. #endif // GGML_USE_BLAS