llama-batch.h 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. /**
  2. * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
  3. *
  4. * MIT License
  5. *
  6. * Copyright (c) 2023-2024 The ggml authors
  7. *
  8. * Permission is hereby granted, free of charge, to any person obtaining a copy
  9. * of this software and associated documentation files (the "Software"), to deal
  10. * in the Software without restriction, including without limitation the rights
  11. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  12. * copies of the Software, and to permit persons to whom the Software is
  13. * furnished to do so, subject to the following conditions:
  14. *
  15. * The above copyright notice and this permission notice shall be included in all
  16. * copies or substantial portions of the Software.
  17. *
  18. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  19. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  20. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  21. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  22. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  23. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  24. * SOFTWARE.
  25. */
  26. #pragma once
  27. #include "llama.h"
  28. #include <array>
  29. #include <vector>
  30. // very similar to llama_batch,
  31. // but has more metadata about sequences
  32. struct llama_ubatch {
  33. bool equal_seqs;
  34. // TODO: whole_seqs for embeddings?
  35. uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
  36. uint32_t n_seq_tokens; // tokens per sequence
  37. uint32_t n_seqs;
  38. llama_token * token; // [n_tokens]
  39. float * embd; // [n_embd, n_tokens]
  40. llama_pos * pos; // [n_tokens]
  41. int32_t * n_seq_id; // [n_seqs]
  42. llama_seq_id ** seq_id; // [n_seqs]
  43. int8_t * output; // [n_tokens]
  44. };
  45. struct llama_sbatch_seq {
  46. int32_t n_seq_id;
  47. llama_seq_id * seq_id;
  48. size_t offset;
  49. size_t length;
  50. };
  51. // sequence-length-aware batch splitting
  52. struct llama_sbatch {
  53. // tokens left in this batch
  54. size_t n_tokens;
  55. size_t n_embd;
  56. bool logits_all; // TODO: remove once lctx.logits_all is removed too
  57. // sorted indices into the batch
  58. std::vector<size_t> ids;
  59. // batch indices of the output
  60. std::vector<size_t> out_ids;
  61. std::vector<llama_sbatch_seq> seq;
  62. const llama_batch * batch = nullptr;
  63. // buffers for the ubatch
  64. std::vector<llama_token> ubatch_token;
  65. std::vector<float> ubatch_embd;
  66. std::vector<llama_pos> ubatch_pos;
  67. std::vector<int32_t> ubatch_n_seq_id;
  68. std::vector<llama_seq_id *> ubatch_seq_id;
  69. std::vector<int8_t> ubatch_output;
  70. llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);
  71. void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length);
  72. // simple split, unknown number of sequences of unequal lengths
  73. llama_ubatch split_simple(size_t n_ubatch);
  74. // make batches of equal-length sequences
  75. llama_ubatch split_equal(size_t n_ubatch);
  76. // sequence-wise split
  77. llama_ubatch split_seq(size_t n_ubatch);
  78. void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
  79. };
  80. // temporary allocate memory for the input batch if needed
  81. struct llama_batch_allocr {
  82. struct llama_batch batch;
  83. std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
  84. std::vector<llama_pos> pos;
  85. std::vector<int32_t> n_seq_id;
  86. std::vector<llama_seq_id *> seq_id;
  87. std::vector<int8_t> logits;
  88. // optionally fulfill the batch returned by llama_batch_get_one
  89. llama_batch_allocr(struct llama_batch in_batch, llama_pos p0);
  90. };