unicode.cpp 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865
  1. #if defined(_MSC_VER)
  2. #define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
  3. #endif
  4. #if defined(_WIN32)
  5. #define WIN32_LEAN_AND_MEAN
  6. #include <windows.h>
  7. #endif
  8. #include "unicode.h"
  9. #include "unicode-data.h"
  10. #include <algorithm>
  11. #include <cassert>
  12. #include <cstddef>
  13. #include <cstdint>
  14. #include <map>
  15. #include <regex>
  16. #include <stdexcept>
  17. #include <string>
  18. #include <unordered_map>
  19. #include <unordered_set>
  20. #include <utility>
  21. #include <vector>
  22. #include <locale>
  23. #include <codecvt>
  24. size_t unicode_len_utf8(char src) {
  25. const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
  26. uint8_t highbits = static_cast<uint8_t>(src) >> 4;
  27. return lookup[highbits];
  28. }
  29. static std::string unicode_cpts_to_utf8(const std::vector<uint32_t> & cps) {
  30. std::string result;
  31. for (size_t i = 0; i < cps.size(); ++i) {
  32. result.append(unicode_cpt_to_utf8(cps[i]));
  33. }
  34. return result;
  35. }
  36. uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) {
  37. assert(offset < utf8.size());
  38. if (!(utf8[offset + 0] & 0x80)) {
  39. auto result = utf8[offset + 0];
  40. offset += 1;
  41. return result;
  42. }
  43. if (!(utf8[offset + 0] & 0x40)) {
  44. throw std::invalid_argument("invalid character");
  45. }
  46. if (!(utf8[offset + 0] & 0x20)) {
  47. if (offset + 1 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80)) {
  48. throw std::invalid_argument("invalid character");
  49. }
  50. auto result = ((utf8[offset + 0] & 0x1f) << 6) | (utf8[offset + 1] & 0x3f);
  51. offset += 2;
  52. return result;
  53. }
  54. if (!(utf8[offset + 0] & 0x10)) {
  55. if (offset + 2 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80)) {
  56. throw std::invalid_argument("invalid character");
  57. }
  58. auto result = ((utf8[offset + 0] & 0x0f) << 12) | ((utf8[offset + 1] & 0x3f) << 6) | (utf8[offset + 2] & 0x3f);
  59. offset += 3;
  60. return result;
  61. }
  62. if (!(utf8[offset + 0] & 0x08)) {
  63. if (offset + 3 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80) || !((utf8[offset + 3] & 0xc0) == 0x80)) {
  64. throw std::invalid_argument("invalid character");
  65. }
  66. auto result = ((utf8[offset + 0] & 0x07) << 18) | ((utf8[offset + 1] & 0x3f) << 12) | ((utf8[offset + 2] & 0x3f) << 6) | (utf8[offset + 3] & 0x3f);
  67. offset += 4;
  68. return result;
  69. }
  70. throw std::invalid_argument("failed to convert utf8 to codepoint");
  71. }
  72. //static std::vector<uint16_t> unicode_cpt_to_utf16(uint32_t cpt) {
  73. // std::vector<uint16_t> result;
  74. // if (/* 0x0000 <= cpt && */ cpt <= 0xffff) {
  75. // result.emplace_back(cpt);
  76. // return result;
  77. // }
  78. // if (0x10000 <= cpt && cpt <= 0x10ffff) {
  79. // result.emplace_back(0xd800 | ((cpt - 0x10000) >> 10));
  80. // result.emplace_back(0xdc00 | ((cpt - 0x10000) & 0x03ff));
  81. // return result;
  82. // }
  83. // throw std::invalid_argument("failed to convert codepoint to utf16");
  84. //}
  85. //static std::vector<uint16_t> unicode_cpts_to_utf16(const std::vector<uint32_t> & cps) {
  86. // std::vector<uint16_t> result;
  87. // for (size_t i = 0; i < cps.size(); ++i) {
  88. // auto temp = unicode_cpt_to_utf16(cps[i]);
  89. // result.insert(result.end(), temp.begin(), temp.end());
  90. // }
  91. // return result;
  92. //}
  93. //static uint32_t unicode_cpt_from_utf16(const std::vector<uint16_t> & utf16, size_t & offset) {
  94. // assert(offset < utf16.size());
  95. // if (((utf16[0] >> 10) << 10) != 0xd800) {
  96. // auto result = utf16[offset + 0];
  97. // offset += 1;
  98. // return result;
  99. // }
  100. //
  101. // if (offset + 1 >= utf16.size() || !((utf16[1] & 0xdc00) == 0xdc00)) {
  102. // throw std::invalid_argument("invalid character");
  103. // }
  104. //
  105. // auto result = 0x10000 + (((utf16[0] & 0x03ff) << 10) | (utf16[1] & 0x03ff));
  106. // offset += 2;
  107. // return result;
  108. //}
  109. //static std::vector<uint32_t> unicode_cpts_from_utf16(const std::vector<uint16_t> & utf16) {
  110. // std::vector<uint32_t> result;
  111. // size_t offset = 0;
  112. // while (offset < utf16.size()) {
  113. // result.push_back(unicode_cpt_from_utf16(utf16, offset));
  114. // }
  115. // return result;
  116. //}
  117. static std::vector<unicode_cpt_flags> unicode_cpt_flags_array() {
  118. std::vector<unicode_cpt_flags> cpt_flags(MAX_CODEPOINTS, unicode_cpt_flags::UNDEFINED);
  119. assert (unicode_ranges_flags.begin()[0].first == 0);
  120. assert (unicode_ranges_flags.begin()[unicode_ranges_flags.size()-1].first == MAX_CODEPOINTS);
  121. for (size_t i = 1; i < unicode_ranges_flags.size(); ++i) {
  122. const auto range_ini = unicode_ranges_flags.begin()[i-1]; // codepoint_ini, flags
  123. const auto range_end = unicode_ranges_flags.begin()[i]; // codepoint_end, flags
  124. for (uint32_t cpt = range_ini.first; cpt < range_end.first; ++cpt) {
  125. cpt_flags[cpt] = range_ini.second;
  126. }
  127. }
  128. for (auto cpt : unicode_set_whitespace) {
  129. cpt_flags[cpt].is_whitespace = true;
  130. }
  131. for (auto p : unicode_map_lowercase) {
  132. cpt_flags[p.second].is_lowercase = true;
  133. }
  134. for (auto p : unicode_map_uppercase) {
  135. cpt_flags[p.second].is_uppercase = true;
  136. }
  137. for (auto &range : unicode_ranges_nfd) { // start, last, nfd
  138. cpt_flags[range.nfd].is_nfd = true;
  139. }
  140. return cpt_flags;
  141. }
  142. static std::unordered_map<uint8_t, std::string> unicode_byte_to_utf8_map() {
  143. std::unordered_map<uint8_t, std::string> map;
  144. for (int ch = 0x21; ch <= 0x7E; ++ch) { // u'!' to u'~'
  145. assert(0 <= ch && ch < 256);
  146. map[ch] = unicode_cpt_to_utf8(ch);
  147. }
  148. for (int ch = 0xA1; ch <= 0xAC; ++ch) { // u'¡' to u'¬'
  149. assert(0 <= ch && ch < 256);
  150. map[ch] = unicode_cpt_to_utf8(ch);
  151. }
  152. for (int ch = 0xAE; ch <= 0xFF; ++ch) { // u'®' to u'ÿ'
  153. assert(0 <= ch && ch < 256);
  154. map[ch] = unicode_cpt_to_utf8(ch);
  155. }
  156. auto n = 0;
  157. for (int ch = 0; ch < 256; ++ch) {
  158. if (map.find(ch) == map.end()) {
  159. map[ch] = unicode_cpt_to_utf8(256 + n);
  160. ++n;
  161. }
  162. }
  163. return map;
  164. }
  165. static std::unordered_map<std::string, uint8_t> unicode_utf8_to_byte_map() {
  166. std::unordered_map<std::string, uint8_t> map;
  167. for (int ch = 0x21; ch <= 0x7E; ++ch) { // u'!' to u'~'
  168. assert(0 <= ch && ch < 256);
  169. map[unicode_cpt_to_utf8(ch)] = ch;
  170. }
  171. for (int ch = 0xA1; ch <= 0xAC; ++ch) { // u'¡' to u'¬'
  172. assert(0 <= ch && ch < 256);
  173. map[unicode_cpt_to_utf8(ch)] = ch;
  174. }
  175. for (int ch = 0xAE; ch <= 0xFF; ++ch) { // u'®' to u'ÿ'
  176. assert(0 <= ch && ch < 256);
  177. map[unicode_cpt_to_utf8(ch)] = ch;
  178. }
  179. auto n = 0;
  180. for (int ch = 0; ch < 256; ++ch) {
  181. if (map.find(unicode_cpt_to_utf8(ch)) == map.end()) {
  182. map[unicode_cpt_to_utf8(256 + n)] = ch;
  183. ++n;
  184. }
  185. }
  186. return map;
  187. }
  188. static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
  189. #ifdef _WIN32
  190. int wlen = MultiByteToWideChar(CP_UTF8, 0, s.c_str(), -1, NULL, 0);
  191. if (!wlen) {
  192. throw std::invalid_argument("failed to convert regex");
  193. }
  194. wchar_t * wbuf = (wchar_t *) malloc(wlen * sizeof(wchar_t));
  195. wlen = MultiByteToWideChar(CP_UTF8, 0, s.c_str(), -1, wbuf, wlen);
  196. if (!wlen) {
  197. free(wbuf);
  198. throw std::invalid_argument("failed to convert regex");
  199. }
  200. std::wstring ret = std::wstring(wbuf);
  201. free(wbuf);
  202. return ret;
  203. #else
  204. #if defined(__clang__)
  205. // disable C++17 deprecation warning for std::codecvt_utf8
  206. # pragma clang diagnostic push
  207. # pragma clang diagnostic ignored "-Wdeprecated-declarations"
  208. #endif
  209. std::wstring_convert<std::codecvt_utf8<wchar_t>> conv;
  210. #if defined(__clang__)
  211. # pragma clang diagnostic pop
  212. #endif
  213. return conv.from_bytes(s);
  214. #endif
  215. }
  216. static std::vector<std::string> unicode_byte_encoding_process(const std::vector<std::string> & bpe_words) {
  217. std::vector<std::string> bpe_encoded_words;
  218. for (const auto & word : bpe_words) {
  219. std::string text_utf;
  220. auto utf_word = unicode_cpts_from_utf8(word);
  221. for (size_t i = 0; i < utf_word.size(); ++i) {
  222. text_utf += unicode_cpt_to_utf8(utf_word[i]);
  223. }
  224. std::string encoded_token;
  225. for (char & c : text_utf) {
  226. encoded_token += unicode_byte_to_utf8(c);
  227. }
  228. bpe_encoded_words.emplace_back(encoded_token);
  229. }
  230. return bpe_encoded_words;
  231. }
  232. // GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
  233. static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & text, const std::vector<size_t> & offsets) {
  234. std::vector<size_t> bpe_offsets; // store the offset of each word
  235. bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
  236. const auto cpts = unicode_cpts_from_utf8(text);
  237. size_t start = 0;
  238. for (auto offset : offsets) {
  239. const size_t offset_ini = start;
  240. const size_t offset_end = start + offset;
  241. assert(offset_end <= cpts.size());
  242. start = offset_end;
  243. static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
  244. auto _get_cpt = [&] (const size_t pos) -> uint32_t {
  245. return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
  246. };
  247. auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
  248. return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
  249. };
  250. size_t _prev_end = offset_ini;
  251. auto _add_token = [&] (const size_t end) -> size_t {
  252. assert(_prev_end <= end && end <= offset_end);
  253. size_t len = end - _prev_end;
  254. if (len > 0) {
  255. bpe_offsets.push_back(len);
  256. }
  257. _prev_end = end;
  258. //if (len > 0) {
  259. // std::string s = "";
  260. // for(size_t p = end-len; p < end; p++)
  261. // s += unicode_cpt_to_utf8(cpts[p]);
  262. // printf(">>> '%s'\n", s.c_str());
  263. //}
  264. return len;
  265. };
  266. for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
  267. const uint32_t cpt = _get_cpt(pos);
  268. const auto flags = _get_flags(pos);
  269. // regex: 's|'t|'re|'ve|'m|'ll|'d
  270. if (cpt == '\'' && pos+1 < offset_end) {
  271. uint32_t cpt_next = _get_cpt(pos+1);
  272. if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
  273. pos += _add_token(pos+2);
  274. continue;
  275. }
  276. if (pos+2 < offset_end) {
  277. uint32_t cpt_next_next = _get_cpt(pos+2);
  278. if ((cpt_next == 'r' && cpt_next_next == 'e') ||
  279. (cpt_next == 'v' && cpt_next_next == 'e') ||
  280. (cpt_next == 'l' && cpt_next_next == 'l')) {
  281. pos += _add_token(pos+3);
  282. continue;
  283. }
  284. }
  285. }
  286. auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
  287. // regex: <space>?\p{L}+
  288. if (flags2.is_letter) {
  289. pos += (cpt == ' ');
  290. while (flags2.is_letter) {
  291. flags2 = _get_flags(++pos);
  292. }
  293. _add_token(pos);
  294. continue;
  295. }
  296. // regex: <space>?\p{N}+
  297. if (flags2.is_number) {
  298. pos += (cpt == ' ');
  299. while (flags2.is_number) {
  300. flags2 = _get_flags(++pos);
  301. }
  302. _add_token(pos);
  303. continue;
  304. }
  305. // regex: <space>?[^\s\p{L}\p{N}]+
  306. if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
  307. pos += (cpt == ' ');
  308. while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
  309. flags2 = _get_flags(++pos);
  310. }
  311. _add_token(pos);
  312. continue;
  313. }
  314. size_t num_whitespaces = 0;
  315. while (_get_flags(pos+num_whitespaces).is_whitespace) {
  316. num_whitespaces++;
  317. }
  318. // regex: \s+(?!\S)
  319. if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) {
  320. pos += num_whitespaces - 1;
  321. _add_token(pos);
  322. continue;
  323. }
  324. // regex: \s+
  325. if (num_whitespaces > 0) {
  326. pos += num_whitespaces;
  327. _add_token(pos);
  328. continue;
  329. }
  330. // no matches
  331. _add_token(++pos);
  332. }
  333. }
  334. return bpe_offsets;
  335. }
  336. // LLAMA3 system regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
  337. static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string & text, const std::vector<size_t> & offsets) {
  338. std::vector<size_t> bpe_offsets; // store the offset of each word
  339. bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
  340. const auto cpts = unicode_cpts_from_utf8(text);
  341. size_t start = 0;
  342. for (auto offset : offsets) {
  343. const size_t offset_ini = start;
  344. const size_t offset_end = start + offset;
  345. assert(offset_end <= cpts.size());
  346. start = offset_end;
  347. static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
  348. auto _get_cpt = [&] (const size_t pos) -> uint32_t {
  349. return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
  350. };
  351. auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
  352. return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
  353. };
  354. size_t _prev_end = offset_ini;
  355. auto _add_token = [&] (const size_t end) -> size_t {
  356. assert(_prev_end <= end && end <= offset_end);
  357. size_t len = end - _prev_end;
  358. if (len > 0) {
  359. bpe_offsets.push_back(len);
  360. }
  361. _prev_end = end;
  362. //if (len > 0) {
  363. // std::string s = "";
  364. // for(size_t p = end-len; p < end; p++)
  365. // s += unicode_cpt_to_utf8(cpts[p]);
  366. // printf(">>> '%s'\n", s.c_str());
  367. //}
  368. return len;
  369. };
  370. for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
  371. const uint32_t cpt = _get_cpt(pos);
  372. const auto flags = _get_flags(pos);
  373. // regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
  374. if (cpt == '\'' && pos+1 < offset_end) {
  375. uint32_t cpt_next = unicode_tolower(_get_cpt(pos+1));
  376. if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
  377. pos += _add_token(pos+2);
  378. continue;
  379. }
  380. if (pos+2 < offset_end) {
  381. uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos+2));
  382. if ((cpt_next == 'r' && cpt_next_next == 'e') ||
  383. (cpt_next == 'v' && cpt_next_next == 'e') ||
  384. (cpt_next == 'l' && cpt_next_next == 'l')) {
  385. pos += _add_token(pos+3);
  386. continue;
  387. }
  388. }
  389. }
  390. // regex: [^\r\n\p{L}\p{N}]?\p{L}+
  391. if (!(cpt == '\r' || cpt == '\n' || flags.is_number)) {
  392. if (flags.is_letter || _get_flags(pos+1).is_letter) { // one or more letters
  393. pos++;
  394. while (_get_flags(pos).is_letter) {
  395. pos++;
  396. }
  397. _add_token(pos);
  398. continue;
  399. }
  400. }
  401. // regex: \p{N}{1,3}
  402. if (flags.is_number) {
  403. size_t ini = pos;
  404. while (_get_flags(pos).is_number) {
  405. if (++pos - ini >= 3 ) {
  406. _add_token(pos);
  407. ini = pos;
  408. }
  409. }
  410. _add_token(pos);
  411. continue;
  412. }
  413. // regex: <space>?[^\s\p{L}\p{N}]+[\r\n]*
  414. auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
  415. if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags.as_uint()) {
  416. pos += (cpt == ' ');
  417. while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
  418. flags2 = _get_flags(++pos);
  419. }
  420. uint32_t cpt2 = _get_cpt(pos);
  421. while (cpt2 == '\r' || cpt2 == '\n') {
  422. cpt2 = _get_cpt(++pos);
  423. }
  424. _add_token(pos);
  425. continue;
  426. }
  427. size_t num_whitespaces = 0;
  428. size_t last_end_r_or_n = 0;
  429. while (_get_flags(pos+num_whitespaces).is_whitespace) {
  430. uint32_t cpt2 = _get_cpt(pos+num_whitespaces);
  431. if (cpt2 == '\r' || cpt2 == '\n') {
  432. last_end_r_or_n = pos + num_whitespaces + 1;
  433. }
  434. num_whitespaces++;
  435. }
  436. // regex: \s*[\r\n]+
  437. if (last_end_r_or_n > 0) {
  438. pos = last_end_r_or_n;
  439. _add_token(pos);
  440. continue;
  441. }
  442. // regex: \s+(?!\S)
  443. if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) {
  444. pos += num_whitespaces - 1;
  445. _add_token(pos);
  446. continue;
  447. }
  448. // regex: \s+
  449. if (num_whitespaces > 0) {
  450. pos += num_whitespaces;
  451. _add_token(pos);
  452. continue;
  453. }
  454. // no matches
  455. _add_token(++pos);
  456. }
  457. }
  458. return bpe_offsets;
  459. }
  460. // use std::wregex to split the text
  461. static std::vector<size_t> unicode_regex_split_stl(const std::wstring & wtext, const std::wstring & regex_expr, const std::vector<size_t> & offsets) {
  462. std::wregex expr(regex_expr);
  463. std::vector<size_t> bpe_offsets; // store the offset of each word
  464. bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
  465. size_t start = 0;
  466. for (auto offset : offsets) {
  467. std::wcregex_iterator it(wtext.data() + start, wtext.data() + start + offset, expr);
  468. std::wcregex_iterator end;
  469. int64_t start_idx = 0;
  470. while (it != end) {
  471. std::wcmatch match = *it;
  472. if (match.position() > start_idx) {
  473. bpe_offsets.emplace_back(match.position() - start_idx);
  474. }
  475. bpe_offsets.emplace_back(match.length());
  476. start_idx = match.position() + match.length();
  477. ++it;
  478. }
  479. if (start_idx < (int64_t) offset) {
  480. bpe_offsets.emplace_back(offset - start_idx);
  481. }
  482. start += offset;
  483. }
  484. return bpe_offsets;
  485. }
  486. // use std::regex to split the text
  487. static std::vector<size_t> unicode_regex_split_stl(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
  488. std::regex expr(regex_expr);
  489. std::vector<size_t> bpe_offsets; // store the offset of each word
  490. bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
  491. size_t start = 0;
  492. for (auto offset : offsets) {
  493. std::cregex_iterator it(text.data() + start, text.data() + start + offset, expr);
  494. std::cregex_iterator end;
  495. int64_t start_idx = 0;
  496. while (it != end) {
  497. std::cmatch match = *it;
  498. if (match.position() > start_idx) {
  499. bpe_offsets.emplace_back(match.position() - start_idx);
  500. }
  501. bpe_offsets.emplace_back(match.length());
  502. start_idx = match.position() + match.length();
  503. ++it;
  504. }
  505. if (start_idx < (int64_t) offset) {
  506. bpe_offsets.emplace_back(offset - start_idx);
  507. }
  508. start += offset;
  509. }
  510. return bpe_offsets;
  511. }
  512. static std::vector<size_t> unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
  513. std::vector<size_t> bpe_offsets;
  514. if (regex_expr == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") {
  515. bpe_offsets = unicode_regex_split_custom_gpt2(text, offsets);
  516. } else if (
  517. regex_expr == "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" ||
  518. regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") {
  519. bpe_offsets = unicode_regex_split_custom_llama3(text, offsets);
  520. }
  521. return bpe_offsets;
  522. }
  523. //
  524. // interface
  525. //
  526. std::string unicode_cpt_to_utf8(uint32_t cpt) {
  527. std::string result;
  528. if (/* 0x00 <= cpt && */ cpt <= 0x7f) {
  529. result.push_back(cpt);
  530. return result;
  531. }
  532. if (0x80 <= cpt && cpt <= 0x7ff) {
  533. result.push_back(0xc0 | ((cpt >> 6) & 0x1f));
  534. result.push_back(0x80 | (cpt & 0x3f));
  535. return result;
  536. }
  537. if (0x800 <= cpt && cpt <= 0xffff) {
  538. result.push_back(0xe0 | ((cpt >> 12) & 0x0f));
  539. result.push_back(0x80 | ((cpt >> 6) & 0x3f));
  540. result.push_back(0x80 | (cpt & 0x3f));
  541. return result;
  542. }
  543. if (0x10000 <= cpt && cpt <= 0x10ffff) {
  544. result.push_back(0xf0 | ((cpt >> 18) & 0x07));
  545. result.push_back(0x80 | ((cpt >> 12) & 0x3f));
  546. result.push_back(0x80 | ((cpt >> 6) & 0x3f));
  547. result.push_back(0x80 | (cpt & 0x3f));
  548. return result;
  549. }
  550. throw std::invalid_argument("invalid codepoint");
  551. }
  552. std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts) {
  553. auto comp = [] (const uint32_t cpt, const range_nfd & range) {
  554. return cpt < range.first;
  555. };
  556. std::vector<uint32_t> result(cpts.size());
  557. for (size_t i = 0; i < cpts.size(); ++i) {
  558. const uint32_t cpt = cpts[i];
  559. auto it = std::upper_bound(unicode_ranges_nfd.begin(), unicode_ranges_nfd.end(), cpt, comp) - 1;
  560. result[i] = (it->first <= cpt && cpt <= it->last) ? it->nfd : cpt;
  561. }
  562. return result;
  563. }
  564. std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
  565. std::vector<uint32_t> result;
  566. result.reserve(utf8.size());
  567. size_t offset = 0;
  568. while (offset < utf8.size()) {
  569. result.push_back(unicode_cpt_from_utf8(utf8, offset));
  570. }
  571. return result;
  572. }
  573. unicode_cpt_flags unicode_cpt_flags_from_cpt(const uint32_t cpt) {
  574. static const unicode_cpt_flags undef(unicode_cpt_flags::UNDEFINED);
  575. static const auto cpt_flags = unicode_cpt_flags_array();
  576. return cpt < cpt_flags.size() ? cpt_flags[cpt] : undef;
  577. }
  578. unicode_cpt_flags unicode_cpt_flags_from_utf8(const std::string & utf8) {
  579. static const unicode_cpt_flags undef(unicode_cpt_flags::UNDEFINED);
  580. if (utf8.empty()) {
  581. return undef; // undefined
  582. }
  583. size_t offset = 0;
  584. return unicode_cpt_flags_from_cpt(unicode_cpt_from_utf8(utf8, offset));
  585. }
  586. std::string unicode_byte_to_utf8(uint8_t byte) {
  587. static std::unordered_map<uint8_t, std::string> map = unicode_byte_to_utf8_map();
  588. return map.at(byte);
  589. }
  590. uint8_t unicode_utf8_to_byte(const std::string & utf8) {
  591. static std::unordered_map<std::string, uint8_t> map = unicode_utf8_to_byte_map();
  592. return map.at(utf8);
  593. }
  594. uint32_t unicode_tolower(uint32_t cpt) {
  595. // binary search
  596. auto it = std::lower_bound(unicode_map_lowercase.begin(), unicode_map_lowercase.end(), cpt,
  597. [](const std::pair<uint32_t, uint32_t> & pair, uint32_t value) {
  598. return pair.first < value;
  599. });
  600. if (it != unicode_map_lowercase.end() && it->first == cpt) {
  601. return it->second;
  602. }
  603. return cpt; // Return the original code point if no lowercase mapping is found
  604. }
  605. std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
  606. // unicode categories
  607. static const std::map<std::string, int> k_ucat_enum = {
  608. { "\\p{N}", unicode_cpt_flags::NUMBER },
  609. { "\\p{L}", unicode_cpt_flags::LETTER },
  610. { "\\p{P}", unicode_cpt_flags::PUNCTUATION },
  611. { "\\p{M}", unicode_cpt_flags::ACCENT_MARK },
  612. { "\\p{S}", unicode_cpt_flags::SYMBOL },
  613. };
  614. static const std::map<int, int> k_ucat_cpt = {
  615. { unicode_cpt_flags::NUMBER, 0xD1 },
  616. { unicode_cpt_flags::LETTER, 0xD2 },
  617. { unicode_cpt_flags::PUNCTUATION, 0xD3 },
  618. { unicode_cpt_flags::ACCENT_MARK, 0xD4 },
  619. { unicode_cpt_flags::SYMBOL, 0xD5 },
  620. };
  621. static const std::map<int, std::string> k_ucat_map = {
  622. { unicode_cpt_flags::NUMBER, "\x30-\x39" }, // 0-9
  623. { unicode_cpt_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
  624. { unicode_cpt_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
  625. { unicode_cpt_flags::ACCENT_MARK, "" }, // no sub-128 codepoints
  626. { unicode_cpt_flags::SYMBOL, "\\\x24\\\x2B\x3C-\x3E\x5E\x60\\\x7C" }, // $+<=>^`|
  627. };
  628. // compute collapsed codepoints only if needed by at least one regex
  629. bool need_collapse = false;
  630. for (const auto & regex_expr : regex_exprs) {
  631. // search for unicode categories
  632. for (const auto & ucat : k_ucat_enum) {
  633. if (std::string::npos != regex_expr.find(ucat.first)) {
  634. need_collapse = true;
  635. break;
  636. }
  637. }
  638. }
  639. const auto cpts = unicode_cpts_from_utf8(text);
  640. // generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte
  641. // ref: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2081479935
  642. std::string text_collapsed;
  643. if (need_collapse) {
  644. // collapse all unicode categories
  645. text_collapsed.resize(cpts.size());
  646. for (size_t i = 0; i < cpts.size(); ++i) {
  647. // keep single-byte codepoints as is
  648. if (cpts[i] < 128) {
  649. text_collapsed[i] = cpts[i];
  650. continue;
  651. }
  652. const auto flags = unicode_cpt_flags_from_cpt(cpts[i]);
  653. if (flags.is_whitespace) {
  654. //NOTE: C++ std::regex \s does not mach 0x85, Rust and Python regex does.
  655. //text_collapsed[i] = (char) 0x85; // <Next Line> as whitespace fallback
  656. text_collapsed[i] = (char) 0x0B; // <vertical tab> as whitespace fallback
  657. } else if (k_ucat_cpt.find(flags.category_flag()) != k_ucat_cpt.end()) {
  658. text_collapsed[i] = k_ucat_cpt.at(flags.category_flag());
  659. } else {
  660. text_collapsed[i] = (char) 0xD0; // fallback
  661. }
  662. }
  663. }
  664. std::vector<size_t> bpe_offsets = { cpts.size() };
  665. for (const auto & regex_expr : regex_exprs) {
  666. // first, see if we have an efficient custom regex implementation
  667. auto tmp = unicode_regex_split_custom(text, regex_expr, bpe_offsets);
  668. if (!tmp.empty()) {
  669. bpe_offsets = std::move(tmp);
  670. continue;
  671. }
  672. // fallback to general-purpose std::regex / std::wregex
  673. try {
  674. // if a unicode category is used in the regex, we use the collapsed text and replace the unicode category
  675. // with the corresponding collapsed representation
  676. bool use_collapsed = false;
  677. for (const auto & ucat : k_ucat_enum) {
  678. if (std::string::npos != regex_expr.find(ucat.first)) {
  679. use_collapsed = true;
  680. break;
  681. }
  682. }
  683. if (use_collapsed) {
  684. // sanity-check that the original regex does not contain any non-ASCII characters
  685. const auto cpts_regex = unicode_cpts_from_utf8(regex_expr);
  686. for (size_t i = 0; i < cpts_regex.size(); ++i) {
  687. if (cpts_regex[i] >= 128) {
  688. throw std::runtime_error("Regex includes both unicode categories and non-ASCII characters - not supported");
  689. }
  690. }
  691. // generate a collapsed representation of the regex
  692. std::string regex_expr_collapsed;
  693. // track if we are inside [], because nested [] are not allowed
  694. bool inside = false;
  695. for (size_t i = 0; i < regex_expr.size(); ++i) {
  696. if (regex_expr[i] == '[' && (i == 0 || regex_expr[i - 1] != '\\')) {
  697. regex_expr_collapsed += '[';
  698. inside = true;
  699. continue;
  700. }
  701. if (inside && regex_expr[i] == ']' && regex_expr[i - 1] != '\\') {
  702. regex_expr_collapsed += ']';
  703. inside = false;
  704. continue;
  705. }
  706. if (regex_expr[i + 0] == '\\' && i + 4 < regex_expr.size() &&
  707. regex_expr[i + 1] == 'p' &&
  708. regex_expr[i + 2] == '{' &&
  709. regex_expr[i + 4] == '}') {
  710. const std::string pat = regex_expr.substr(i, 5);
  711. if (k_ucat_enum.find(pat) != k_ucat_enum.end()) {
  712. if (!inside) {
  713. regex_expr_collapsed += '[';
  714. }
  715. regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat));
  716. regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat));
  717. if (!inside) {
  718. regex_expr_collapsed += ']';
  719. }
  720. i += 4;
  721. continue;
  722. }
  723. }
  724. regex_expr_collapsed += regex_expr[i];
  725. }
  726. //printf("text_collapsed: %s\n", text_collapsed.c_str());
  727. //printf("regex_expr_collapsed: %s\n", regex_expr_collapsed.c_str());
  728. bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets);
  729. } else {
  730. // no unicode category used, we can use std::wregex directly
  731. const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
  732. // std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback
  733. std::wstring wtext(cpts.begin(), cpts.end());
  734. for (size_t i = 0; i < wtext.size(); ++i) {
  735. if (wtext[i] > 0x7F && unicode_cpt_flags_from_cpt(wtext[i]).is_whitespace) {
  736. wtext[i] = 0x0B;
  737. }
  738. }
  739. //printf("text: %s\n", text.c_str());
  740. //printf("regex_expr: %s\n", regex_expr.c_str());
  741. bpe_offsets = unicode_regex_split_stl(wtext, wregex_expr, bpe_offsets);
  742. }
  743. } catch (std::regex_error & e) {
  744. fprintf(stderr, "Failed to process regex: '%s'\n", regex_expr.c_str());
  745. fprintf(stderr, "Regex error: %s\n", e.what());
  746. throw std::runtime_error("Failed to process regex");
  747. }
  748. }
  749. std::vector<std::string> bpe_words;
  750. bpe_words.reserve(bpe_offsets.size()); // reserve memory for the approximate size
  751. size_t start = 0;
  752. for (size_t & offset : bpe_offsets) {
  753. bpe_words.emplace_back();
  754. for (size_t i = start; i < start + offset; ++i) {
  755. bpe_words.back() += unicode_cpt_to_utf8(cpts[i]);
  756. }
  757. start += offset;
  758. }
  759. return unicode_byte_encoding_process(bpe_words);
  760. }