llama-grammar.cpp 42 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165
  1. /**
  2. * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
  3. *
  4. * MIT License
  5. *
  6. * Copyright (c) 2023-2024 The ggml authors
  7. *
  8. * Permission is hereby granted, free of charge, to any person obtaining a copy
  9. * of this software and associated documentation files (the "Software"), to deal
  10. * in the Software without restriction, including without limitation the rights
  11. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  12. * copies of the Software, and to permit persons to whom the Software is
  13. * furnished to do so, subject to the following conditions:
  14. *
  15. * The above copyright notice and this permission notice shall be included in all
  16. * copies or substantial portions of the Software.
  17. *
  18. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  19. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  20. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  21. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  22. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  23. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  24. * SOFTWARE.
  25. */
  26. #include "llama-grammar.h"
  27. #include "llama-impl.h"
  28. #include "llama-vocab.h"
  29. #include "llama-sampling.h"
  30. #include <cmath>
  31. #include <algorithm>
  32. #include <stdexcept>
  33. //
  34. // helpers
  35. //
  36. // NOTE: assumes valid utf8 (but checks for overrun)
  37. static std::pair<uint32_t, const char *> decode_utf8(const char * src) {
  38. static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
  39. uint8_t first_byte = static_cast<uint8_t>(*src);
  40. uint8_t highbits = first_byte >> 4;
  41. int len = lookup[highbits];
  42. uint8_t mask = (1 << (8 - len)) - 1;
  43. uint32_t value = first_byte & mask;
  44. const char * end = src + len; // may overrun!
  45. const char * pos = src + 1;
  46. for ( ; pos < end && *pos; pos++) {
  47. value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
  48. }
  49. return std::make_pair(value, pos);
  50. }
  51. static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
  52. const std::string & src,
  53. llama_partial_utf8 partial_start) {
  54. static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
  55. const char * pos = src.c_str();
  56. std::vector<uint32_t> code_points;
  57. // common english strings have the same number of codepoints and bytes. `+ 1` for the terminating 0.
  58. code_points.reserve(src.size() + 1);
  59. uint32_t value = partial_start.value;
  60. int n_remain = partial_start.n_remain;
  61. // continue previous decode, if applicable
  62. while (*pos != 0 && n_remain > 0) {
  63. uint8_t next_byte = static_cast<uint8_t>(*pos);
  64. if ((next_byte >> 6) != 2) {
  65. // invalid sequence, abort
  66. code_points.push_back(0);
  67. return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, -1 });
  68. }
  69. value = (value << 6) + (next_byte & 0x3F);
  70. ++pos;
  71. --n_remain;
  72. }
  73. if (partial_start.n_remain > 0 && n_remain == 0) {
  74. code_points.push_back(value);
  75. }
  76. // decode any subsequent utf-8 sequences, which may end in an incomplete one
  77. while (*pos != 0) {
  78. uint8_t first_byte = static_cast<uint8_t>(*pos);
  79. uint8_t highbits = first_byte >> 4;
  80. n_remain = lookup[highbits] - 1;
  81. if (n_remain < 0) {
  82. // invalid sequence, abort
  83. code_points.clear();
  84. code_points.push_back(0);
  85. return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, n_remain });
  86. }
  87. uint8_t mask = (1 << (7 - n_remain)) - 1;
  88. value = first_byte & mask;
  89. ++pos;
  90. while (*pos != 0 && n_remain > 0) {
  91. value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
  92. ++pos;
  93. --n_remain;
  94. }
  95. if (n_remain == 0) {
  96. code_points.push_back(value);
  97. }
  98. }
  99. code_points.push_back(0);
  100. return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
  101. }
  102. static bool is_digit_char(char c) {
  103. return '0' <= c && c <= '9';
  104. }
  105. static bool is_word_char(char c) {
  106. return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c);
  107. }
  108. static std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
  109. const char * pos = src;
  110. const char * end = src + size;
  111. uint32_t value = 0;
  112. for ( ; pos < end && *pos; pos++) {
  113. value <<= 4;
  114. char c = *pos;
  115. if ('a' <= c && c <= 'f') {
  116. value += c - 'a' + 10;
  117. } else if ('A' <= c && c <= 'F') {
  118. value += c - 'A' + 10;
  119. } else if ('0' <= c && c <= '9') {
  120. value += c - '0';
  121. } else {
  122. break;
  123. }
  124. }
  125. if (pos != end) {
  126. throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
  127. }
  128. return std::make_pair(value, pos);
  129. }
  130. static const char * parse_space(const char * src, bool newline_ok) {
  131. const char * pos = src;
  132. while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
  133. (newline_ok && (*pos == '\r' || *pos == '\n'))) {
  134. if (*pos == '#') {
  135. while (*pos && *pos != '\r' && *pos != '\n') {
  136. pos++;
  137. }
  138. } else {
  139. pos++;
  140. }
  141. }
  142. return pos;
  143. }
  144. static const char * parse_name(const char * src) {
  145. const char * pos = src;
  146. while (is_word_char(*pos)) {
  147. pos++;
  148. }
  149. if (pos == src) {
  150. throw std::runtime_error(std::string("expecting name at ") + src);
  151. }
  152. return pos;
  153. }
  154. static const char * parse_int(const char * src) {
  155. const char * pos = src;
  156. while (is_digit_char(*pos)) {
  157. pos++;
  158. }
  159. if (pos == src) {
  160. throw std::runtime_error(std::string("expecting integer at ") + src);
  161. }
  162. return pos;
  163. }
  164. static std::pair<uint32_t, const char *> parse_char(const char * src) {
  165. if (*src == '\\') {
  166. switch (src[1]) {
  167. case 'x': return parse_hex(src + 2, 2);
  168. case 'u': return parse_hex(src + 2, 4);
  169. case 'U': return parse_hex(src + 2, 8);
  170. case 't': return std::make_pair('\t', src + 2);
  171. case 'r': return std::make_pair('\r', src + 2);
  172. case 'n': return std::make_pair('\n', src + 2);
  173. case '\\':
  174. case '"':
  175. case '[':
  176. case ']':
  177. return std::make_pair(src[1], src + 2);
  178. default:
  179. throw std::runtime_error(std::string("unknown escape at ") + src);
  180. }
  181. } else if (*src) {
  182. return decode_utf8(src);
  183. }
  184. throw std::runtime_error("unexpected end of input");
  185. }
  186. static void print_grammar_char(FILE * file, uint32_t c) {
  187. if (0x20 <= c && c <= 0x7f) {
  188. fprintf(file, "%c", static_cast<char>(c));
  189. } else {
  190. // cop out of encoding UTF-8
  191. fprintf(file, "<U+%04X>", c);
  192. }
  193. }
  194. static bool is_char_element(llama_grammar_element elem) {
  195. switch (elem.type) {
  196. case LLAMA_GRETYPE_CHAR: return true;
  197. case LLAMA_GRETYPE_CHAR_NOT: return true;
  198. case LLAMA_GRETYPE_CHAR_ALT: return true;
  199. case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true;
  200. case LLAMA_GRETYPE_CHAR_ANY: return true;
  201. default: return false;
  202. }
  203. }
  204. static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) {
  205. for (auto elem : rule) {
  206. switch (elem.type) {
  207. case LLAMA_GRETYPE_END: fprintf(file, "END"); break;
  208. case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break;
  209. case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break;
  210. case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break;
  211. case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break;
  212. case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
  213. case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
  214. case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break;
  215. }
  216. switch (elem.type) {
  217. case LLAMA_GRETYPE_END:
  218. case LLAMA_GRETYPE_ALT:
  219. case LLAMA_GRETYPE_RULE_REF:
  220. fprintf(file, "(%u) ", elem.value);
  221. break;
  222. case LLAMA_GRETYPE_CHAR:
  223. case LLAMA_GRETYPE_CHAR_NOT:
  224. case LLAMA_GRETYPE_CHAR_RNG_UPPER:
  225. case LLAMA_GRETYPE_CHAR_ALT:
  226. case LLAMA_GRETYPE_CHAR_ANY:
  227. fprintf(file, "(\"");
  228. print_grammar_char(file, elem.value);
  229. fprintf(file, "\") ");
  230. break;
  231. }
  232. }
  233. fprintf(file, "\n");
  234. }
  235. static void print_rule(
  236. FILE * file,
  237. uint32_t rule_id,
  238. const llama_grammar_rule & rule,
  239. const std::map<uint32_t, std::string> & symbol_id_names) {
  240. if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) {
  241. throw std::runtime_error(
  242. "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id));
  243. }
  244. fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
  245. for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
  246. llama_grammar_element elem = rule[i];
  247. switch (elem.type) {
  248. case LLAMA_GRETYPE_END:
  249. throw std::runtime_error(
  250. "unexpected end of rule: " + std::to_string(rule_id) + "," +
  251. std::to_string(i));
  252. case LLAMA_GRETYPE_ALT:
  253. fprintf(file, "| ");
  254. break;
  255. case LLAMA_GRETYPE_RULE_REF:
  256. fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
  257. break;
  258. case LLAMA_GRETYPE_CHAR:
  259. fprintf(file, "[");
  260. print_grammar_char(file, elem.value);
  261. break;
  262. case LLAMA_GRETYPE_CHAR_NOT:
  263. fprintf(file, "[^");
  264. print_grammar_char(file, elem.value);
  265. break;
  266. case LLAMA_GRETYPE_CHAR_RNG_UPPER:
  267. if (i == 0 || !is_char_element(rule[i - 1])) {
  268. throw std::runtime_error(
  269. "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
  270. std::to_string(rule_id) + "," + std::to_string(i));
  271. }
  272. fprintf(file, "-");
  273. print_grammar_char(file, elem.value);
  274. break;
  275. case LLAMA_GRETYPE_CHAR_ALT:
  276. if (i == 0 || !is_char_element(rule[i - 1])) {
  277. throw std::runtime_error(
  278. "LLAMA_GRETYPE_CHAR_ALT without preceding char: " +
  279. std::to_string(rule_id) + "," + std::to_string(i));
  280. }
  281. print_grammar_char(file, elem.value);
  282. break;
  283. case LLAMA_GRETYPE_CHAR_ANY:
  284. fprintf(file, ".");
  285. break;
  286. }
  287. if (is_char_element(elem)) {
  288. switch (rule[i + 1].type) {
  289. case LLAMA_GRETYPE_CHAR_ALT:
  290. case LLAMA_GRETYPE_CHAR_RNG_UPPER:
  291. case LLAMA_GRETYPE_CHAR_ANY:
  292. break;
  293. default:
  294. fprintf(file, "] ");
  295. }
  296. }
  297. }
  298. fprintf(file, "\n");
  299. }
  300. //
  301. // implementation
  302. //
  303. uint32_t llama_grammar_parser::get_symbol_id(const char * src, size_t len) {
  304. uint32_t next_id = static_cast<uint32_t>(symbol_ids.size());
  305. auto result = symbol_ids.emplace(std::string(src, len), next_id);
  306. return result.first->second;
  307. }
  308. uint32_t llama_grammar_parser::generate_symbol_id(const std::string & base_name) {
  309. uint32_t next_id = static_cast<uint32_t>(symbol_ids.size());
  310. symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
  311. return next_id;
  312. }
  313. void llama_grammar_parser::add_rule(uint32_t rule_id, const llama_grammar_rule & rule) {
  314. if (rules.size() <= rule_id) {
  315. rules.resize(rule_id + 1);
  316. }
  317. rules[rule_id] = rule;
  318. }
  319. const char * llama_grammar_parser::parse_alternates(
  320. const char * src,
  321. const std::string & rule_name,
  322. uint32_t rule_id,
  323. bool is_nested) {
  324. llama_grammar_rule rule;
  325. const char * pos = parse_sequence(src, rule_name, rule, is_nested);
  326. while (*pos == '|') {
  327. rule.push_back({LLAMA_GRETYPE_ALT, 0});
  328. pos = parse_space(pos + 1, true);
  329. pos = parse_sequence(pos, rule_name, rule, is_nested);
  330. }
  331. rule.push_back({LLAMA_GRETYPE_END, 0});
  332. add_rule(rule_id, rule);
  333. return pos;
  334. }
  335. const char * llama_grammar_parser::parse_sequence(
  336. const char * src,
  337. const std::string & rule_name,
  338. llama_grammar_rule & rule,
  339. bool is_nested) {
  340. size_t last_sym_start = rule.size();
  341. const char * pos = src;
  342. auto handle_repetitions = [&](int min_times, int max_times) {
  343. if (last_sym_start == rule.size()) {
  344. throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
  345. }
  346. // apply transformation to previous symbol (last_sym_start to end) according to
  347. // the following rewrite rules:
  348. // S{m,n} --> S S S (m times) S'(n-m)
  349. // S'(x) ::= S S'(x-1) |
  350. // (... n-m definitions of these S' rules ...)
  351. // S'(1) ::= S |
  352. // S{m,} --> S S S (m times) S'
  353. // S' ::= S S' |
  354. // S* --> S{0,}
  355. // --> S' ::= S S' |
  356. // S+ --> S{1,}
  357. // --> S S'
  358. // S' ::= S S' |
  359. // S? --> S{0,1}
  360. // --> S'
  361. // S' ::= S |
  362. llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
  363. if (min_times == 0) {
  364. rule.resize(last_sym_start);
  365. } else {
  366. // Repeat the previous elements (min_times - 1) times
  367. for (int i = 1; i < min_times; i++) {
  368. rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
  369. }
  370. }
  371. uint32_t last_rec_rule_id = 0;
  372. auto n_opt = max_times < 0 ? 1 : max_times - min_times;
  373. llama_grammar_rule rec_rule(prev_rule);
  374. for (int i = 0; i < n_opt; i++) {
  375. rec_rule.resize(prev_rule.size());
  376. uint32_t rec_rule_id = generate_symbol_id( rule_name);
  377. if (i > 0 || max_times < 0) {
  378. rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
  379. }
  380. rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
  381. rec_rule.push_back({LLAMA_GRETYPE_END, 0});
  382. add_rule( rec_rule_id, rec_rule);
  383. last_rec_rule_id = rec_rule_id;
  384. }
  385. if (n_opt > 0) {
  386. rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
  387. }
  388. };
  389. while (*pos) {
  390. if (*pos == '"') { // literal string
  391. pos++;
  392. last_sym_start = rule.size();
  393. while (*pos != '"') {
  394. if (!*pos) {
  395. throw std::runtime_error("unexpected end of input");
  396. }
  397. auto char_pair = parse_char(pos);
  398. pos = char_pair.second;
  399. rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
  400. }
  401. pos = parse_space(pos + 1, is_nested);
  402. } else if (*pos == '[') { // char range(s)
  403. pos++;
  404. enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
  405. if (*pos == '^') {
  406. pos++;
  407. start_type = LLAMA_GRETYPE_CHAR_NOT;
  408. }
  409. last_sym_start = rule.size();
  410. while (*pos != ']') {
  411. if (!*pos) {
  412. throw std::runtime_error("unexpected end of input");
  413. }
  414. auto char_pair = parse_char(pos);
  415. pos = char_pair.second;
  416. enum llama_gretype type = last_sym_start < rule.size()
  417. ? LLAMA_GRETYPE_CHAR_ALT
  418. : start_type;
  419. rule.push_back({type, char_pair.first});
  420. if (pos[0] == '-' && pos[1] != ']') {
  421. if (!pos[1]) {
  422. throw std::runtime_error("unexpected end of input");
  423. }
  424. auto endchar_pair = parse_char(pos + 1);
  425. pos = endchar_pair.second;
  426. rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
  427. }
  428. }
  429. pos = parse_space(pos + 1, is_nested);
  430. } else if (is_word_char(*pos)) { // rule reference
  431. const char * name_end = parse_name(pos);
  432. uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
  433. pos = parse_space(name_end, is_nested);
  434. last_sym_start = rule.size();
  435. rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
  436. } else if (*pos == '(') { // grouping
  437. // parse nested alternates into synthesized rule
  438. pos = parse_space(pos + 1, true);
  439. uint32_t sub_rule_id = generate_symbol_id(rule_name);
  440. pos = parse_alternates(pos, rule_name, sub_rule_id, true);
  441. last_sym_start = rule.size();
  442. // output reference to synthesized rule
  443. rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
  444. if (*pos != ')') {
  445. throw std::runtime_error(std::string("expecting ')' at ") + pos);
  446. }
  447. pos = parse_space(pos + 1, is_nested);
  448. } else if (*pos == '.') { // any char
  449. last_sym_start = rule.size();
  450. rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
  451. pos = parse_space(pos + 1, is_nested);
  452. } else if (*pos == '*') {
  453. pos = parse_space(pos + 1, is_nested);
  454. handle_repetitions(0, -1);
  455. } else if (*pos == '+') {
  456. pos = parse_space(pos + 1, is_nested);
  457. handle_repetitions(1, -1);
  458. } else if (*pos == '?') {
  459. pos = parse_space(pos + 1, is_nested);
  460. handle_repetitions(0, 1);
  461. } else if (*pos == '{') {
  462. pos = parse_space(pos + 1, is_nested);
  463. if (!is_digit_char(*pos)) {
  464. throw std::runtime_error(std::string("expecting an int at ") + pos);
  465. }
  466. const char * int_end = parse_int(pos);
  467. int min_times = std::stoul(std::string(pos, int_end - pos));
  468. pos = parse_space(int_end, is_nested);
  469. int max_times = -1;
  470. if (*pos == '}') {
  471. max_times = min_times;
  472. pos = parse_space(pos + 1, is_nested);
  473. } else if (*pos == ',') {
  474. pos = parse_space(pos + 1, is_nested);
  475. if (is_digit_char(*pos)) {
  476. const char * int_end = parse_int(pos);
  477. max_times = std::stoul(std::string(pos, int_end - pos));
  478. pos = parse_space(int_end, is_nested);
  479. }
  480. if (*pos != '}') {
  481. throw std::runtime_error(std::string("expecting '}' at ") + pos);
  482. }
  483. pos = parse_space(pos + 1, is_nested);
  484. } else {
  485. throw std::runtime_error(std::string("expecting ',' at ") + pos);
  486. }
  487. handle_repetitions(min_times, max_times);
  488. } else {
  489. break;
  490. }
  491. }
  492. return pos;
  493. }
  494. const char * llama_grammar_parser::parse_rule(const char * src) {
  495. const char * name_end = parse_name(src);
  496. const char * pos = parse_space(name_end, false);
  497. size_t name_len = name_end - src;
  498. uint32_t rule_id = get_symbol_id(src, name_len);
  499. const std::string name(src, name_len);
  500. if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
  501. throw std::runtime_error(std::string("expecting ::= at ") + pos);
  502. }
  503. pos = parse_space(pos + 3, true);
  504. pos = parse_alternates(pos, name, rule_id, false);
  505. if (*pos == '\r') {
  506. pos += pos[1] == '\n' ? 2 : 1;
  507. } else if (*pos == '\n') {
  508. pos++;
  509. } else if (*pos) {
  510. throw std::runtime_error(std::string("expecting newline or end at ") + pos);
  511. }
  512. return parse_space(pos, true);
  513. }
  514. bool llama_grammar_parser::parse(const char * src) {
  515. try {
  516. const char * pos = parse_space(src, true);
  517. while (*pos) {
  518. pos = parse_rule(pos);
  519. }
  520. // Validate the state to ensure that all rules are defined
  521. for (const auto & rule : rules) {
  522. if (rule.empty()) {
  523. throw std::runtime_error("Undefined rule");
  524. }
  525. for (const auto & elem : rule) {
  526. if (elem.type == LLAMA_GRETYPE_RULE_REF) {
  527. // Ensure that the rule at that location exists
  528. if (elem.value >= rules.size() || rules[elem.value].empty()) {
  529. // Get the name of the rule that is missing
  530. for (const auto & kv : symbol_ids) {
  531. if (kv.second == elem.value) {
  532. throw std::runtime_error("Undefined rule identifier '" + kv.first + "'");
  533. }
  534. }
  535. }
  536. }
  537. }
  538. }
  539. } catch (const std::exception & err) {
  540. fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
  541. rules.clear();
  542. return false;
  543. }
  544. return true;
  545. }
  546. void llama_grammar_parser::print(FILE * file) {
  547. try {
  548. std::map<uint32_t, std::string> symbol_id_names;
  549. for (const auto & kv : symbol_ids) {
  550. symbol_id_names[kv.second] = kv.first;
  551. }
  552. for (size_t i = 0, end = rules.size(); i < end; i++) {
  553. // fprintf(file, "%zu: ", i);
  554. // print_rule_binary(file, rules[i]);
  555. print_rule(file, uint32_t(i), rules[i], symbol_id_names);
  556. // fprintf(file, "\n");
  557. }
  558. } catch (const std::exception & err) {
  559. fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
  560. }
  561. }
  562. llama_grammar_stack llama_grammar_parser::c_rules() const {
  563. llama_grammar_stack ret;
  564. ret.reserve(rules.size());
  565. for (const auto & rule : rules) {
  566. ret.push_back(rule.data());
  567. }
  568. return ret;
  569. }
  570. // returns true iff pos points to the end of one of the definitions of a rule
  571. static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) {
  572. switch (pos->type) {
  573. case LLAMA_GRETYPE_END: return true; // NOLINT
  574. case LLAMA_GRETYPE_ALT: return true; // NOLINT
  575. default: return false;
  576. }
  577. }
  578. // returns true iff chr satisfies the char range at pos (regular or inverse range)
  579. // asserts that pos is pointing to a char range element
  580. static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
  581. const llama_grammar_element * pos,
  582. const uint32_t chr) {
  583. bool found = false;
  584. bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
  585. GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); // NOLINT
  586. do {
  587. if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
  588. // inclusive range, e.g. [a-z]
  589. found = found || (pos->value <= chr && chr <= pos[1].value);
  590. pos += 2;
  591. } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) {
  592. // Any character matches "."
  593. found = true;
  594. pos += 1;
  595. } else {
  596. // exact char match, e.g. [a] or "a"
  597. found = found || pos->value == chr;
  598. pos += 1;
  599. }
  600. } while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
  601. return std::make_pair(found == is_positive_char, pos);
  602. }
  603. // returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char
  604. // range at pos (regular or inverse range)
  605. // asserts that pos is pointing to a char range element
  606. static bool llama_grammar_match_partial_char(
  607. const llama_grammar_element * pos,
  608. const llama_partial_utf8 partial_utf8) {
  609. bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
  610. GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT);
  611. uint32_t partial_value = partial_utf8.value;
  612. int n_remain = partial_utf8.n_remain;
  613. // invalid sequence or 7-bit char split across 2 bytes (overlong)
  614. if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) {
  615. return false;
  616. }
  617. // range of possible code points this partial UTF-8 sequence could complete to
  618. uint32_t low = partial_value << (n_remain * 6);
  619. uint32_t high = low | ((1 << (n_remain * 6)) - 1);
  620. if (low == 0) {
  621. if (n_remain == 2) {
  622. low = 1 << 11;
  623. } else if (n_remain == 3) {
  624. low = 1 << 16;
  625. }
  626. }
  627. do {
  628. if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
  629. // inclusive range, e.g. [a-z]
  630. if (pos->value <= high && low <= pos[1].value) {
  631. return is_positive_char;
  632. }
  633. pos += 2;
  634. } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) {
  635. // Any character matches "."
  636. return true;
  637. } else {
  638. // exact char match, e.g. [a] or "a"
  639. if (low <= pos->value && pos->value <= high) {
  640. return is_positive_char;
  641. }
  642. pos += 1;
  643. }
  644. } while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
  645. return !is_positive_char;
  646. }
  647. // transforms a grammar pushdown stack into N possible stacks, all ending
  648. // at a character range (terminal element)
  649. static void llama_grammar_advance_stack(
  650. const llama_grammar_rules & rules,
  651. const llama_grammar_stack & stack,
  652. llama_grammar_stacks & new_stacks) {
  653. if (stack.empty()) {
  654. if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
  655. new_stacks.emplace_back(stack);
  656. }
  657. return;
  658. }
  659. const llama_grammar_element * pos = stack.back();
  660. switch (pos->type) {
  661. case LLAMA_GRETYPE_RULE_REF: {
  662. const size_t rule_id = static_cast<size_t>(pos->value);
  663. const llama_grammar_element * subpos = rules[rule_id].data();
  664. do {
  665. // init new stack without the top (pos)
  666. llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
  667. if (!llama_grammar_is_end_of_sequence(pos + 1)) {
  668. // if this rule ref is followed by another element, add that to stack
  669. new_stack.push_back(pos + 1);
  670. }
  671. if (!llama_grammar_is_end_of_sequence(subpos)) {
  672. // if alternate is nonempty, add to stack
  673. new_stack.push_back(subpos);
  674. }
  675. llama_grammar_advance_stack(rules, new_stack, new_stacks);
  676. while (!llama_grammar_is_end_of_sequence(subpos)) {
  677. // scan to end of alternate def
  678. subpos++;
  679. }
  680. if (subpos->type == LLAMA_GRETYPE_ALT) {
  681. // there's another alternate def of this rule to process
  682. subpos++;
  683. } else {
  684. break;
  685. }
  686. } while (true);
  687. break;
  688. }
  689. case LLAMA_GRETYPE_CHAR:
  690. case LLAMA_GRETYPE_CHAR_NOT:
  691. case LLAMA_GRETYPE_CHAR_ANY:
  692. if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
  693. // only add the stack if it's not a duplicate of one we already have
  694. new_stacks.emplace_back(stack);
  695. }
  696. break;
  697. default:
  698. // end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range
  699. // (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
  700. // those
  701. GGML_ABORT("fatal error");
  702. }
  703. }
  704. static llama_grammar_candidates llama_grammar_reject_candidates(
  705. const llama_grammar_rules & rules,
  706. const llama_grammar_stacks & stacks,
  707. const llama_grammar_candidates & candidates) {
  708. GGML_ASSERT(!stacks.empty()); // REVIEW
  709. if (candidates.empty()) {
  710. return {};
  711. }
  712. auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
  713. for (size_t i = 1, size = stacks.size(); i < size; ++i) {
  714. rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
  715. }
  716. return rejects;
  717. }
  718. static bool llama_grammar_detect_left_recursion(
  719. const llama_grammar_rules & rules,
  720. size_t rule_index,
  721. std::vector<bool> * rules_visited,
  722. std::vector<bool> * rules_in_progress,
  723. std::vector<bool> * rules_may_be_empty) {
  724. if ((*rules_in_progress)[rule_index]) {
  725. return true;
  726. }
  727. (*rules_in_progress)[rule_index] = true;
  728. const llama_grammar_rule & rule = rules[rule_index];
  729. // First check if the rule might produce the empty string. This could be done combined with the second
  730. // step but it's more readable as two steps.
  731. bool at_rule_start = true;
  732. for (size_t i = 0; i < rule.size(); i++) {
  733. if (llama_grammar_is_end_of_sequence(&rule[i])) {
  734. if (at_rule_start) {
  735. (*rules_may_be_empty)[rule_index] = true;
  736. break;
  737. }
  738. at_rule_start = true;
  739. } else {
  740. at_rule_start = false;
  741. }
  742. }
  743. // Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
  744. // be empty)
  745. bool recurse_into_nonterminal = true;
  746. for (size_t i = 0; i < rule.size(); i++) {
  747. if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
  748. if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
  749. return true;
  750. }
  751. if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
  752. recurse_into_nonterminal = false;
  753. }
  754. } else if (llama_grammar_is_end_of_sequence(&rule[i])) {
  755. recurse_into_nonterminal = true;
  756. } else {
  757. recurse_into_nonterminal = false;
  758. }
  759. }
  760. (*rules_in_progress)[rule_index] = false;
  761. (*rules_visited)[rule_index] = true;
  762. return false;
  763. }
  764. const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) {
  765. return grammar->rules;
  766. }
  767. llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) {
  768. return grammar->stacks;
  769. }
  770. void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) {
  771. llama_grammar_stacks stacks_new;
  772. stacks_new.reserve(grammar->stacks.size());
  773. for (const auto & stack : grammar->stacks) {
  774. if (stack.empty()) {
  775. continue;
  776. }
  777. auto match = llama_grammar_match_char(stack.back(), chr);
  778. if (match.first) {
  779. const llama_grammar_element * pos = match.second;
  780. // update top of stack to next element, if any
  781. llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
  782. if (!llama_grammar_is_end_of_sequence(pos)) {
  783. new_stack.push_back(pos);
  784. }
  785. llama_grammar_advance_stack(grammar->rules, new_stack, stacks_new);
  786. }
  787. }
  788. grammar->stacks = std::move(stacks_new);
  789. }
  790. llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
  791. const llama_grammar_rules & rules,
  792. const llama_grammar_stack & stack,
  793. const llama_grammar_candidates & candidates) {
  794. llama_grammar_candidates rejects;
  795. rejects.reserve(candidates.size());
  796. if (stack.empty()) {
  797. for (const auto & tok : candidates) {
  798. if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) {
  799. rejects.push_back(tok);
  800. }
  801. }
  802. return rejects;
  803. }
  804. const llama_grammar_element * stack_pos = stack.back();
  805. llama_grammar_candidates next_candidates;
  806. next_candidates.reserve(candidates.size());
  807. for (const auto & tok : candidates) {
  808. if (*tok.code_points == 0) {
  809. // reached end of full codepoints in token, reject iff it ended in a partial sequence
  810. // that cannot satisfy this position in grammar
  811. if (tok.partial_utf8.n_remain != 0 &&
  812. !llama_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
  813. rejects.push_back(tok);
  814. }
  815. } else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) {
  816. next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 });
  817. } else {
  818. rejects.push_back(tok);
  819. }
  820. }
  821. const auto * stack_pos_after = llama_grammar_match_char(stack_pos, 0).second;
  822. // update top of stack to next element, if any
  823. llama_grammar_stack stack_after(stack.begin(), stack.end() - 1);
  824. if (!llama_grammar_is_end_of_sequence(stack_pos_after)) {
  825. stack_after.push_back(stack_pos_after);
  826. }
  827. llama_grammar_stacks next_stacks;
  828. llama_grammar_advance_stack(rules, stack_after, next_stacks);
  829. auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
  830. for (const auto & tok : next_rejects) {
  831. rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 });
  832. }
  833. return rejects;
  834. }
  835. ////////////////////
  836. struct llama_grammar * llama_grammar_init_impl(
  837. const struct llama_vocab * vocab,
  838. const llama_grammar_element ** rules,
  839. size_t n_rules,
  840. size_t start_rule_index) {
  841. const llama_grammar_element * pos;
  842. // copy rule definitions into vectors
  843. llama_grammar_rules vec_rules(n_rules);
  844. for (size_t i = 0; i < n_rules; i++) {
  845. for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
  846. vec_rules[i].push_back(*pos);
  847. }
  848. vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
  849. }
  850. // Check for left recursion
  851. std::vector<bool> rules_visited(n_rules);
  852. std::vector<bool> rules_in_progress(n_rules);
  853. std::vector<bool> rules_may_be_empty(n_rules);
  854. for (size_t i = 0; i < n_rules; i++) {
  855. if (rules_visited[i]) {
  856. continue;
  857. }
  858. if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
  859. LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i);
  860. return nullptr;
  861. }
  862. }
  863. // loop over alternates of start rule to build initial stacks
  864. llama_grammar_stacks stacks;
  865. pos = vec_rules[start_rule_index].data();
  866. do {
  867. llama_grammar_stack stack;
  868. if (!llama_grammar_is_end_of_sequence(pos)) {
  869. // if alternate is nonempty, add to stack
  870. stack.push_back(pos);
  871. }
  872. llama_grammar_advance_stack(vec_rules, stack, stacks);
  873. while (!llama_grammar_is_end_of_sequence(pos)) {
  874. // scan to end of alternate def
  875. pos++;
  876. }
  877. if (pos->type == LLAMA_GRETYPE_ALT) {
  878. // there's another alternate def of this rule to process
  879. pos++;
  880. } else {
  881. break;
  882. }
  883. } while (true);
  884. // Important: vec_rules has to be moved here, not copied, because stacks contains
  885. // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
  886. // then the pointers would be invalidated when the local vec_rules goes out of scope.
  887. return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
  888. }
  889. struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
  890. llama_grammar_parser parser;
  891. // if there is a grammar, parse it
  892. if (!parser.parse(grammar_str)) {
  893. return nullptr;
  894. }
  895. // will be empty (default) if there are parse errors
  896. if (parser.rules.empty()) {
  897. fprintf(stderr, "%s: failed to parse grammar\n", __func__);
  898. return nullptr;
  899. }
  900. // Ensure that there is a "root" node.
  901. if (parser.symbol_ids.find("root") == parser.symbol_ids.end()) {
  902. fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__);
  903. return nullptr;
  904. }
  905. std::vector<const llama_grammar_element *> grammar_rules(parser.c_rules());
  906. const size_t n_rules = grammar_rules.size();
  907. const size_t start_rule_index = parser.symbol_ids.at(grammar_root);
  908. const llama_grammar_element * pos;
  909. // copy rule definitions into vectors
  910. llama_grammar_rules vec_rules(n_rules);
  911. for (size_t i = 0; i < n_rules; i++) {
  912. for (pos = grammar_rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
  913. vec_rules[i].push_back(*pos);
  914. }
  915. vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
  916. }
  917. // Check for left recursion
  918. std::vector<bool> rules_visited(n_rules);
  919. std::vector<bool> rules_in_progress(n_rules);
  920. std::vector<bool> rules_may_be_empty(n_rules);
  921. for (size_t i = 0; i < n_rules; i++) {
  922. if (rules_visited[i]) {
  923. continue;
  924. }
  925. if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
  926. LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i);
  927. return nullptr;
  928. }
  929. }
  930. // loop over alternates of start rule to build initial stacks
  931. llama_grammar_stacks stacks;
  932. pos = vec_rules[start_rule_index].data();
  933. do {
  934. llama_grammar_stack stack;
  935. if (!llama_grammar_is_end_of_sequence(pos)) {
  936. // if alternate is nonempty, add to stack
  937. stack.push_back(pos);
  938. }
  939. llama_grammar_advance_stack(vec_rules, stack, stacks);
  940. while (!llama_grammar_is_end_of_sequence(pos)) {
  941. // scan to end of alternate def
  942. pos++;
  943. }
  944. if (pos->type == LLAMA_GRETYPE_ALT) {
  945. // there's another alternate def of this rule to process
  946. pos++;
  947. } else {
  948. break;
  949. }
  950. } while (true);
  951. // Important: vec_rules has to be moved here, not copied, because stacks contains
  952. // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
  953. // then the pointers would be invalidated when the local vec_rules goes out of scope.
  954. return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
  955. }
  956. void llama_grammar_free_impl(struct llama_grammar * grammar) {
  957. if (grammar == nullptr) {
  958. return;
  959. }
  960. delete grammar;
  961. }
  962. struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
  963. llama_grammar * result = new llama_grammar {
  964. grammar.vocab,
  965. grammar.rules,
  966. grammar.stacks,
  967. grammar.partial_utf8,
  968. };
  969. // redirect elements in stacks to point to new rules
  970. for (size_t is = 0; is < result->stacks.size(); is++) {
  971. for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
  972. for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) {
  973. for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) {
  974. if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) {
  975. result->stacks[is][ie] = &result->rules[ir0][ir1];
  976. }
  977. }
  978. }
  979. }
  980. }
  981. return result;
  982. }
  983. void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) {
  984. GGML_ASSERT(grammar.vocab != nullptr);
  985. bool allow_eog = false;
  986. for (const auto & stack : grammar.stacks) {
  987. if (stack.empty()) {
  988. allow_eog = true;
  989. break;
  990. }
  991. }
  992. std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
  993. candidates_decoded.reserve(cur_p->size);
  994. llama_grammar_candidates candidates_grammar;
  995. candidates_grammar.reserve(cur_p->size);
  996. for (size_t i = 0; i < cur_p->size; ++i) {
  997. const llama_token id = cur_p->data[i].id;
  998. const std::string & piece = grammar.vocab->cache_token_to_piece.at(id);
  999. if (llama_token_is_eog_impl(*grammar.vocab, id)) {
  1000. if (!allow_eog) {
  1001. cur_p->data[i].logit = -INFINITY;
  1002. }
  1003. } else if (piece.empty() || piece[0] == 0) {
  1004. cur_p->data[i].logit = -INFINITY;
  1005. } else {
  1006. candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8));
  1007. candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
  1008. }
  1009. }
  1010. const auto rejects = llama_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar);
  1011. for (const auto & reject : rejects) {
  1012. cur_p->data[reject.index].logit = -INFINITY;
  1013. }
  1014. }
  1015. void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
  1016. GGML_ASSERT(grammar.vocab != nullptr);
  1017. if (llama_token_is_eog_impl(*grammar.vocab, token)) {
  1018. for (const auto & stack : grammar.stacks) {
  1019. if (stack.empty()) {
  1020. return;
  1021. }
  1022. }
  1023. GGML_ABORT("fatal error");
  1024. }
  1025. const std::string & piece = grammar.vocab->cache_token_to_piece.at(token);
  1026. // Note terminating 0 in decoded string
  1027. const auto decoded = decode_utf8(piece, grammar.partial_utf8);
  1028. const auto & code_points = decoded.first;
  1029. for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
  1030. llama_grammar_accept(&grammar, *it);
  1031. }
  1032. grammar.partial_utf8 = decoded.second;
  1033. GGML_ASSERT(!grammar.stacks.empty());
  1034. }