Bladeren bron

Merge pull request #2422 from dhiltgen/better_kill

More robust shutdown
Daniel Hiltgen 1 jaar geleden
bovenliggende
commit
939c60473f
2 gewijzigde bestanden met toevoegingen van 44 en 1 verwijderingen
  1. 1 1
      llm/dyn_ext_server.go
  2. 43 0
      llm/ext_server/ext_server.cpp

+ 1 - 1
llm/dyn_ext_server.go

@@ -258,7 +258,7 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu
 					})
 					})
 				}
 				}
 
 
-				if p.Stop {
+				if p.Stop || bool(result.stop) {
 					fn(PredictResult{
 					fn(PredictResult{
 						Done:               true,
 						Done:               true,
 						PromptEvalCount:    p.Timings.PromptN,
 						PromptEvalCount:    p.Timings.PromptN,

+ 43 - 0
llm/ext_server/ext_server.cpp

@@ -1,4 +1,5 @@
 #include "ext_server.h"
 #include "ext_server.h"
+#include <atomic>
 
 
 // Necessary evil since the server types are not defined in a header
 // Necessary evil since the server types are not defined in a header
 #include "server.cpp"
 #include "server.cpp"
@@ -27,8 +28,24 @@
 // Expose the llama server as a callable extern "C" API
 // Expose the llama server as a callable extern "C" API
 llama_server_context *llama = NULL;
 llama_server_context *llama = NULL;
 std::thread ext_server_thread;
 std::thread ext_server_thread;
+bool shutting_down = false;
+std::atomic_int recv_counter;
 
 
+// RAII wrapper for tracking in-flight recv calls
+class atomicRecv {
+  public:
+    atomicRecv(std::atomic<int> &atomic) : atomic(atomic) {
+      ++this->atomic;
+    }
+    ~atomicRecv() {
+      --this->atomic;
+    }
+  private:
+    std::atomic<int> &atomic;
+};
+ 
 void llama_server_init(ext_server_params *sparams, ext_server_resp_t *err) {
 void llama_server_init(ext_server_params *sparams, ext_server_resp_t *err) {
+  recv_counter = 0;
   assert(err != NULL && sparams != NULL);
   assert(err != NULL && sparams != NULL);
   log_set_target(stderr);
   log_set_target(stderr);
   if (!sparams->verbose_logging) {
   if (!sparams->verbose_logging) {
@@ -151,7 +168,14 @@ void llama_server_start() {
 
 
 void llama_server_stop() {
 void llama_server_stop() {
   assert(llama != NULL);
   assert(llama != NULL);
+  // Shutdown any in-flight requests and block incoming requests.
   LOG_TEE("\ninitiating shutdown - draining remaining tasks...\n");
   LOG_TEE("\ninitiating shutdown - draining remaining tasks...\n");
+  shutting_down = true;
+
+  while (recv_counter.load() > 0) {
+    std::this_thread::sleep_for(std::chrono::milliseconds(50));
+  }
+
   // This may take a while for any pending tasks to drain
   // This may take a while for any pending tasks to drain
   // TODO - consider a timeout to cancel tasks if it's taking too long
   // TODO - consider a timeout to cancel tasks if it's taking too long
   llama->queue_tasks.terminate();
   llama->queue_tasks.terminate();
@@ -166,6 +190,9 @@ void llama_server_completion(const char *json_req, ext_server_resp_t *resp) {
   resp->id = -1;
   resp->id = -1;
   resp->msg[0] = '\0';
   resp->msg[0] = '\0';
   try {
   try {
+    if (shutting_down) {
+      throw std::runtime_error("server shutting down");
+    }
     json data = json::parse(json_req);
     json data = json::parse(json_req);
     resp->id = llama->queue_tasks.get_new_id();
     resp->id = llama->queue_tasks.get_new_id();
     llama->queue_results.add_waiting_task_id(resp->id);
     llama->queue_results.add_waiting_task_id(resp->id);
@@ -187,6 +214,7 @@ void llama_server_completion_next_result(const int task_id,
   resp->json_resp = NULL;
   resp->json_resp = NULL;
   std::string result_json;
   std::string result_json;
   try {
   try {
+    atomicRecv ar(recv_counter);
     task_result result = llama->queue_results.recv(task_id);
     task_result result = llama->queue_results.recv(task_id);
     result_json =
     result_json =
         result.result_json.dump(-1, ' ', false, json::error_handler_t::replace);
         result.result_json.dump(-1, ' ', false, json::error_handler_t::replace);
@@ -203,6 +231,11 @@ void llama_server_completion_next_result(const int task_id,
       llama->request_cancel(task_id);
       llama->request_cancel(task_id);
       LOG_TEE("next result removing waiting task ID: %d\n", task_id);
       LOG_TEE("next result removing waiting task ID: %d\n", task_id);
       llama->queue_results.remove_waiting_task_id(task_id);
       llama->queue_results.remove_waiting_task_id(task_id);
+    } else if (shutting_down) {
+      LOG_TEE("aborting completion due to shutdown %d\n", task_id);
+      llama->request_cancel(task_id);
+      llama->queue_results.remove_waiting_task_id(task_id);
+      resp->stop = true;
     }
     }
   } catch (std::exception &e) {
   } catch (std::exception &e) {
     resp->error = true;
     resp->error = true;
@@ -251,6 +284,9 @@ void llama_server_tokenize(const char *json_req, char **json_resp,
   err->id = 0;
   err->id = 0;
   err->msg[0] = '\0';
   err->msg[0] = '\0';
   try {
   try {
+    if (shutting_down) {
+      throw std::runtime_error("server shutting down");
+    }
     const json body = json::parse(json_req);
     const json body = json::parse(json_req);
     std::vector<llama_token> tokens;
     std::vector<llama_token> tokens;
     if (body.count("content") != 0) {
     if (body.count("content") != 0) {
@@ -284,6 +320,9 @@ void llama_server_detokenize(const char *json_req, char **json_resp,
   err->id = 0;
   err->id = 0;
   err->msg[0] = '\0';
   err->msg[0] = '\0';
   try {
   try {
+    if (shutting_down) {
+      throw std::runtime_error("server shutting down");
+    }
     const json body = json::parse(json_req);
     const json body = json::parse(json_req);
     std::string content;
     std::string content;
     if (body.count("tokens") != 0) {
     if (body.count("tokens") != 0) {
@@ -311,6 +350,9 @@ void llama_server_embedding(const char *json_req, char **json_resp,
   err->id = 0;
   err->id = 0;
   err->msg[0] = '\0';
   err->msg[0] = '\0';
   try {
   try {
+    if (shutting_down) {
+      throw std::runtime_error("server shutting down");
+    }
     const json body = json::parse(json_req);
     const json body = json::parse(json_req);
     json prompt;
     json prompt;
     if (body.count("content") != 0) {
     if (body.count("content") != 0) {
@@ -321,6 +363,7 @@ void llama_server_embedding(const char *json_req, char **json_resp,
     const int task_id = llama->queue_tasks.get_new_id();
     const int task_id = llama->queue_tasks.get_new_id();
     llama->queue_results.add_waiting_task_id(task_id);
     llama->queue_results.add_waiting_task_id(task_id);
     llama->request_completion(task_id, {{"prompt", prompt}, {"n_predict", 0}}, false, true, -1);
     llama->request_completion(task_id, {{"prompt", prompt}, {"n_predict", 0}}, false, true, -1);
+    atomicRecv ar(recv_counter);
     task_result result = llama->queue_results.recv(task_id);
     task_result result = llama->queue_results.recv(task_id);
     std::string result_json = result.result_json.dump();
     std::string result_json = result.result_json.dump();
     const std::string::size_type size = result_json.size() + 1;
     const std::string::size_type size = result_json.size() + 1;