llama-vocab.h 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. #pragma once
  2. #include "llama-impl.h"
  3. #include <string>
  4. #include <vector>
  5. #include <unordered_map>
  6. #include <map>
  7. #include <set>
  8. struct llm_tokenizer;
  9. struct llama_vocab {
  10. using id = llama_token;
  11. using token = std::string;
  12. using tattr = llama_token_attr;
  13. struct token_data {
  14. token text;
  15. float score;
  16. tattr attr;
  17. };
  18. uint32_t n_vocab = 0; // TODO: not great because has to keep in sync with hparams.n_vocab
  19. enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
  20. enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
  21. int max_token_len = 0; // used for optimizing longest token search
  22. std::unordered_map<token, id> token_to_id;
  23. std::vector<token_data> id_to_token;
  24. std::vector<id> cache_special_tokens;
  25. std::vector<token> cache_token_to_piece; // llama_token_to_piece(special = true);
  26. std::map<std::pair<std::string, std::string>, int> bpe_ranks;
  27. // default LLaMA special tokens
  28. // TODO: should we set all of these to LLAMA_TOKEN_NULL?
  29. id special_bos_id = 1;
  30. id special_eos_id = 2;
  31. id special_eot_id = LLAMA_TOKEN_NULL;
  32. id special_eom_id = LLAMA_TOKEN_NULL;
  33. id special_unk_id = 0;
  34. id special_sep_id = LLAMA_TOKEN_NULL;
  35. id special_pad_id = LLAMA_TOKEN_NULL;
  36. id special_cls_id = LLAMA_TOKEN_NULL;
  37. id special_mask_id = LLAMA_TOKEN_NULL;
  38. id linefeed_id = 13;
  39. // fim tokens
  40. id special_fim_pre_id = LLAMA_TOKEN_NULL;
  41. id special_fim_suf_id = LLAMA_TOKEN_NULL;
  42. id special_fim_mid_id = LLAMA_TOKEN_NULL;
  43. id special_fim_pad_id = LLAMA_TOKEN_NULL;
  44. id special_fim_rep_id = LLAMA_TOKEN_NULL; // repo
  45. id special_fim_sep_id = LLAMA_TOKEN_NULL; // file separator
  46. // set of all tokens that cause "end of generation"
  47. std::set<id> special_eog_ids;
  48. // tokenizer flags
  49. bool tokenizer_add_space_prefix = false;
  50. bool tokenizer_add_bos = false;
  51. bool tokenizer_add_eos = false;
  52. bool tokenizer_ignore_merges = false;
  53. bool tokenizer_clean_spaces = false; // clean_up_tokenization_spaces
  54. bool tokenizer_remove_extra_whitespaces = false;
  55. bool tokenizer_escape_whitespaces = true;
  56. bool tokenizer_treat_whitespace_as_suffix = false;
  57. std::vector<char> precompiled_charsmap;
  58. llm_tokenizer * tokenizer = nullptr;
  59. llama_vocab() = default;
  60. ~llama_vocab();
  61. int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
  62. void init_tokenizer();
  63. };
  64. //
  65. // internal API
  66. //
  67. // TODO: rename to llama_tokenize_impl
  68. // TODO: This should probably be in llama.h
  69. std::vector<llama_vocab::id> llama_tokenize_internal(
  70. const llama_vocab & vocab,
  71. std::string raw_text,
  72. bool add_special,
  73. bool parse_special = false);
  74. // TODO: move the API below as member functions of llama_vocab
  75. llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch);
  76. const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token);
  77. float llama_token_get_score_impl(const struct llama_vocab & vocab, llama_token token);
  78. llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, llama_token token);
  79. bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token);
  80. bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token);
  81. llama_token llama_token_bos_impl(const struct llama_vocab & vocab);
  82. llama_token llama_token_eos_impl(const struct llama_vocab & vocab);
  83. llama_token llama_token_eot_impl(const struct llama_vocab & vocab);
  84. llama_token llama_token_eom_impl(const struct llama_vocab & vocab);
  85. llama_token llama_token_cls_impl(const struct llama_vocab & vocab);
  86. llama_token llama_token_sep_impl(const struct llama_vocab & vocab);
  87. llama_token llama_token_nl_impl (const struct llama_vocab & vocab);
  88. llama_token llama_token_pad_impl(const struct llama_vocab & vocab);
  89. llama_token llama_token_prefix_impl(const struct llama_vocab & vocab);
  90. llama_token llama_token_middle_impl(const struct llama_vocab & vocab);
  91. llama_token llama_token_suffix_impl(const struct llama_vocab & vocab);
  92. llama_token llama_token_fim_pre_impl(const struct llama_vocab & vocab);
  93. llama_token llama_token_fim_suf_impl(const struct llama_vocab & vocab);
  94. llama_token llama_token_fim_mid_impl(const struct llama_vocab & vocab);
  95. llama_token llama_token_fim_pad_impl(const struct llama_vocab & vocab);
  96. llama_token llama_token_fim_rep_impl(const struct llama_vocab & vocab);
  97. llama_token llama_token_fim_sep_impl(const struct llama_vocab & vocab);
  98. bool llama_add_bos_token_impl(const struct llama_vocab & vocab);
  99. bool llama_add_eos_token_impl(const struct llama_vocab & vocab);
  100. int32_t llama_tokenize_impl(
  101. const struct llama_vocab & vocab,
  102. const char * text,
  103. int32_t text_len,
  104. llama_token * tokens,
  105. int32_t n_tokens_max,
  106. bool add_special,
  107. bool parse_special);
  108. // does not write null-terminator to buf
  109. int32_t llama_token_to_piece_impl(
  110. const struct llama_vocab & vocab,
  111. llama_token token,
  112. char * buf,
  113. int32_t length,
  114. int32_t lstrip,
  115. bool special);
  116. // check if token0 is contained as a prefix in token1
  117. bool llama_token_is_prefix_impl(
  118. const struct llama_vocab & vocab,
  119. llama_token token0,
  120. llama_token token1);
  121. int32_t llama_detokenize_impl(
  122. const struct llama_vocab & vocab,
  123. const llama_token * tokens,
  124. int32_t n_tokens,
  125. char * text,
  126. int32_t text_len_max,
  127. bool remove_special,
  128. bool unparse_special);
  129. std::string llama_detokenize(
  130. const struct llama_vocab & vocab,
  131. const std::vector<llama_token> & tokens,
  132. bool special);