02-shutdown.diff 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. diff --git a/examples/server/server.cpp b/examples/server/server.cpp
  2. index a0b46970..7800c6e7 100644
  3. --- a/examples/server/server.cpp
  4. +++ b/examples/server/server.cpp
  5. @@ -28,6 +28,7 @@
  6. #include <chrono>
  7. #include <condition_variable>
  8. #include <atomic>
  9. +#include <signal.h>
  10. using json = nlohmann::json;
  11. @@ -2511,6 +2512,9 @@ static void append_to_generated_text_from_generated_token_probs(llama_server_con
  12. }
  13. }
  14. +std::function<void(int)> shutdown_handler;
  15. +inline void signal_handler(int signal) { shutdown_handler(signal); }
  16. +
  17. int main(int argc, char **argv)
  18. {
  19. #if SERVER_VERBOSE != 1
  20. @@ -3128,8 +3132,25 @@ int main(int argc, char **argv)
  21. std::placeholders::_2,
  22. std::placeholders::_3
  23. ));
  24. - llama.queue_tasks.start_loop();
  25. + shutdown_handler = [&](int) {
  26. + llama.queue_tasks.terminate();
  27. + };
  28. +
  29. +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
  30. + struct sigaction sigint_action;
  31. + sigint_action.sa_handler = signal_handler;
  32. + sigemptyset (&sigint_action.sa_mask);
  33. + sigint_action.sa_flags = 0;
  34. + sigaction(SIGINT, &sigint_action, NULL);
  35. +#elif defined (_WIN32)
  36. + auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
  37. + return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
  38. + };
  39. + SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
  40. +#endif
  41. + llama.queue_tasks.start_loop();
  42. + svr.stop();
  43. t.join();
  44. llama_backend_free();
  45. diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp
  46. index 54854896..0ee670db 100644
  47. --- a/examples/server/utils.hpp
  48. +++ b/examples/server/utils.hpp
  49. @@ -220,6 +220,7 @@ inline std::string format_chatml(std::vector<json> messages)
  50. struct llama_server_queue {
  51. int id = 0;
  52. std::mutex mutex_tasks;
  53. + bool running;
  54. // queues
  55. std::vector<task_server> queue_tasks;
  56. std::vector<task_server> queue_tasks_deferred;
  57. @@ -278,9 +279,18 @@ struct llama_server_queue {
  58. queue_tasks_deferred.clear();
  59. }
  60. - // Start the main loop. This call is blocking
  61. - [[noreturn]]
  62. + // end the start_loop routine
  63. + void terminate() {
  64. + {
  65. + std::unique_lock<std::mutex> lock(mutex_tasks);
  66. + running = false;
  67. + }
  68. + condition_tasks.notify_all();
  69. + }
  70. +
  71. + // Start the main loop.
  72. void start_loop() {
  73. + running = true;
  74. while (true) {
  75. // new task arrived
  76. LOG_VERBOSE("have new task", {});
  77. @@ -324,8 +334,12 @@ struct llama_server_queue {
  78. {
  79. std::unique_lock<std::mutex> lock(mutex_tasks);
  80. if (queue_tasks.empty()) {
  81. + if (!running) {
  82. + LOG_VERBOSE("ending start_loop", {});
  83. + return;
  84. + }
  85. condition_tasks.wait(lock, [&]{
  86. - return !queue_tasks.empty();
  87. + return (!queue_tasks.empty() || !running);
  88. });
  89. }
  90. }