llama-kv-cache.h 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. /**
  2. * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
  3. *
  4. * MIT License
  5. *
  6. * Copyright (c) 2023-2024 The ggml authors
  7. *
  8. * Permission is hereby granted, free of charge, to any person obtaining a copy
  9. * of this software and associated documentation files (the "Software"), to deal
  10. * in the Software without restriction, including without limitation the rights
  11. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  12. * copies of the Software, and to permit persons to whom the Software is
  13. * furnished to do so, subject to the following conditions:
  14. *
  15. * The above copyright notice and this permission notice shall be included in all
  16. * copies or substantial portions of the Software.
  17. *
  18. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  19. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  20. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  21. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  22. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  23. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  24. * SOFTWARE.
  25. */
  26. #pragma once
  27. #include "llama.h"
  28. #include "ggml-cpp.h"
  29. #include <set>
  30. #include <vector>
  31. struct llama_kv_cell {
  32. llama_pos pos = -1;
  33. llama_pos delta = 0;
  34. int32_t src = -1; // used by recurrent state models to copy states
  35. int32_t tail = -1;
  36. std::set<llama_seq_id> seq_id;
  37. bool has_seq_id(const llama_seq_id & id) const {
  38. return seq_id.find(id) != seq_id.end();
  39. }
  40. bool is_empty() const {
  41. return seq_id.empty();
  42. }
  43. bool is_same_seq(const llama_kv_cell & other) const {
  44. return seq_id == other.seq_id;
  45. }
  46. };
  47. // ring-buffer of cached KV data
  48. struct llama_kv_cache {
  49. bool has_shift = false;
  50. bool do_defrag = false;
  51. bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
  52. bool v_trans = true; // the value tensor is transposed
  53. bool can_shift = false;
  54. // Note: The value of head isn't only used to optimize searching
  55. // for a free KV slot. llama_decode_internal also uses it, so it
  56. // cannot be freely changed after a slot has been allocated.
  57. uint32_t head = 0;
  58. uint32_t size = 0;
  59. uint32_t used = 0; // used cells (i.e. at least one seq_id)
  60. // computed before each graph build
  61. uint32_t n = 0;
  62. ggml_type type_k = GGML_TYPE_F16;
  63. ggml_type type_v = GGML_TYPE_F16;
  64. std::vector<llama_kv_cell> cells;
  65. std::vector<struct ggml_tensor *> k_l; // per layer
  66. std::vector<struct ggml_tensor *> v_l;
  67. std::vector<ggml_context_ptr> ctxs;
  68. std::vector<ggml_backend_buffer_ptr> bufs;
  69. size_t total_size() const {
  70. size_t size = 0;
  71. for (const auto & buf : bufs) {
  72. size += ggml_backend_buffer_get_size(buf.get());
  73. }
  74. return size;
  75. }
  76. // TODO: better data structures to reduce the cost of this operation
  77. llama_pos max_pos() const {
  78. llama_pos max_pos = -1;
  79. for (const auto & cell : cells) {
  80. max_pos = std::max(max_pos, cell.pos);
  81. }
  82. return max_pos;
  83. }
  84. };
  85. // a structure holds information about the slot found in llama_kv_cache_find_slot
  86. struct llama_kv_cache_slot_info {
  87. std::pair<uint32_t, uint32_t> boundaries; // slot boundaries [begin, end)
  88. bool found = false; // the slot was found
  89. explicit llama_kv_cache_slot_info(bool found_) : found{found_} {}
  90. llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {}
  91. operator bool() const { return found; }
  92. };
  93. // TODO: maybe not needed
  94. uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams);
  95. bool llama_kv_cache_init(
  96. struct llama_kv_cache & cache,
  97. const llama_model & model,
  98. const llama_cparams & cparams,
  99. ggml_type type_k,
  100. ggml_type type_v,
  101. uint32_t kv_size,
  102. bool offload);
  103. // find an empty slot of size "n_tokens" in the cache
  104. // updates the cache head
  105. // returns a structure holding information about the slot found
  106. // Note: On success, it's important that cache.head points
  107. // to the first cell of the slot.
  108. struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
  109. struct llama_kv_cache & cache,
  110. const struct llama_ubatch & batch);
  111. // find how many cells are currently in use
  112. uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache);
  113. void llama_kv_cache_clear(struct llama_kv_cache & cache);
  114. bool llama_kv_cache_seq_rm(
  115. struct llama_kv_cache & cache,
  116. llama_seq_id seq_id,
  117. llama_pos p0,
  118. llama_pos p1);
  119. void llama_kv_cache_seq_cp(
  120. struct llama_kv_cache & cache,
  121. llama_seq_id seq_id_src,
  122. llama_seq_id seq_id_dst,
  123. llama_pos p0,
  124. llama_pos p1);
  125. void llama_kv_cache_seq_keep(
  126. struct llama_kv_cache & cache,
  127. llama_seq_id seq_id);
  128. void llama_kv_cache_seq_add(
  129. struct llama_kv_cache & cache,
  130. llama_seq_id seq_id,
  131. llama_pos p0,
  132. llama_pos p1,
  133. llama_pos delta);
  134. void llama_kv_cache_seq_div(
  135. struct llama_kv_cache & cache,
  136. llama_seq_id seq_id,
  137. llama_pos p0,
  138. llama_pos p1,
  139. int d);
  140. llama_pos llama_kv_cache_seq_pos_max(
  141. struct llama_kv_cache & cache,
  142. llama_seq_id seq_id);
  143. void llama_kv_cache_defrag(struct llama_kv_cache & cache);
  144. int32_t llama_get_kv_cache_token_count(const struct llama_kv_cache & kv);
  145. int32_t llama_get_kv_cache_used_cells(const struct llama_kv_cache & kv);
  146. bool llama_kv_cache_can_shift(const struct llama_kv_cache & kv);
  147. //
  148. // kv cache view
  149. //
  150. struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache & kv, int32_t n_seq_max);
  151. void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache & kv);
  152. //
  153. // kv cache restore
  154. //
  155. // saves the kv_cache state for future recovery.
  156. // used to rollback llama_kv_cache_find_slot changes.
  157. struct llama_kv_slot_restorer {
  158. struct llama_kv_cache_state {
  159. uint32_t head = 0;
  160. uint32_t n = 0;
  161. } old_state;
  162. // for non-recurrent models only
  163. // list of slots to restore
  164. std::vector<std::pair<uint32_t, uint32_t>> slot_boundaries;
  165. bool do_restore = false;
  166. explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) {
  167. old_state.head = cache.head;
  168. old_state.n = cache.n;
  169. }
  170. // saves a slot information for future restoration
  171. void save(const struct llama_kv_cache_slot_info & slot) {
  172. if (slot) {
  173. do_restore = true;
  174. if (slot.boundaries.first != slot.boundaries.second) {
  175. slot_boundaries.push_back(slot.boundaries);
  176. }
  177. }
  178. }
  179. // must be explicitly called to restore the kv_cache state
  180. // and rollback changes from all llama_kv_cache_find_slot calls
  181. void restore(struct llama_kv_cache & cache) {
  182. if (do_restore) {
  183. cache.head = old_state.head;
  184. cache.n = old_state.n;
  185. if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased
  186. llama_kv_cache_seq_rm(cache, -1, -1, -1);
  187. } else {
  188. for (auto & slot : slot_boundaries) {
  189. llama_kv_cache_seq_rm(cache, -1, slot.first, slot.second);
  190. }
  191. }
  192. }
  193. }
  194. };