llama-grammar.h 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. /**
  2. * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
  3. *
  4. * MIT License
  5. *
  6. * Copyright (c) 2023-2024 The ggml authors
  7. *
  8. * Permission is hereby granted, free of charge, to any person obtaining a copy
  9. * of this software and associated documentation files (the "Software"), to deal
  10. * in the Software without restriction, including without limitation the rights
  11. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  12. * copies of the Software, and to permit persons to whom the Software is
  13. * furnished to do so, subject to the following conditions:
  14. *
  15. * The above copyright notice and this permission notice shall be included in all
  16. * copies or substantial portions of the Software.
  17. *
  18. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  19. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  20. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  21. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  22. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  23. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  24. * SOFTWARE.
  25. */
  26. #pragma once
  27. #include "llama.h"
  28. #include <map>
  29. #include <string>
  30. #include <vector>
  31. struct llama_vocab;
  32. // grammar element type
  33. enum llama_gretype {
  34. // end of rule definition
  35. LLAMA_GRETYPE_END = 0,
  36. // start of alternate definition for rule
  37. LLAMA_GRETYPE_ALT = 1,
  38. // non-terminal element: reference to rule
  39. LLAMA_GRETYPE_RULE_REF = 2,
  40. // terminal element: character (code point)
  41. LLAMA_GRETYPE_CHAR = 3,
  42. // inverse char(s) ([^a], [^a-b] [^abc])
  43. LLAMA_GRETYPE_CHAR_NOT = 4,
  44. // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
  45. // be an inclusive range ([a-z])
  46. LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
  47. // modifies a preceding LLAMA_GRETYPE_CHAR or
  48. // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
  49. LLAMA_GRETYPE_CHAR_ALT = 6,
  50. // any character (.)
  51. LLAMA_GRETYPE_CHAR_ANY = 7,
  52. };
  53. typedef struct llama_grammar_element {
  54. enum llama_gretype type;
  55. uint32_t value; // Unicode code point or rule ID
  56. } llama_grammar_element;
  57. struct llama_partial_utf8 {
  58. uint32_t value; // bit value so far (unshifted)
  59. int n_remain; // num bytes remaining; -1 indicates invalid sequence
  60. };
  61. struct llama_grammar_candidate {
  62. size_t index;
  63. const uint32_t * code_points;
  64. llama_partial_utf8 partial_utf8;
  65. };
  66. using llama_grammar_rule = std::vector< llama_grammar_element>;
  67. using llama_grammar_stack = std::vector<const llama_grammar_element *>;
  68. using llama_grammar_rules = std::vector<llama_grammar_rule>;
  69. using llama_grammar_stacks = std::vector<llama_grammar_stack>;
  70. using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
  71. // TODO: remove, needed for tests atm
  72. const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
  73. llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
  74. // takes a set of possible pushdown stacks on a grammar, which are required to
  75. // be positioned at a character range (see `llama_grammar_advance_stack`), and
  76. // produces the N possible stacks if the given char is accepted at those
  77. // positions
  78. void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr);
  79. std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
  80. const llama_grammar_rules & rules,
  81. const llama_grammar_stack & stack,
  82. const llama_grammar_candidates & candidates);
  83. struct llama_grammar_parser {
  84. std::map<std::string, uint32_t> symbol_ids;
  85. llama_grammar_rules rules;
  86. llama_grammar_stack c_rules() const;
  87. uint32_t get_symbol_id(const char * src, size_t len);
  88. uint32_t generate_symbol_id(const std::string & base_name);
  89. void add_rule(uint32_t rule_id, const llama_grammar_rule & rule);
  90. const char * parse_alternates(
  91. const char * src,
  92. const std::string & rule_name,
  93. uint32_t rule_id,
  94. bool is_nested);
  95. const char * parse_sequence(
  96. const char * src,
  97. const std::string & rule_name,
  98. llama_grammar_rule & rule,
  99. bool is_nested);
  100. const char * parse_rule(const char * src);
  101. bool parse(const char * src);
  102. void print(FILE * file);
  103. };
  104. struct llama_grammar {
  105. // note: allow null vocab for testing (not great)
  106. const llama_vocab * vocab;
  107. const llama_grammar_rules rules; // TODO: shared ptr
  108. llama_grammar_stacks stacks;
  109. // buffer for partially generated UTF-8 sequence from accepted tokens
  110. llama_partial_utf8 partial_utf8;
  111. };
  112. //
  113. // internal API
  114. //
  115. // note: needed for tests (not great)
  116. struct llama_grammar * llama_grammar_init_impl(
  117. const struct llama_vocab * vocab,
  118. const llama_grammar_element ** rules,
  119. size_t n_rules,
  120. size_t start_rule_index);
  121. struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root);
  122. void llama_grammar_free_impl(struct llama_grammar * grammar);
  123. struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar);
  124. // TODO: move the API below as member functions of llama_grammar
  125. void llama_grammar_apply_impl(
  126. const struct llama_grammar & grammar,
  127. llama_token_data_array * cur_p);
  128. void llama_grammar_accept_impl(
  129. struct llama_grammar & grammar,
  130. llama_token token);