ggml-mpi.c 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. //go:build mpi
  2. /**
  3. * llama.cpp - git 3ebb00935f3f0522b75df49c2769ab1774b91380
  4. *
  5. * MIT License
  6. *
  7. * Copyright (c) 2023 Georgi Gerganov
  8. *
  9. * Permission is hereby granted, free of charge, to any person obtaining a copy
  10. * of this software and associated documentation files (the "Software"), to deal
  11. * in the Software without restriction, including without limitation the rights
  12. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  13. * copies of the Software, and to permit persons to whom the Software is
  14. * furnished to do so, subject to the following conditions:
  15. *
  16. * The above copyright notice and this permission notice shall be included in all
  17. * copies or substantial portions of the Software.
  18. *
  19. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  20. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  21. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  22. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  23. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  24. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  25. * SOFTWARE.
  26. */
  27. #include "ggml-mpi.h"
  28. #include "ggml.h"
  29. #include <mpi.h>
  30. #include <stdio.h>
  31. #include <stdlib.h>
  32. #define MIN(a, b) ((a) < (b) ? (a) : (b))
  33. #define UNUSED GGML_UNUSED
  34. struct ggml_mpi_context {
  35. int rank;
  36. int size;
  37. };
  38. void ggml_mpi_backend_init(void) {
  39. MPI_Init(NULL, NULL);
  40. }
  41. void ggml_mpi_backend_free(void) {
  42. MPI_Finalize();
  43. }
  44. struct ggml_mpi_context * ggml_mpi_init(void) {
  45. struct ggml_mpi_context * ctx = calloc(1, sizeof(struct ggml_mpi_context));
  46. MPI_Comm_rank(MPI_COMM_WORLD, &ctx->rank);
  47. MPI_Comm_size(MPI_COMM_WORLD, &ctx->size);
  48. return ctx;
  49. }
  50. void ggml_mpi_free(struct ggml_mpi_context * ctx) {
  51. free(ctx);
  52. }
  53. int ggml_mpi_rank(struct ggml_mpi_context * ctx) {
  54. return ctx->rank;
  55. }
  56. void ggml_mpi_eval_init(
  57. struct ggml_mpi_context * ctx_mpi,
  58. int * n_tokens,
  59. int * n_past,
  60. int * n_threads) {
  61. UNUSED(ctx_mpi);
  62. // synchronize the worker node parameters with the root node
  63. MPI_Barrier(MPI_COMM_WORLD);
  64. MPI_Bcast(n_tokens, 1, MPI_INT, 0, MPI_COMM_WORLD);
  65. MPI_Bcast(n_past, 1, MPI_INT, 0, MPI_COMM_WORLD);
  66. MPI_Bcast(n_threads, 1, MPI_INT, 0, MPI_COMM_WORLD);
  67. }
  68. static int ggml_graph_get_node_idx(struct ggml_cgraph * gf, const char * name) {
  69. struct ggml_tensor * t = ggml_graph_get_tensor(gf, name);
  70. if (t == NULL) {
  71. fprintf(stderr, "%s: tensor %s not found\n", __func__, name);
  72. return -1;
  73. }
  74. for (int i = 0; i < gf->n_nodes; i++) {
  75. if (gf->nodes[i] == t) {
  76. return i;
  77. }
  78. }
  79. fprintf(stderr, "%s: tensor %s not found in graph (should not happen)\n", __func__, name);
  80. return -1;
  81. }
  82. static void ggml_mpi_tensor_send(struct ggml_tensor * t, int mpi_rank_dst) {
  83. MPI_Datatype mpi_type;
  84. switch (t->type) {
  85. case GGML_TYPE_I32: mpi_type = MPI_INT32_T; break;
  86. case GGML_TYPE_F32: mpi_type = MPI_FLOAT; break;
  87. default: GGML_ASSERT(false && "not implemented");
  88. }
  89. const int retval = MPI_Send(t->data, ggml_nelements(t), mpi_type, mpi_rank_dst, 0, MPI_COMM_WORLD);
  90. GGML_ASSERT(retval == MPI_SUCCESS);
  91. }
  92. static void ggml_mpi_tensor_recv(struct ggml_tensor * t, int mpi_rank_src) {
  93. MPI_Datatype mpi_type;
  94. switch (t->type) {
  95. case GGML_TYPE_I32: mpi_type = MPI_INT32_T; break;
  96. case GGML_TYPE_F32: mpi_type = MPI_FLOAT; break;
  97. default: GGML_ASSERT(false && "not implemented");
  98. }
  99. MPI_Status status; UNUSED(status);
  100. const int retval = MPI_Recv(t->data, ggml_nelements(t), mpi_type, mpi_rank_src, MPI_ANY_TAG, MPI_COMM_WORLD, &status);
  101. GGML_ASSERT(retval == MPI_SUCCESS);
  102. }
  103. // TODO: there are many improvements that can be done to this implementation
  104. void ggml_mpi_graph_compute_pre(
  105. struct ggml_mpi_context * ctx_mpi,
  106. struct ggml_cgraph * gf,
  107. int n_layers) {
  108. const int mpi_rank = ctx_mpi->rank;
  109. const int mpi_size = ctx_mpi->size;
  110. struct ggml_tensor * inp_tokens = ggml_graph_get_tensor(gf, "inp_tokens");
  111. if (inp_tokens == NULL) {
  112. fprintf(stderr, "%s: tensor 'inp_tokens' not found\n", __func__);
  113. return;
  114. }
  115. struct ggml_tensor * inp0 = ggml_graph_get_tensor(gf, "layer_inp_0");
  116. if (inp0 == NULL) {
  117. fprintf(stderr, "%s: tensor 'inp0' not found\n", __func__);
  118. return;
  119. }
  120. GGML_ASSERT(inp0 == gf->nodes[0]);
  121. // distribute the compute graph into slices across the MPI nodes
  122. //
  123. // the main node (0) processes the last layers + the remainder of the compute graph
  124. // and is responsible to pass the input tokens to the first node (1)
  125. //
  126. // node 1: [( 0) * n_per_node, ( 1) * n_per_node)
  127. // node 2: [( 1) * n_per_node, ( 2) * n_per_node)
  128. // ...
  129. // node n-1: [(n-2) * n_per_node, (n-1) * n_per_node)
  130. // node 0: [(n-1) * n_per_node, n_nodes)
  131. //
  132. if (mpi_rank > 0) {
  133. if (mpi_rank == 1) {
  134. // the first node (1) receives the input tokens from the main node (0)
  135. ggml_mpi_tensor_recv(inp_tokens, 0);
  136. } else {
  137. // recv input data for each node into the "inp0" tensor (i.e. the first node in the compute graph)
  138. ggml_mpi_tensor_recv(inp0, mpi_rank - 1);
  139. }
  140. } else if (mpi_size > 1) {
  141. // node 0 sends the input tokens to node 1
  142. ggml_mpi_tensor_send(inp_tokens, 1);
  143. // recv the output data from the last node
  144. ggml_mpi_tensor_recv(inp0, mpi_size - 1);
  145. }
  146. {
  147. const int n_per_node = (n_layers + (mpi_size - 1)) / mpi_size;
  148. const int mpi_idx = mpi_rank > 0 ? mpi_rank - 1 : mpi_size - 1;
  149. const int il0 = (mpi_idx + 0) * n_per_node;
  150. const int il1 = MIN(n_layers, (mpi_idx + 1) * n_per_node);
  151. char name_l0[GGML_MAX_NAME];
  152. char name_l1[GGML_MAX_NAME];
  153. snprintf(name_l0, sizeof(name_l0), "layer_inp_%d", il0);
  154. snprintf(name_l1, sizeof(name_l1), "layer_inp_%d", il1);
  155. const int idx_l0 = ggml_graph_get_node_idx(gf, name_l0);
  156. const int idx_l1 = mpi_rank > 0 ? ggml_graph_get_node_idx(gf, name_l1) + 1 : gf->n_nodes;
  157. if (idx_l0 < 0 || idx_l1 < 0) {
  158. fprintf(stderr, "%s: layer input nodes not found\n", __func__);
  159. return;
  160. }
  161. // attach the input data to all nodes that need it
  162. // TODO: not great - should be able to do this without modifying the compute graph (see next TODO below)
  163. for (int i = idx_l0; i < idx_l1; i++) {
  164. if (gf->nodes[i]->src[0] == gf->nodes[idx_l0]) {
  165. gf->nodes[i]->src[0] = inp0;
  166. }
  167. if (gf->nodes[i]->src[1] == gf->nodes[idx_l0]) {
  168. gf->nodes[i]->src[1] = inp0;
  169. }
  170. }
  171. // TODO: instead of rearranging the nodes, we should be able to execute a subset of the compute graph
  172. for (int i = 1; i < idx_l1 - idx_l0; i++) {
  173. gf->nodes[i] = gf->nodes[idx_l0 + i];
  174. gf->grads[i] = gf->grads[idx_l0 + i];
  175. }
  176. // the first node performs the "get_rows" operation, the rest of the nodes get the data from the previous node
  177. if (mpi_idx != 0) {
  178. gf->nodes[0]->op = GGML_OP_NONE;
  179. }
  180. gf->n_nodes = idx_l1 - idx_l0;
  181. //fprintf(stderr, "%s: node %d: processing %d nodes [%d, %d)\n", __func__, mpi_rank, gf->n_nodes, il0, il1);
  182. }
  183. }
  184. void ggml_mpi_graph_compute_post(
  185. struct ggml_mpi_context * ctx_mpi,
  186. struct ggml_cgraph * gf,
  187. int n_layers) {
  188. UNUSED(n_layers);
  189. const int mpi_rank = ctx_mpi->rank;
  190. const int mpi_size = ctx_mpi->size;
  191. // send the output data to the next node
  192. if (mpi_rank > 0) {
  193. ggml_mpi_tensor_send(gf->nodes[gf->n_nodes - 1], (mpi_rank + 1) % mpi_size);
  194. }
  195. }