llama-kv-cache.cpp 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777
  1. /**
  2. * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
  3. *
  4. * MIT License
  5. *
  6. * Copyright (c) 2023-2024 The ggml authors
  7. *
  8. * Permission is hereby granted, free of charge, to any person obtaining a copy
  9. * of this software and associated documentation files (the "Software"), to deal
  10. * in the Software without restriction, including without limitation the rights
  11. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  12. * copies of the Software, and to permit persons to whom the Software is
  13. * furnished to do so, subject to the following conditions:
  14. *
  15. * The above copyright notice and this permission notice shall be included in all
  16. * copies or substantial portions of the Software.
  17. *
  18. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  19. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  20. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  21. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  22. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  23. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  24. * SOFTWARE.
  25. */
  26. #include "llama-kv-cache.h"
  27. #include "llama-impl.h"
  28. #include "llama-batch.h"
  29. #include "llama-cparams.h"
  30. #include "llama-model.h"
  31. #include <algorithm>
  32. #include <limits>
  33. #include <map>
  34. static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
  35. uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) {
  36. // the FA kernels require padding to avoid extra runtime boundary checks
  37. return cparams.flash_attn ? 256u : 32u;
  38. }
  39. bool llama_kv_cache_init(
  40. struct llama_kv_cache & cache,
  41. const llama_model & model,
  42. const llama_cparams & cparams,
  43. ggml_type type_k,
  44. ggml_type type_v,
  45. uint32_t kv_size,
  46. bool offload) {
  47. const struct llama_hparams & hparams = model.hparams;
  48. const int32_t n_layer = hparams.n_layer;
  49. cache.has_shift = false;
  50. cache.recurrent = llama_model_is_recurrent(&model);
  51. cache.v_trans = !cache.recurrent && !cparams.flash_attn;
  52. cache.can_shift = !cache.recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
  53. LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n",
  54. __func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, cache.can_shift);
  55. cache.head = 0;
  56. cache.size = kv_size;
  57. cache.used = 0;
  58. cache.type_k = type_k;
  59. cache.type_v = type_v;
  60. cache.cells.clear();
  61. cache.cells.resize(kv_size);
  62. // create a context for each buffer type
  63. std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
  64. auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
  65. auto it = ctx_map.find(buft);
  66. if (it == ctx_map.end()) {
  67. struct ggml_init_params params = {
  68. /*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()),
  69. /*.mem_buffer =*/ NULL,
  70. /*.no_alloc =*/ true,
  71. };
  72. ggml_context * ctx = ggml_init(params);
  73. if (!ctx) {
  74. return nullptr;
  75. }
  76. ctx_map[buft] = ctx;
  77. cache.ctxs.emplace_back(ctx);
  78. return ctx;
  79. }
  80. return it->second;
  81. };
  82. cache.k_l.reserve(n_layer);
  83. cache.v_l.reserve(n_layer);
  84. for (int i = 0; i < n_layer; i++) {
  85. // for cross attention layers
  86. if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layers(i)) {
  87. const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
  88. const llama_model::buft_list_t * buft_list;
  89. if (offload) {
  90. buft_list = model.dev_layer.at(i).buft_list;
  91. } else {
  92. buft_list = &model.cpu_buft_list;
  93. }
  94. ggml_backend_buffer_type_t buft = select_buft(*buft_list,
  95. [&](ggml_context * ctx) {
  96. ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
  97. if (hparams.rope_type == LLAMA_ROPE_TYPE_NONE) {
  98. return k;
  99. }
  100. ggml_tensor * p = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
  101. return ggml_rope(ctx, k, p, hparams.n_rot, hparams.rope_type);
  102. });
  103. ggml_context * ctx = ctx_for_buft(buft);
  104. if (!ctx) {
  105. LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__);
  106. return false;
  107. }
  108. ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_k, 6404, hparams.n_head_kv(i));
  109. ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_v, 6404, hparams.n_head_kv(i));
  110. ggml_format_name(k, "cache_k_l%d", i);
  111. ggml_format_name(v, "cache_v_l%d", i);
  112. cache.k_l.push_back(k);
  113. cache.v_l.push_back(v);
  114. continue;
  115. }
  116. const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
  117. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
  118. LLAMA_LOG_DEBUG("%s: layer %d: n_embd_k_gqa = %d, n_embd_v_gqa = %d\n", __func__, i, n_embd_k_gqa, n_embd_v_gqa);
  119. ggml_backend_buffer_type_t buft;
  120. if (offload) {
  121. auto * dev = model.dev_layer.at(i).dev;
  122. buft = ggml_backend_dev_buffer_type(dev);
  123. } else {
  124. buft = ggml_backend_cpu_buffer_type();
  125. }
  126. ggml_context * ctx = ctx_for_buft(buft);
  127. if (!ctx) {
  128. LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__);
  129. return false;
  130. }
  131. ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
  132. ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
  133. ggml_format_name(k, "cache_k_l%d", i);
  134. ggml_format_name(v, "cache_v_l%d", i);
  135. cache.k_l.push_back(k);
  136. cache.v_l.push_back(v);
  137. }
  138. // allocate tensors and initialize the buffers to avoid NaNs in the padding
  139. for (auto it : ctx_map) {
  140. auto * buft = it.first;
  141. auto * ctx = it.second;
  142. ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
  143. if (!buf) {
  144. LLAMA_LOG_ERROR("%s: failed to allocate buffer for kv cache\n", __func__);
  145. return false;
  146. }
  147. ggml_backend_buffer_clear(buf, 0);
  148. LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
  149. cache.bufs.emplace_back(buf);
  150. }
  151. return true;
  152. }
  153. struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
  154. struct llama_kv_cache & cache,
  155. const struct llama_ubatch & batch) {
  156. const uint32_t n_tokens = batch.n_tokens;
  157. const uint32_t n_seqs = batch.n_seqs;
  158. const uint32_t n_seq_tokens = batch.n_seq_tokens;
  159. if (cache.recurrent) {
  160. // For recurrent state architectures (like Mamba or RWKV),
  161. // each cache cell can store the state for a whole sequence.
  162. // A slot should be always be contiguous.
  163. // can only process batches with an equal number of new tokens in each sequence
  164. GGML_ASSERT(batch.equal_seqs);
  165. int32_t min = cache.size - 1;
  166. int32_t max = 0;
  167. // everything should fit if all seq_ids are smaller than the max
  168. for (uint32_t s = 0; s < n_seqs; ++s) {
  169. const uint32_t n_seq_id = batch.n_seq_id[s];
  170. for (uint32_t j = 0; j < n_seq_id; ++j) {
  171. const llama_seq_id seq_id = batch.seq_id[s][j];
  172. if (seq_id < 0 || (uint32_t) seq_id >= cache.size) {
  173. // too big seq_id
  174. // TODO: would it be possible to resize the cache instead?
  175. LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
  176. return llama_kv_cache_slot_info_failed;
  177. }
  178. if (j > 0) {
  179. llama_kv_cell & seq = cache.cells[seq_id];
  180. if (seq.tail >= 0) {
  181. llama_kv_cell & cell = cache.cells[seq.tail];
  182. // clear cells from seq_ids that become shared
  183. // (should not normally happen, but let's handle it anyway)
  184. cell.seq_id.erase(seq_id);
  185. seq.tail = -1;
  186. if (cell.seq_id.empty()) {
  187. cell.pos = -1;
  188. cell.src = -1;
  189. cache.used -= 1;
  190. }
  191. }
  192. }
  193. }
  194. }
  195. #ifndef NDEBUG
  196. {
  197. std::vector<int32_t> tails_verif;
  198. tails_verif.assign(cache.size, -1);
  199. for (uint32_t i = 0; i < cache.size; ++i) {
  200. llama_kv_cell & cell = cache.cells[i];
  201. for (llama_seq_id seq_id : cell.seq_id) {
  202. if (tails_verif[seq_id] != -1) {
  203. LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
  204. }
  205. tails_verif[seq_id] = i;
  206. }
  207. }
  208. for (uint32_t i = 0; i < cache.size; ++i) {
  209. if (tails_verif[i] != cache.cells[i].tail) {
  210. LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cache.cells[i].tail, tails_verif[i]);
  211. }
  212. }
  213. }
  214. #endif
  215. // find next empty cell
  216. uint32_t next_empty_cell = cache.head;
  217. for (uint32_t i = 0; i < cache.size; ++i) {
  218. if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; }
  219. llama_kv_cell & cell = cache.cells[next_empty_cell];
  220. if (cell.is_empty()) { break; }
  221. next_empty_cell += 1;
  222. }
  223. // find usable cell range
  224. for (uint32_t s = 0; s < n_seqs; ++s) {
  225. const llama_seq_id seq_id = batch.seq_id[s][0];
  226. llama_kv_cell & seq_meta = cache.cells[seq_id];
  227. bool has_cell = false;
  228. if (seq_meta.tail >= 0) {
  229. llama_kv_cell & cell = cache.cells[seq_meta.tail];
  230. GGML_ASSERT(cell.has_seq_id(seq_id));
  231. // does this seq_id "own" the cell?
  232. if (cell.seq_id.size() == 1) { has_cell = true; }
  233. }
  234. if (!has_cell) {
  235. llama_kv_cell & empty_cell = cache.cells[next_empty_cell];
  236. GGML_ASSERT(empty_cell.is_empty());
  237. // copy old tail into the empty cell
  238. if (seq_meta.tail >= 0) {
  239. llama_kv_cell & orig_cell = cache.cells[seq_meta.tail];
  240. empty_cell.pos = orig_cell.pos;
  241. empty_cell.src = orig_cell.src;
  242. orig_cell.seq_id.erase(seq_id);
  243. empty_cell.seq_id.insert(seq_id); // will be overwritten
  244. }
  245. seq_meta.tail = next_empty_cell;
  246. // find next empty cell
  247. if (s + 1 < n_seqs) {
  248. next_empty_cell += 1;
  249. for (uint32_t i = 0; i < cache.size; ++i) {
  250. if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; }
  251. llama_kv_cell & cell = cache.cells[next_empty_cell];
  252. if (cell.is_empty()) { break; }
  253. next_empty_cell += 1;
  254. }
  255. }
  256. }
  257. if (min > seq_meta.tail) { min = seq_meta.tail; }
  258. if (max < seq_meta.tail) { max = seq_meta.tail; }
  259. }
  260. // gather and re-order
  261. for (uint32_t s = 0; s < n_seqs; ++s) {
  262. int32_t dst_id = s + min;
  263. int32_t src_id = cache.cells[batch.seq_id[s][0]].tail;
  264. if (dst_id != src_id) {
  265. llama_kv_cell & dst_cell = cache.cells[dst_id];
  266. llama_kv_cell & src_cell = cache.cells[src_id];
  267. std::swap(dst_cell.pos, src_cell.pos);
  268. std::swap(dst_cell.src, src_cell.src);
  269. std::swap(dst_cell.seq_id, src_cell.seq_id);
  270. // swap tails (assuming they NEVER overlap)
  271. for (const llama_seq_id seq_id : src_cell.seq_id) {
  272. cache.cells[seq_id].tail = src_id;
  273. }
  274. for (const llama_seq_id seq_id : dst_cell.seq_id) {
  275. cache.cells[seq_id].tail = dst_id;
  276. }
  277. }
  278. }
  279. // update the pos of the used seqs
  280. for (uint32_t s = 0; s < n_seqs; ++s) {
  281. const llama_pos last_pos = batch.pos[n_seq_tokens * s + n_seq_tokens - 1];
  282. int32_t cell_id = s + min;
  283. llama_kv_cell & cell = cache.cells[cell_id];
  284. if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
  285. // What should happen when the pos backtracks or skips a value?
  286. // Clearing the state mid-batch would require special-casing which isn't done.
  287. LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
  288. __func__, last_pos, cell.pos, batch.seq_id[s][0], n_seq_tokens);
  289. }
  290. cell.pos = last_pos;
  291. cell.seq_id.clear();
  292. for (int32_t j = 0; j < batch.n_seq_id[s]; ++j) {
  293. const llama_seq_id seq_id = batch.seq_id[s][j];
  294. cell.seq_id.insert(seq_id);
  295. cache.cells[seq_id].tail = cell_id;
  296. }
  297. }
  298. // allow getting the range of used cells, from head to head + n
  299. cache.head = min;
  300. cache.n = max - min + 1;
  301. cache.used = std::count_if(cache.cells.begin(), cache.cells.end(),
  302. [](const llama_kv_cell& cell){ return !cell.is_empty(); });
  303. // sanity check
  304. return llama_kv_cache_slot_info(cache.n >= n_seqs);
  305. }
  306. // otherwise, one cell per token.
  307. if (n_tokens > cache.size) {
  308. LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size);
  309. return llama_kv_cache_slot_info_failed;
  310. }
  311. uint32_t n_tested = 0;
  312. while (true) {
  313. if (cache.head + n_tokens > cache.size) {
  314. n_tested += cache.size - cache.head;
  315. cache.head = 0;
  316. continue;
  317. }
  318. bool found = true;
  319. for (uint32_t i = 0; i < n_tokens; i++) {
  320. if (cache.cells[cache.head + i].pos >= 0) {
  321. found = false;
  322. cache.head += i + 1;
  323. n_tested += i + 1;
  324. break;
  325. }
  326. }
  327. if (found) {
  328. break;
  329. }
  330. if (n_tested >= cache.size) {
  331. //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
  332. return llama_kv_cache_slot_info_failed;
  333. }
  334. }
  335. for (uint32_t s = 0; s < n_seqs; s++) {
  336. for (uint32_t i = 0; i < n_seq_tokens; ++i) {
  337. uint32_t k = s*n_seq_tokens + i;
  338. cache.cells[cache.head + k].pos = batch.pos[k];
  339. for (int32_t j = 0; j < batch.n_seq_id[s]; j++) {
  340. cache.cells[cache.head + k].seq_id.insert(batch.seq_id[s][j]);
  341. }
  342. }
  343. }
  344. cache.used += n_tokens;
  345. return llama_kv_cache_slot_info(cache.head, cache.head + n_tokens);
  346. }
  347. uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
  348. for (uint32_t i = cache.size; i > 0; --i) {
  349. const llama_kv_cell & cell = cache.cells[i - 1];
  350. if (cell.pos >= 0 && !cell.is_empty()) {
  351. return i;
  352. }
  353. }
  354. return 0;
  355. }
  356. void llama_kv_cache_clear(struct llama_kv_cache & cache) {
  357. for (int32_t i = 0; i < (int32_t) cache.size; ++i) {
  358. cache.cells[i].pos = -1;
  359. cache.cells[i].seq_id.clear();
  360. cache.cells[i].src = -1;
  361. cache.cells[i].tail = -1;
  362. }
  363. cache.head = 0;
  364. cache.used = 0;
  365. for (auto & buf : cache.bufs) {
  366. ggml_backend_buffer_clear(buf.get(), 0);
  367. }
  368. }
  369. bool llama_kv_cache_seq_rm(
  370. struct llama_kv_cache & cache,
  371. llama_seq_id seq_id,
  372. llama_pos p0,
  373. llama_pos p1) {
  374. uint32_t new_head = cache.size;
  375. if (p0 < 0) p0 = 0;
  376. if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
  377. // models like Mamba or RWKV can't have a state partially erased
  378. if (cache.recurrent) {
  379. if (seq_id >= (int64_t) cache.size) {
  380. // could be fatal
  381. return false;
  382. }
  383. if (0 <= seq_id) {
  384. int32_t & tail_id = cache.cells[seq_id].tail;
  385. if (tail_id >= 0) {
  386. const llama_kv_cell & cell = cache.cells[tail_id];
  387. // partial intersection is invalid
  388. if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
  389. return false;
  390. }
  391. // invalidate tails which will be cleared
  392. if (p0 <= cell.pos && cell.pos < p1) {
  393. tail_id = -1;
  394. }
  395. }
  396. } else {
  397. // seq_id is negative, then the range should include everything or nothing
  398. if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
  399. return false;
  400. }
  401. }
  402. }
  403. for (uint32_t i = 0; i < cache.size; ++i) {
  404. if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
  405. if (seq_id < 0) {
  406. cache.cells[i].seq_id.clear();
  407. } else if (cache.cells[i].has_seq_id(seq_id)) {
  408. cache.cells[i].seq_id.erase(seq_id);
  409. } else {
  410. continue;
  411. }
  412. if (cache.cells[i].is_empty()) {
  413. // keep count of the number of used cells
  414. if (cache.cells[i].pos >= 0) cache.used--;
  415. cache.cells[i].pos = -1;
  416. cache.cells[i].src = -1;
  417. if (new_head == cache.size) new_head = i;
  418. }
  419. }
  420. }
  421. // If we freed up a slot, set head to it so searching can start there.
  422. if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
  423. return true;
  424. }
  425. void llama_kv_cache_seq_cp(
  426. struct llama_kv_cache & cache,
  427. llama_seq_id seq_id_src,
  428. llama_seq_id seq_id_dst,
  429. llama_pos p0,
  430. llama_pos p1) {
  431. if (p0 < 0) p0 = 0;
  432. if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
  433. if (cache.recurrent) {
  434. if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) {
  435. llama_kv_cell & tail_src = cache.cells[seq_id_src];
  436. llama_kv_cell & tail_dst = cache.cells[seq_id_dst];
  437. if (tail_dst.tail >= 0) {
  438. // clear destination seq_id if it wasn't empty
  439. llama_kv_cell & cell_dst = cache.cells[tail_dst.tail];
  440. cell_dst.seq_id.erase(seq_id_dst);
  441. tail_dst.tail = -1;
  442. if (cell_dst.seq_id.empty()) {
  443. cell_dst.pos = -1;
  444. cell_dst.delta = -1;
  445. cell_dst.src = -1;
  446. cache.used -= 1;
  447. }
  448. }
  449. if (tail_src.tail >= 0) {
  450. llama_kv_cell & cell_src = cache.cells[tail_src.tail];
  451. cell_src.seq_id.insert(seq_id_dst);
  452. tail_dst.tail = tail_src.tail;
  453. }
  454. }
  455. return;
  456. }
  457. // otherwise, this is the KV cache of a Transformer-like model
  458. cache.head = 0;
  459. for (uint32_t i = 0; i < cache.size; ++i) {
  460. if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
  461. cache.cells[i].seq_id.insert(seq_id_dst);
  462. }
  463. }
  464. }
  465. void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) {
  466. uint32_t new_head = cache.size;
  467. for (uint32_t i = 0; i < cache.size; ++i) {
  468. if (cache.recurrent && (llama_seq_id) i != seq_id) {
  469. cache.cells[i].tail = -1;
  470. }
  471. if (!cache.cells[i].has_seq_id(seq_id)) {
  472. if (cache.cells[i].pos >= 0) cache.used--;
  473. cache.cells[i].pos = -1;
  474. cache.cells[i].src = -1;
  475. cache.cells[i].seq_id.clear();
  476. if (new_head == cache.size) new_head = i;
  477. } else {
  478. cache.cells[i].seq_id.clear();
  479. cache.cells[i].seq_id.insert(seq_id);
  480. }
  481. }
  482. // If we freed up a slot, set head to it so searching can start there.
  483. if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
  484. }
  485. void llama_kv_cache_seq_add(
  486. struct llama_kv_cache & cache,
  487. llama_seq_id seq_id,
  488. llama_pos p0,
  489. llama_pos p1,
  490. llama_pos delta) {
  491. uint32_t new_head = cache.size;
  492. if (p0 < 0) p0 = 0;
  493. if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
  494. // If there is no range then return early to avoid looping over the cache.
  495. if (p0 == p1) return;
  496. if (cache.recurrent) {
  497. // for Mamba-like or RWKV models, only the pos needs to be shifted
  498. if (0 <= seq_id && seq_id < (int64_t) cache.size) {
  499. const int32_t tail_id = cache.cells[seq_id].tail;
  500. if (tail_id >= 0) {
  501. llama_kv_cell & cell = cache.cells[tail_id];
  502. if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
  503. cell.pos += delta;
  504. }
  505. }
  506. }
  507. return;
  508. }
  509. for (uint32_t i = 0; i < cache.size; ++i) {
  510. if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
  511. cache.has_shift = true;
  512. cache.cells[i].pos += delta;
  513. cache.cells[i].delta += delta;
  514. if (cache.cells[i].pos < 0) {
  515. if (!cache.cells[i].is_empty()) {
  516. cache.used--;
  517. }
  518. cache.cells[i].pos = -1;
  519. cache.cells[i].seq_id.clear();
  520. if (new_head == cache.size) {
  521. new_head = i;
  522. }
  523. }
  524. }
  525. }
  526. // If we freed up a slot, set head to it so searching can start there.
  527. // Otherwise we just start the next search from the beginning.
  528. cache.head = new_head != cache.size ? new_head : 0;
  529. }
  530. void llama_kv_cache_seq_div(
  531. struct llama_kv_cache & cache,
  532. llama_seq_id seq_id,
  533. llama_pos p0,
  534. llama_pos p1,
  535. int d) {
  536. if (p0 < 0) p0 = 0;
  537. if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
  538. // If there is no range then return early to avoid looping over the cache.
  539. if (p0 == p1) return;
  540. if (cache.recurrent) {
  541. // for Mamba-like or RWKV models, only the pos needs to be changed
  542. if (0 <= seq_id && seq_id < (int64_t) cache.size) {
  543. const int32_t tail_id = cache.cells[seq_id].tail;
  544. if (tail_id >= 0) {
  545. llama_kv_cell & cell = cache.cells[tail_id];
  546. if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
  547. cell.pos /= d;
  548. }
  549. }
  550. }
  551. return;
  552. }
  553. for (uint32_t i = 0; i < cache.size; ++i) {
  554. if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
  555. cache.has_shift = true;
  556. {
  557. llama_pos p_old = cache.cells[i].pos;
  558. cache.cells[i].pos /= d;
  559. cache.cells[i].delta += cache.cells[i].pos - p_old;
  560. }
  561. }
  562. }
  563. }
  564. llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama_seq_id seq_id) {
  565. llama_pos result = 0;
  566. for (uint32_t i = 0; i < cache.size; ++i) {
  567. if (cache.cells[i].has_seq_id(seq_id)) {
  568. result = std::max(result, cache.cells[i].pos);
  569. }
  570. }
  571. return result;
  572. }
  573. void llama_kv_cache_defrag(struct llama_kv_cache & cache) {
  574. if (!cache.recurrent) {
  575. cache.do_defrag = true;
  576. }
  577. }
  578. int32_t llama_get_kv_cache_token_count(const struct llama_kv_cache & kv) {
  579. int result = 0;
  580. for (uint32_t i = 0; i < kv.size; i++) {
  581. result += kv.cells[i].seq_id.size();
  582. }
  583. return result;
  584. }
  585. int32_t llama_get_kv_cache_used_cells(const struct llama_kv_cache & kv) {
  586. return kv.used;
  587. }
  588. bool llama_kv_cache_can_shift(const struct llama_kv_cache & kv) {
  589. return kv.can_shift;
  590. }
  591. //
  592. // kv cache view
  593. //
  594. struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache & kv, int32_t n_seq_max) {
  595. struct llama_kv_cache_view result = {
  596. /*.n_cells = */ 0,
  597. /*.n_seq_max = */ n_seq_max,
  598. /*.token_count = */ 0,
  599. /*.used_cells = */ llama_get_kv_cache_used_cells(kv),
  600. /*.max_contiguous = */ 0,
  601. /*.max_contiguous_idx = */ -1,
  602. /*.cells = */ nullptr,
  603. /*.cells_sequences = */ nullptr,
  604. };
  605. return result;
  606. }
  607. void llama_kv_cache_view_free(struct llama_kv_cache_view * view) {
  608. if (view->cells != nullptr) {
  609. free(view->cells);
  610. view->cells = nullptr;
  611. }
  612. if (view->cells_sequences != nullptr) {
  613. free(view->cells_sequences);
  614. view->cells_sequences = nullptr;
  615. }
  616. }
  617. void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache & kv) {
  618. if (uint32_t(view->n_cells) < kv.size || view->cells == nullptr) {
  619. view->n_cells = int32_t(kv.size);
  620. void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells);
  621. GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells");
  622. view->cells = (struct llama_kv_cache_view_cell *)p;
  623. p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_seq_max * view->n_cells);
  624. GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences");
  625. view->cells_sequences = (llama_seq_id *)p;
  626. }
  627. const std::vector<llama_kv_cell> & kv_cells = kv.cells;
  628. llama_kv_cache_view_cell * c_curr = view->cells;
  629. llama_seq_id * cs_curr = view->cells_sequences;
  630. int32_t used_cells = 0;
  631. int32_t token_count = 0;
  632. int32_t curr_contig_idx = -1;
  633. uint32_t max_contig = 0;
  634. int32_t max_contig_idx = -1;
  635. for (int32_t i = 0; i < int32_t(kv.size); i++, c_curr++, cs_curr += view->n_seq_max) {
  636. const size_t curr_size = kv_cells[i].seq_id.size();
  637. token_count += curr_size;
  638. c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;
  639. if (curr_size > 0) {
  640. if (curr_contig_idx >= 0 && uint32_t(i - curr_contig_idx) > max_contig) {
  641. max_contig = i - curr_contig_idx;
  642. max_contig_idx = curr_contig_idx;
  643. }
  644. curr_contig_idx = -1;
  645. } else if (curr_contig_idx < 0) {
  646. curr_contig_idx = i;
  647. }
  648. int seq_idx = 0;
  649. for (const llama_seq_id it : kv_cells[i].seq_id) {
  650. if (seq_idx >= view->n_seq_max) {
  651. break;
  652. }
  653. cs_curr[seq_idx] = it;
  654. seq_idx++;
  655. }
  656. if (seq_idx != 0) {
  657. used_cells++;
  658. }
  659. for (; seq_idx < view->n_seq_max; seq_idx++) {
  660. cs_curr[seq_idx] = -1;
  661. }
  662. }
  663. if (curr_contig_idx >= 0 && kv_cells.size() - curr_contig_idx > max_contig) {
  664. max_contig_idx = curr_contig_idx;
  665. max_contig = kv_cells.size() - curr_contig_idx;
  666. }
  667. view->max_contiguous = max_contig;
  668. view->max_contiguous_idx = max_contig_idx;
  669. view->token_count = token_count;
  670. view->used_cells = used_cells;
  671. if (uint32_t(used_cells) != kv.used) {
  672. LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n",
  673. __func__, kv.used, used_cells);
  674. }
  675. }