rocm_shim.c 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. #include "rocm_shim.h"
  2. #include <stdio.h>
  3. #include <string.h>
  4. #ifdef __linux__
  5. #include <dlfcn.h>
  6. #define LOAD_LIBRARY(lib, flags) dlopen(lib, flags | RTLD_DEEPBIND)
  7. #define LOAD_SYMBOL(handle, sym) dlsym(handle, sym)
  8. #define LOAD_ERR() dlerror()
  9. #define UNLOAD_LIBRARY(handle) dlclose(handle)
  10. #elif _WIN32
  11. #include <windows.h>
  12. #define LOAD_LIBRARY(lib, flags) LoadLibrary(lib)
  13. #define LOAD_SYMBOL(handle, sym) GetProcAddress(handle, sym)
  14. #define UNLOAD_LIBRARY(handle) FreeLibrary(handle)
  15. // TODO - refactor this with proper error message handling on windows
  16. inline static char *LOAD_ERR() {
  17. static char errbuf[8];
  18. snprintf(errbuf, 8, "0x%lx", GetLastError());
  19. return errbuf;
  20. }
  21. #else
  22. #include <dlfcn.h>
  23. #define LOAD_LIBRARY(lib, flags) dlopen(lib, flags)
  24. #define LOAD_SYMBOL(handle, sym) dlsym(handle, sym)
  25. #define LOAD_ERR() dlerror()
  26. #define UNLOAD_LIBRARY(handle) dlclose(handle)
  27. #endif
  28. void rocm_shim_init(const char *libPath, struct rocm_llama_server *s,
  29. ext_server_resp_t *err) {
  30. int i = 0;
  31. struct lookup {
  32. char *s;
  33. void **p;
  34. } l[] = {
  35. {"llama_server_init", (void *)&s->llama_server_init},
  36. {"llama_server_start", (void *)&s->llama_server_start},
  37. {"llama_server_stop", (void *)&s->llama_server_stop},
  38. {"llama_server_completion", (void *)&s->llama_server_completion},
  39. {"llama_server_completion_next_result",
  40. (void *)&s->llama_server_completion_next_result},
  41. {"llama_server_completion_cancel",
  42. (void *)&s->llama_server_completion_cancel},
  43. {"llama_server_release_task_result",
  44. (void *)&s->llama_server_release_task_result},
  45. {"llama_server_tokenize", (void *)&s->llama_server_tokenize},
  46. {"llama_server_detokenize", (void *)&s->llama_server_detokenize},
  47. {"llama_server_embedding", (void *)&s->llama_server_embedding},
  48. {"llama_server_release_json_resp",
  49. (void *)&s->llama_server_release_json_resp},
  50. {"", NULL},
  51. };
  52. printf("Lazy loading %s library\n", libPath);
  53. s->handle = LOAD_LIBRARY(libPath, RTLD_NOW);
  54. if (!s->handle) {
  55. err->id = -1;
  56. snprintf(
  57. err->msg, err->msg_len,
  58. "Unable to load rocm server library: %s (If you have a Radeon card, "
  59. "did you install the ROCM libraries?)",
  60. LOAD_ERR());
  61. return;
  62. }
  63. for (i = 0; l[i].p != NULL; i++) {
  64. *l[i].p = LOAD_SYMBOL(s->handle, l[i].s);
  65. if (!l[i].p) {
  66. UNLOAD_LIBRARY(s->handle);
  67. err->id = -1;
  68. snprintf(err->msg, err->msg_len, "symbol lookup for %s failed: %s",
  69. l[i].s, LOAD_ERR());
  70. return;
  71. }
  72. }
  73. }
  74. inline void rocm_shim_llama_server_init(struct rocm_llama_server s,
  75. ext_server_params_t *sparams,
  76. ext_server_resp_t *err) {
  77. s.llama_server_init(sparams, err);
  78. }
  79. inline void rocm_shim_llama_server_start(struct rocm_llama_server s) {
  80. s.llama_server_start();
  81. }
  82. inline void rocm_shim_llama_server_stop(struct rocm_llama_server s) {
  83. s.llama_server_stop();
  84. }
  85. inline void rocm_shim_llama_server_completion(struct rocm_llama_server s,
  86. const char *json_req,
  87. ext_server_resp_t *resp) {
  88. s.llama_server_completion(json_req, resp);
  89. }
  90. inline void rocm_shim_llama_server_completion_next_result(
  91. struct rocm_llama_server s, const int task_id,
  92. ext_server_task_result_t *result) {
  93. s.llama_server_completion_next_result(task_id, result);
  94. }
  95. inline void rocm_shim_llama_server_completion_cancel(struct rocm_llama_server s,
  96. const int task_id,
  97. ext_server_resp_t *err) {
  98. s.llama_server_completion_cancel(task_id, err);
  99. }
  100. inline void rocm_shim_llama_server_release_task_result(
  101. struct rocm_llama_server s, ext_server_task_result_t *result) {
  102. s.llama_server_release_task_result(result);
  103. }
  104. inline void rocm_shim_llama_server_tokenize(struct rocm_llama_server s,
  105. const char *json_req,
  106. char **json_resp,
  107. ext_server_resp_t *err) {
  108. s.llama_server_tokenize(json_req, json_resp, err);
  109. }
  110. inline void rocm_shim_llama_server_detokenize(struct rocm_llama_server s,
  111. const char *json_req,
  112. char **json_resp,
  113. ext_server_resp_t *err) {
  114. s.llama_server_detokenize(json_req, json_resp, err);
  115. }
  116. inline void rocm_shim_llama_server_embedding(struct rocm_llama_server s,
  117. const char *json_req,
  118. char **json_resp,
  119. ext_server_resp_t *err) {
  120. s.llama_server_embedding(json_req, json_resp, err);
  121. }
  122. inline void rocm_shim_llama_server_release_json_resp(struct rocm_llama_server s,
  123. char **json_resp) {
  124. s.llama_server_release_json_resp(json_resp);
  125. }