utils.hpp 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661
  1. // MIT License
  2. // Copyright (c) 2023 Georgi Gerganov
  3. // Permission is hereby granted, free of charge, to any person obtaining a copy
  4. // of this software and associated documentation files (the "Software"), to deal
  5. // in the Software without restriction, including without limitation the rights
  6. // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  7. // copies of the Software, and to permit persons to whom the Software is
  8. // furnished to do so, subject to the following conditions:
  9. // The above copyright notice and this permission notice shall be included in all
  10. // copies or substantial portions of the Software.
  11. // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  12. // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  13. // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  14. // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  15. // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  16. // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  17. // SOFTWARE.
  18. #pragma once
  19. #include <string>
  20. #include <vector>
  21. #include <set>
  22. #include <mutex>
  23. #include <condition_variable>
  24. #include <unordered_map>
  25. #include <random>
  26. #include <iostream>
  27. #include <thread>
  28. #include "json.hpp"
  29. #include "../llava/clip.h"
  30. using json = nlohmann::json;
  31. extern bool server_verbose;
  32. extern bool server_log_json;
  33. #ifndef SERVER_VERBOSE
  34. #define SERVER_VERBOSE 1
  35. #endif
  36. #if SERVER_VERBOSE != 1
  37. #define LOG_VERBOSE(MSG, ...)
  38. #else
  39. #define LOG_VERBOSE(MSG, ...) \
  40. do \
  41. { \
  42. if (server_verbose) \
  43. { \
  44. server_log("VERB", __func__, __LINE__, MSG, __VA_ARGS__); \
  45. } \
  46. } while (0)
  47. #endif
  48. #define LOG_ERROR( MSG, ...) server_log("ERROR", __func__, __LINE__, MSG, __VA_ARGS__)
  49. #define LOG_WARNING(MSG, ...) server_log("WARN", __func__, __LINE__, MSG, __VA_ARGS__)
  50. #define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__)
  51. #define LOG_DEBUG( MSG, ...) server_log("DEBUG", __func__, __LINE__, MSG, __VA_ARGS__)
  52. enum server_state {
  53. SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet
  54. SERVER_STATE_READY, // Server is ready and model is loaded
  55. SERVER_STATE_ERROR // An error occurred, load_model failed
  56. };
  57. enum task_type {
  58. TASK_TYPE_COMPLETION,
  59. TASK_TYPE_CANCEL,
  60. TASK_TYPE_NEXT_RESPONSE,
  61. TASK_TYPE_METRICS
  62. };
  63. struct task_server {
  64. int id = -1; // to be filled by llama_server_queue
  65. int target_id;
  66. task_type type;
  67. json data;
  68. bool infill_mode = false;
  69. bool embedding_mode = false;
  70. int multitask_id = -1;
  71. };
  72. struct task_result {
  73. int id;
  74. int multitask_id = -1;
  75. bool stop;
  76. bool error;
  77. json result_json;
  78. };
  79. struct task_multi {
  80. int id;
  81. std::set<int> subtasks_remaining{};
  82. std::vector<task_result> results{};
  83. };
  84. // completion token output with probabilities
  85. struct completion_token_output {
  86. struct token_prob
  87. {
  88. llama_token tok;
  89. float prob;
  90. };
  91. std::vector<token_prob> probs;
  92. llama_token tok;
  93. std::string text_to_send;
  94. };
  95. struct token_translator {
  96. llama_context * ctx;
  97. std::string operator()(llama_token tok) const { return llama_token_to_piece(ctx, tok); }
  98. std::string operator()(const completion_token_output &cto) const { return (*this)(cto.tok); }
  99. };
  100. static inline void server_log(const char *level, const char *function, int line, const char *message, const nlohmann::ordered_json &extra) {
  101. std::stringstream ss_tid;
  102. ss_tid << std::this_thread::get_id();
  103. json log = nlohmann::ordered_json{
  104. {"tid", ss_tid.str()},
  105. {"timestamp", time(nullptr)},
  106. };
  107. if (strncmp("DEBUG", level, strlen(level)) == 0 && !server_verbose) {
  108. return;
  109. }
  110. if (server_log_json) {
  111. log.merge_patch(
  112. {
  113. {"level", level},
  114. {"function", function},
  115. {"line", line},
  116. {"msg", message},
  117. });
  118. if (!extra.empty()) {
  119. log.merge_patch(extra);
  120. }
  121. std::cout << log.dump(-1, ' ', false, json::error_handler_t::replace) << "\n" << std::flush;
  122. } else {
  123. if (!extra.empty()) {
  124. log.merge_patch(extra);
  125. }
  126. std::stringstream ss;
  127. ss << level << " [" << function << "] " << message << " |";
  128. for (const auto& el : log.items())
  129. {
  130. const std::string value = el.value().dump(-1, ' ', false, json::error_handler_t::replace);
  131. ss << " " << el.key() << "=" << value;
  132. }
  133. const std::string str = ss.str();
  134. printf("%.*s\n", (int)str.size(), str.data());
  135. fflush(stdout);
  136. }
  137. }
  138. //
  139. // server utils
  140. //
  141. template <typename T>
  142. static T json_value(const json &body, const std::string &key, const T &default_value) {
  143. // Fallback null to default value
  144. return body.contains(key) && !body.at(key).is_null()
  145. ? body.value(key, default_value)
  146. : default_value;
  147. }
  148. // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
  149. inline bool verify_custom_template(const std::string & tmpl) {
  150. llama_chat_message chat[] = {{"user", "test"}};
  151. std::vector<char> buf(1);
  152. int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, buf.data(), buf.size());
  153. return res >= 0;
  154. }
  155. // Format given chat. If tmpl is empty, we take the template from model metadata
  156. inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) {
  157. size_t alloc_size = 0;
  158. // vector holding all allocated string to be passed to llama_chat_apply_template
  159. std::vector<std::string> str(messages.size() * 2);
  160. std::vector<llama_chat_message> chat(messages.size());
  161. for (size_t i = 0; i < messages.size(); ++i) {
  162. auto &curr_msg = messages[i];
  163. str[i*2 + 0] = json_value(curr_msg, "role", std::string(""));
  164. str[i*2 + 1] = json_value(curr_msg, "content", std::string(""));
  165. alloc_size += str[i*2 + 1].length();
  166. chat[i].role = str[i*2 + 0].c_str();
  167. chat[i].content = str[i*2 + 1].c_str();
  168. }
  169. const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
  170. std::vector<char> buf(alloc_size * 2);
  171. // run the first time to get the total output length
  172. int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size());
  173. // if it turns out that our buffer is too small, we resize it
  174. if ((size_t) res > buf.size()) {
  175. buf.resize(res);
  176. res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size());
  177. }
  178. std::string formatted_chat(buf.data(), res);
  179. LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}});
  180. return formatted_chat;
  181. }
  182. //
  183. // work queue utils
  184. //
  185. struct llama_server_queue {
  186. int id = 0;
  187. std::mutex mutex_tasks;
  188. bool running;
  189. // queues
  190. std::vector<task_server> queue_tasks;
  191. std::vector<task_server> queue_tasks_deferred;
  192. std::vector<task_multi> queue_multitasks;
  193. std::condition_variable condition_tasks;
  194. // callback functions
  195. std::function<void(task_server&)> callback_new_task;
  196. std::function<void(task_multi&)> callback_finish_multitask;
  197. std::function<void(void)> callback_run_slots;
  198. // Add a new task to the end of the queue
  199. int post(task_server task) {
  200. std::unique_lock<std::mutex> lock(mutex_tasks);
  201. if (task.id == -1) {
  202. task.id = id++;
  203. LOG_VERBOSE("new task id", {{"new_id", task.id}});
  204. }
  205. queue_tasks.push_back(std::move(task));
  206. condition_tasks.notify_one();
  207. return task.id;
  208. }
  209. // Add a new task, but defer until one slot is available
  210. void defer(task_server task) {
  211. std::unique_lock<std::mutex> lock(mutex_tasks);
  212. queue_tasks_deferred.push_back(std::move(task));
  213. }
  214. // Get the next id for creating anew task
  215. int get_new_id() {
  216. std::unique_lock<std::mutex> lock(mutex_tasks);
  217. int new_id = id++;
  218. LOG_VERBOSE("new task id", {{"new_id", new_id}});
  219. return new_id;
  220. }
  221. // Register function to process a new task
  222. void on_new_task(std::function<void(task_server&)> callback) {
  223. callback_new_task = callback;
  224. }
  225. // Register function to process a multitask when it is finished
  226. void on_finish_multitask(std::function<void(task_multi&)> callback) {
  227. callback_finish_multitask = callback;
  228. }
  229. // Register the function to be called when all slots data is ready to be processed
  230. void on_run_slots(std::function<void(void)> callback) {
  231. callback_run_slots = callback;
  232. }
  233. // Call when the state of one slot is changed
  234. void notify_slot_changed() {
  235. // move deferred tasks back to main loop
  236. std::unique_lock<std::mutex> lock(mutex_tasks);
  237. for (auto & task : queue_tasks_deferred) {
  238. queue_tasks.push_back(std::move(task));
  239. }
  240. queue_tasks_deferred.clear();
  241. }
  242. // end the start_loop routine
  243. void terminate() {
  244. {
  245. std::unique_lock<std::mutex> lock(mutex_tasks);
  246. running = false;
  247. }
  248. condition_tasks.notify_all();
  249. }
  250. /**
  251. * Main loop consists of these steps:
  252. * - Wait until a new task arrives
  253. * - Process the task (i.e. maybe copy data into slot)
  254. * - Check if multitask is finished
  255. * - Run all slots
  256. */
  257. void start_loop() {
  258. running = true;
  259. while (true) {
  260. LOG_VERBOSE("new task may arrive", {});
  261. {
  262. while (true)
  263. {
  264. std::unique_lock<std::mutex> lock(mutex_tasks);
  265. if (queue_tasks.empty()) {
  266. lock.unlock();
  267. break;
  268. }
  269. task_server task = queue_tasks.front();
  270. queue_tasks.erase(queue_tasks.begin());
  271. lock.unlock();
  272. LOG_VERBOSE("callback_new_task", {{"task_id", task.id}});
  273. callback_new_task(task);
  274. }
  275. LOG_VERBOSE("update_multitasks", {});
  276. // check if we have any finished multitasks
  277. auto queue_iterator = queue_multitasks.begin();
  278. while (queue_iterator != queue_multitasks.end())
  279. {
  280. if (queue_iterator->subtasks_remaining.empty())
  281. {
  282. // all subtasks done == multitask is done
  283. task_multi current_multitask = *queue_iterator;
  284. callback_finish_multitask(current_multitask);
  285. // remove this multitask
  286. queue_iterator = queue_multitasks.erase(queue_iterator);
  287. }
  288. else
  289. {
  290. ++queue_iterator;
  291. }
  292. }
  293. // all tasks in the current loop is processed, slots data is now ready
  294. LOG_VERBOSE("callback_run_slots", {});
  295. callback_run_slots();
  296. }
  297. LOG_VERBOSE("wait for new task", {});
  298. // wait for new task
  299. {
  300. std::unique_lock<std::mutex> lock(mutex_tasks);
  301. if (queue_tasks.empty()) {
  302. if (!running) {
  303. LOG_VERBOSE("ending start_loop", {});
  304. return;
  305. }
  306. condition_tasks.wait(lock, [&]{
  307. return (!queue_tasks.empty() || !running);
  308. });
  309. }
  310. }
  311. }
  312. }
  313. //
  314. // functions to manage multitasks
  315. //
  316. // add a multitask by specifying the id of all subtask (subtask is a task_server)
  317. void add_multitask(int multitask_id, std::vector<int>& sub_ids)
  318. {
  319. std::lock_guard<std::mutex> lock(mutex_tasks);
  320. task_multi multi;
  321. multi.id = multitask_id;
  322. std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
  323. queue_multitasks.push_back(multi);
  324. }
  325. // updatethe remaining subtasks, while appending results to multitask
  326. void update_multitask(int multitask_id, int subtask_id, task_result& result)
  327. {
  328. std::lock_guard<std::mutex> lock(mutex_tasks);
  329. for (auto& multitask : queue_multitasks)
  330. {
  331. if (multitask.id == multitask_id)
  332. {
  333. multitask.subtasks_remaining.erase(subtask_id);
  334. multitask.results.push_back(result);
  335. }
  336. }
  337. }
  338. };
  339. struct llama_server_response {
  340. typedef std::function<void(int, int, task_result&)> callback_multitask_t;
  341. callback_multitask_t callback_update_multitask;
  342. // for keeping track of all tasks waiting for the result
  343. std::set<int> waiting_task_ids;
  344. // the main result queue
  345. std::vector<task_result> queue_results;
  346. std::mutex mutex_results;
  347. std::condition_variable condition_results;
  348. // add the task_id to the list of tasks waiting for response
  349. void add_waiting_task_id(int task_id) {
  350. LOG_VERBOSE("waiting for task id", {{"task_id", task_id}});
  351. std::unique_lock<std::mutex> lock(mutex_results);
  352. waiting_task_ids.insert(task_id);
  353. }
  354. // when the request is finished, we can remove task associated with it
  355. void remove_waiting_task_id(int task_id) {
  356. LOG_VERBOSE("remove waiting for task id", {{"task_id", task_id}});
  357. std::unique_lock<std::mutex> lock(mutex_results);
  358. waiting_task_ids.erase(task_id);
  359. }
  360. // This function blocks the thread until there is a response for this task_id
  361. task_result recv(int task_id) {
  362. while (true)
  363. {
  364. std::unique_lock<std::mutex> lock(mutex_results);
  365. condition_results.wait(lock, [&]{
  366. return !queue_results.empty();
  367. });
  368. for (int i = 0; i < (int) queue_results.size(); i++)
  369. {
  370. if (queue_results[i].id == task_id)
  371. {
  372. assert(queue_results[i].multitask_id == -1);
  373. task_result res = queue_results[i];
  374. queue_results.erase(queue_results.begin() + i);
  375. return res;
  376. }
  377. }
  378. }
  379. // should never reach here
  380. }
  381. // Register the function to update multitask
  382. void on_multitask_update(callback_multitask_t callback) {
  383. callback_update_multitask = callback;
  384. }
  385. // Send a new result to a waiting task_id
  386. void send(task_result result) {
  387. std::unique_lock<std::mutex> lock(mutex_results);
  388. LOG_VERBOSE("send new result", {{"task_id", result.id}});
  389. for (auto& task_id : waiting_task_ids) {
  390. // LOG_TEE("waiting task id %i \n", task_id);
  391. // for now, tasks that have associated parent multitasks just get erased once multitask picks up the result
  392. if (result.multitask_id == task_id)
  393. {
  394. LOG_VERBOSE("callback_update_multitask", {{"task_id", task_id}});
  395. callback_update_multitask(task_id, result.id, result);
  396. continue;
  397. }
  398. if (result.id == task_id)
  399. {
  400. LOG_VERBOSE("queue_results.push_back", {{"task_id", task_id}});
  401. queue_results.push_back(result);
  402. condition_results.notify_all();
  403. return;
  404. }
  405. }
  406. }
  407. };
  408. //
  409. // base64 utils (TODO: move to common in the future)
  410. //
  411. static const std::string base64_chars =
  412. "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
  413. "abcdefghijklmnopqrstuvwxyz"
  414. "0123456789+/";
  415. static inline bool is_base64(uint8_t c)
  416. {
  417. return (isalnum(c) || (c == '+') || (c == '/'));
  418. }
  419. static inline std::vector<uint8_t> base64_decode(const std::string & encoded_string)
  420. {
  421. int i = 0;
  422. int j = 0;
  423. int in_ = 0;
  424. int in_len = encoded_string.size();
  425. uint8_t char_array_4[4];
  426. uint8_t char_array_3[3];
  427. std::vector<uint8_t> ret;
  428. while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_]))
  429. {
  430. char_array_4[i++] = encoded_string[in_]; in_++;
  431. if (i == 4)
  432. {
  433. for (i = 0; i <4; i++)
  434. {
  435. char_array_4[i] = base64_chars.find(char_array_4[i]);
  436. }
  437. char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4);
  438. char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
  439. char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
  440. for (i = 0; (i < 3); i++)
  441. {
  442. ret.push_back(char_array_3[i]);
  443. }
  444. i = 0;
  445. }
  446. }
  447. if (i)
  448. {
  449. for (j = i; j <4; j++)
  450. {
  451. char_array_4[j] = 0;
  452. }
  453. for (j = 0; j <4; j++)
  454. {
  455. char_array_4[j] = base64_chars.find(char_array_4[j]);
  456. }
  457. char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4);
  458. char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
  459. char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
  460. for (j = 0; (j < i - 1); j++)
  461. {
  462. ret.push_back(char_array_3[j]);
  463. }
  464. }
  465. return ret;
  466. }
  467. //
  468. // random string / id
  469. //
  470. static std::string random_string()
  471. {
  472. static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz");
  473. std::random_device rd;
  474. std::mt19937 generator(rd());
  475. std::string result(32, ' ');
  476. for (int i = 0; i < 32; ++i) {
  477. result[i] = str[generator() % str.size()];
  478. }
  479. return result;
  480. }
  481. static std::string gen_chatcmplid()
  482. {
  483. std::stringstream chatcmplid;
  484. chatcmplid << "chatcmpl-" << random_string();
  485. return chatcmplid.str();
  486. }
  487. //
  488. // other common utils
  489. //
  490. static size_t common_part(const std::vector<llama_token> &a, const std::vector<llama_token> &b)
  491. {
  492. size_t i;
  493. for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++)
  494. {
  495. }
  496. return i;
  497. }
  498. static bool ends_with(const std::string &str, const std::string &suffix)
  499. {
  500. return str.size() >= suffix.size() &&
  501. 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
  502. }
  503. static size_t find_partial_stop_string(const std::string &stop,
  504. const std::string &text)
  505. {
  506. if (!text.empty() && !stop.empty())
  507. {
  508. const char text_last_char = text.back();
  509. for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--)
  510. {
  511. if (stop[char_index] == text_last_char)
  512. {
  513. const std::string current_partial = stop.substr(0, char_index + 1);
  514. if (ends_with(text, current_partial))
  515. {
  516. return text.size() - char_index - 1;
  517. }
  518. }
  519. }
  520. }
  521. return std::string::npos;
  522. }
  523. // TODO: reuse llama_detokenize
  524. template <class Iter>
  525. static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end)
  526. {
  527. std::string ret;
  528. for (; begin != end; ++begin)
  529. {
  530. ret += llama_token_to_piece(ctx, *begin);
  531. }
  532. return ret;
  533. }
  534. // format incomplete utf-8 multibyte character for output
  535. static std::string tokens_to_output_formatted_string(const llama_context *ctx, const llama_token token)
  536. {
  537. std::string out = token == -1 ? "" : llama_token_to_piece(ctx, token);
  538. // if the size is 1 and first bit is 1, meaning it's a partial character
  539. // (size > 1 meaning it's already a known token)
  540. if (out.size() == 1 && (out[0] & 0x80) == 0x80)
  541. {
  542. std::stringstream ss;
  543. ss << std::hex << (out[0] & 0xff);
  544. std::string res(ss.str());
  545. out = "byte: \\x" + res;
  546. }
  547. return out;
  548. }
  549. // convert a vector of completion_token_output to json
  550. static json probs_vector_to_json(const llama_context *ctx, const std::vector<completion_token_output> &probs)
  551. {
  552. json out = json::array();
  553. for (const auto &prob : probs)
  554. {
  555. json probs_for_token = json::array();
  556. for (const auto &p : prob.probs)
  557. {
  558. std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok);
  559. probs_for_token.push_back(json
  560. {
  561. {"tok_str", tok_str},
  562. {"prob", p.prob},
  563. });
  564. }
  565. std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok);
  566. out.push_back(json{
  567. {"content", tok_str},
  568. {"probs", probs_for_token},
  569. });
  570. }
  571. return out;
  572. }