llama-grammar.h 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. /**
  2. * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  3. *
  4. * MIT License
  5. *
  6. * Copyright (c) 2023-2024 The ggml authors
  7. *
  8. * Permission is hereby granted, free of charge, to any person obtaining a copy
  9. * of this software and associated documentation files (the "Software"), to deal
  10. * in the Software without restriction, including without limitation the rights
  11. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  12. * copies of the Software, and to permit persons to whom the Software is
  13. * furnished to do so, subject to the following conditions:
  14. *
  15. * The above copyright notice and this permission notice shall be included in all
  16. * copies or substantial portions of the Software.
  17. *
  18. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  19. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  20. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  21. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  22. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  23. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  24. * SOFTWARE.
  25. */
  26. #pragma once
  27. #include "llama-impl.h"
  28. #include <map>
  29. struct llama_vocab;
  30. // grammar element type
  31. enum llama_gretype {
  32. // end of rule definition
  33. LLAMA_GRETYPE_END = 0,
  34. // start of alternate definition for rule
  35. LLAMA_GRETYPE_ALT = 1,
  36. // non-terminal element: reference to rule
  37. LLAMA_GRETYPE_RULE_REF = 2,
  38. // terminal element: character (code point)
  39. LLAMA_GRETYPE_CHAR = 3,
  40. // inverse char(s) ([^a], [^a-b] [^abc])
  41. LLAMA_GRETYPE_CHAR_NOT = 4,
  42. // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
  43. // be an inclusive range ([a-z])
  44. LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
  45. // modifies a preceding LLAMA_GRETYPE_CHAR or
  46. // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
  47. LLAMA_GRETYPE_CHAR_ALT = 6,
  48. // any character (.)
  49. LLAMA_GRETYPE_CHAR_ANY = 7,
  50. };
  51. typedef struct llama_grammar_element {
  52. enum llama_gretype type;
  53. uint32_t value; // Unicode code point or rule ID
  54. } llama_grammar_element;
  55. struct llama_partial_utf8 {
  56. uint32_t value; // bit value so far (unshifted)
  57. int n_remain; // num bytes remaining; -1 indicates invalid sequence
  58. };
  59. struct llama_grammar_candidate {
  60. size_t index;
  61. const uint32_t * code_points;
  62. llama_partial_utf8 partial_utf8;
  63. };
  64. using llama_grammar_rule = std::vector< llama_grammar_element>;
  65. using llama_grammar_stack = std::vector<const llama_grammar_element *>;
  66. using llama_grammar_rules = std::vector<llama_grammar_rule>;
  67. using llama_grammar_stacks = std::vector<llama_grammar_stack>;
  68. using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
  69. const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
  70. llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
  71. // takes a set of possible pushdown stacks on a grammar, which are required to
  72. // be positioned at a character range (see `llama_grammar_advance_stack`), and
  73. // produces the N possible stacks if the given char is accepted at those
  74. // positions
  75. void llama_grammar_accept(
  76. const llama_grammar_rules & rules,
  77. const llama_grammar_stacks & stacks,
  78. uint32_t chr,
  79. llama_grammar_stacks & stacks_new);
  80. std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
  81. const llama_grammar_rules & rules,
  82. const llama_grammar_stack & stack,
  83. const llama_grammar_candidates & candidates);
  84. struct llama_grammar_parser {
  85. std::map<std::string, uint32_t> symbol_ids;
  86. llama_grammar_rules rules;
  87. llama_grammar_stack c_rules() const;
  88. uint32_t get_symbol_id(const char * src, size_t len);
  89. uint32_t generate_symbol_id(const std::string & base_name);
  90. void add_rule(uint32_t rule_id, const llama_grammar_rule & rule);
  91. const char * parse_alternates(
  92. const char * src,
  93. const std::string & rule_name,
  94. uint32_t rule_id,
  95. bool is_nested);
  96. const char * parse_sequence(
  97. const char * src,
  98. const std::string & rule_name,
  99. llama_grammar_rule & rule,
  100. bool is_nested);
  101. const char * parse_rule(const char * src);
  102. bool parse(const char * src);
  103. void print(FILE * file);
  104. };
  105. struct llama_grammar {
  106. // note: allow null vocab for testing (not great)
  107. const llama_vocab * vocab;
  108. const llama_grammar_rules rules; // TODO: shared ptr
  109. llama_grammar_stacks stacks;
  110. // buffer for partially generated UTF-8 sequence from accepted tokens
  111. llama_partial_utf8 partial_utf8;
  112. };
  113. //
  114. // internal API
  115. //
  116. // note: needed for tests (not great)
  117. struct llama_grammar * llama_grammar_init_impl(
  118. const struct llama_vocab * vocab,
  119. const llama_grammar_element ** rules,
  120. size_t n_rules,
  121. size_t start_rule_index);
  122. struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root);
  123. void llama_grammar_free_impl(struct llama_grammar * grammar);
  124. struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar);
  125. // TODO: move the API below as member functions of llama_grammar
  126. void llama_grammar_apply_impl(
  127. const struct llama_grammar & grammar,
  128. llama_token_data_array * cur_p);
  129. void llama_grammar_accept_impl(
  130. struct llama_grammar & grammar,
  131. llama_token token);