|
@@ -1,4 +1,5 @@
|
|
#include "ext_server.h"
|
|
#include "ext_server.h"
|
|
|
|
+#include <atomic>
|
|
|
|
|
|
// Necessary evil since the server types are not defined in a header
|
|
// Necessary evil since the server types are not defined in a header
|
|
#include "server.cpp"
|
|
#include "server.cpp"
|
|
@@ -27,8 +28,24 @@
|
|
// Expose the llama server as a callable extern "C" API
|
|
// Expose the llama server as a callable extern "C" API
|
|
llama_server_context *llama = NULL;
|
|
llama_server_context *llama = NULL;
|
|
std::thread ext_server_thread;
|
|
std::thread ext_server_thread;
|
|
|
|
+bool shutting_down = false;
|
|
|
|
+std::atomic_int recv_counter;
|
|
|
|
|
|
|
|
+// RAII wrapper for tracking in-flight recv calls
|
|
|
|
+class atomicRecv {
|
|
|
|
+ public:
|
|
|
|
+ atomicRecv(std::atomic<int> &atomic) : atomic(atomic) {
|
|
|
|
+ ++this->atomic;
|
|
|
|
+ }
|
|
|
|
+ ~atomicRecv() {
|
|
|
|
+ --this->atomic;
|
|
|
|
+ }
|
|
|
|
+ private:
|
|
|
|
+ std::atomic<int> &atomic;
|
|
|
|
+};
|
|
|
|
+
|
|
void llama_server_init(ext_server_params *sparams, ext_server_resp_t *err) {
|
|
void llama_server_init(ext_server_params *sparams, ext_server_resp_t *err) {
|
|
|
|
+ recv_counter = 0;
|
|
assert(err != NULL && sparams != NULL);
|
|
assert(err != NULL && sparams != NULL);
|
|
log_set_target(stderr);
|
|
log_set_target(stderr);
|
|
if (!sparams->verbose_logging) {
|
|
if (!sparams->verbose_logging) {
|
|
@@ -151,7 +168,14 @@ void llama_server_start() {
|
|
|
|
|
|
void llama_server_stop() {
|
|
void llama_server_stop() {
|
|
assert(llama != NULL);
|
|
assert(llama != NULL);
|
|
|
|
+ // Shutdown any in-flight requests and block incoming requests.
|
|
LOG_TEE("\ninitiating shutdown - draining remaining tasks...\n");
|
|
LOG_TEE("\ninitiating shutdown - draining remaining tasks...\n");
|
|
|
|
+ shutting_down = true;
|
|
|
|
+
|
|
|
|
+ while (recv_counter.load() > 0) {
|
|
|
|
+ std::this_thread::sleep_for(std::chrono::milliseconds(50));
|
|
|
|
+ }
|
|
|
|
+
|
|
// This may take a while for any pending tasks to drain
|
|
// This may take a while for any pending tasks to drain
|
|
// TODO - consider a timeout to cancel tasks if it's taking too long
|
|
// TODO - consider a timeout to cancel tasks if it's taking too long
|
|
llama->queue_tasks.terminate();
|
|
llama->queue_tasks.terminate();
|
|
@@ -166,6 +190,9 @@ void llama_server_completion(const char *json_req, ext_server_resp_t *resp) {
|
|
resp->id = -1;
|
|
resp->id = -1;
|
|
resp->msg[0] = '\0';
|
|
resp->msg[0] = '\0';
|
|
try {
|
|
try {
|
|
|
|
+ if (shutting_down) {
|
|
|
|
+ throw std::runtime_error("server shutting down");
|
|
|
|
+ }
|
|
json data = json::parse(json_req);
|
|
json data = json::parse(json_req);
|
|
resp->id = llama->queue_tasks.get_new_id();
|
|
resp->id = llama->queue_tasks.get_new_id();
|
|
llama->queue_results.add_waiting_task_id(resp->id);
|
|
llama->queue_results.add_waiting_task_id(resp->id);
|
|
@@ -187,6 +214,7 @@ void llama_server_completion_next_result(const int task_id,
|
|
resp->json_resp = NULL;
|
|
resp->json_resp = NULL;
|
|
std::string result_json;
|
|
std::string result_json;
|
|
try {
|
|
try {
|
|
|
|
+ atomicRecv ar(recv_counter);
|
|
task_result result = llama->queue_results.recv(task_id);
|
|
task_result result = llama->queue_results.recv(task_id);
|
|
result_json =
|
|
result_json =
|
|
result.result_json.dump(-1, ' ', false, json::error_handler_t::replace);
|
|
result.result_json.dump(-1, ' ', false, json::error_handler_t::replace);
|
|
@@ -203,6 +231,11 @@ void llama_server_completion_next_result(const int task_id,
|
|
llama->request_cancel(task_id);
|
|
llama->request_cancel(task_id);
|
|
LOG_TEE("next result removing waiting task ID: %d\n", task_id);
|
|
LOG_TEE("next result removing waiting task ID: %d\n", task_id);
|
|
llama->queue_results.remove_waiting_task_id(task_id);
|
|
llama->queue_results.remove_waiting_task_id(task_id);
|
|
|
|
+ } else if (shutting_down) {
|
|
|
|
+ LOG_TEE("aborting completion due to shutdown %d\n", task_id);
|
|
|
|
+ llama->request_cancel(task_id);
|
|
|
|
+ llama->queue_results.remove_waiting_task_id(task_id);
|
|
|
|
+ resp->stop = true;
|
|
}
|
|
}
|
|
} catch (std::exception &e) {
|
|
} catch (std::exception &e) {
|
|
resp->error = true;
|
|
resp->error = true;
|
|
@@ -251,6 +284,9 @@ void llama_server_tokenize(const char *json_req, char **json_resp,
|
|
err->id = 0;
|
|
err->id = 0;
|
|
err->msg[0] = '\0';
|
|
err->msg[0] = '\0';
|
|
try {
|
|
try {
|
|
|
|
+ if (shutting_down) {
|
|
|
|
+ throw std::runtime_error("server shutting down");
|
|
|
|
+ }
|
|
const json body = json::parse(json_req);
|
|
const json body = json::parse(json_req);
|
|
std::vector<llama_token> tokens;
|
|
std::vector<llama_token> tokens;
|
|
if (body.count("content") != 0) {
|
|
if (body.count("content") != 0) {
|
|
@@ -284,6 +320,9 @@ void llama_server_detokenize(const char *json_req, char **json_resp,
|
|
err->id = 0;
|
|
err->id = 0;
|
|
err->msg[0] = '\0';
|
|
err->msg[0] = '\0';
|
|
try {
|
|
try {
|
|
|
|
+ if (shutting_down) {
|
|
|
|
+ throw std::runtime_error("server shutting down");
|
|
|
|
+ }
|
|
const json body = json::parse(json_req);
|
|
const json body = json::parse(json_req);
|
|
std::string content;
|
|
std::string content;
|
|
if (body.count("tokens") != 0) {
|
|
if (body.count("tokens") != 0) {
|
|
@@ -311,6 +350,9 @@ void llama_server_embedding(const char *json_req, char **json_resp,
|
|
err->id = 0;
|
|
err->id = 0;
|
|
err->msg[0] = '\0';
|
|
err->msg[0] = '\0';
|
|
try {
|
|
try {
|
|
|
|
+ if (shutting_down) {
|
|
|
|
+ throw std::runtime_error("server shutting down");
|
|
|
|
+ }
|
|
const json body = json::parse(json_req);
|
|
const json body = json::parse(json_req);
|
|
json prompt;
|
|
json prompt;
|
|
if (body.count("content") != 0) {
|
|
if (body.count("content") != 0) {
|
|
@@ -321,6 +363,7 @@ void llama_server_embedding(const char *json_req, char **json_resp,
|
|
const int task_id = llama->queue_tasks.get_new_id();
|
|
const int task_id = llama->queue_tasks.get_new_id();
|
|
llama->queue_results.add_waiting_task_id(task_id);
|
|
llama->queue_results.add_waiting_task_id(task_id);
|
|
llama->request_completion(task_id, {{"prompt", prompt}, {"n_predict", 0}}, false, true, -1);
|
|
llama->request_completion(task_id, {{"prompt", prompt}, {"n_predict", 0}}, false, true, -1);
|
|
|
|
+ atomicRecv ar(recv_counter);
|
|
task_result result = llama->queue_results.recv(task_id);
|
|
task_result result = llama->queue_results.recv(task_id);
|
|
std::string result_json = result.result_json.dump();
|
|
std::string result_json = result.result_json.dump();
|
|
const std::string::size_type size = result_json.size() + 1;
|
|
const std::string::size_type size = result_json.size() + 1;
|