llama-batch.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. /**
  2. * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
  3. *
  4. * MIT License
  5. *
  6. * Copyright (c) 2023-2024 The ggml authors
  7. *
  8. * Permission is hereby granted, free of charge, to any person obtaining a copy
  9. * of this software and associated documentation files (the "Software"), to deal
  10. * in the Software without restriction, including without limitation the rights
  11. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  12. * copies of the Software, and to permit persons to whom the Software is
  13. * furnished to do so, subject to the following conditions:
  14. *
  15. * The above copyright notice and this permission notice shall be included in all
  16. * copies or substantial portions of the Software.
  17. *
  18. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  19. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  20. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  21. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  22. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  23. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  24. * SOFTWARE.
  25. */
  26. #include "llama-batch.h"
  27. #include <cstring>
  28. #include <algorithm>
  29. llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
  30. // clear empty sequences
  31. // the previous ubatch is assumed to be gone,
  32. // so nothing should refer to values in these sequences anymore.
  33. for (size_t i = seq.size(); i-- > 0;) {
  34. if (seq[i].length == 0) {
  35. seq.pop_back();
  36. } else {
  37. break;
  38. }
  39. }
  40. ubatch_token.resize(!has_embd ? n_ubatch : 0);
  41. ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
  42. ubatch_pos.resize(n_ubatch);
  43. ubatch_n_seq_id.resize(n_ubatch);
  44. ubatch_seq_id.resize(n_ubatch);
  45. ubatch_output.resize(n_ubatch);
  46. llama_ubatch ubatch = {
  47. /*equal_seqs =*/ true,
  48. /*n_tokens =*/ 0,
  49. /*n_seq_tokens =*/ 0,
  50. /*n_seqs =*/ 0,
  51. /*token =*/ !has_embd ? ubatch_token.data() : nullptr,
  52. /*embd =*/ has_embd ? ubatch_embd.data() : nullptr,
  53. /*pos =*/ ubatch_pos.data(),
  54. /*n_seq_id =*/ ubatch_n_seq_id.data(),
  55. /*seq_id =*/ ubatch_seq_id.data(),
  56. /*output =*/ ubatch_output.data(),
  57. };
  58. return ubatch;
  59. }
  60. void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
  61. GGML_ASSERT(batch != nullptr);
  62. GGML_ASSERT(length <= seq.length);
  63. // Can only add sequences of equal lengths to a batch,
  64. // otherwise it isn't clear to which sequence a token belongs
  65. GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs);
  66. GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
  67. // NOTE: loops are separated for cache-friendliness
  68. if (batch->token) {
  69. if (ubatch.equal_seqs) {
  70. for (size_t i = 0; i < length; ++i) {
  71. ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
  72. }
  73. } else {
  74. // simple split
  75. ubatch.token = batch->token + seq.offset;
  76. }
  77. } else {
  78. ubatch.token = nullptr;
  79. }
  80. if (batch->embd) {
  81. if (ubatch.equal_seqs) {
  82. for (size_t i = 0; i < length; ++i) {
  83. memcpy(
  84. ubatch.embd + (n_embd * (ubatch.n_tokens + i)),
  85. batch->embd + (n_embd * ids[seq.offset + i]),
  86. n_embd * sizeof(float)
  87. );
  88. }
  89. } else {
  90. // simple split
  91. ubatch.embd = batch->embd + (n_embd * seq.offset);
  92. }
  93. } else {
  94. ubatch.embd = nullptr;
  95. }
  96. if (ubatch.equal_seqs) {
  97. for (size_t i = 0; i < length; ++i) {
  98. ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
  99. }
  100. } else {
  101. // simple split
  102. ubatch.pos = batch->pos + seq.offset;
  103. }
  104. if (ubatch.equal_seqs) {
  105. ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
  106. if (seq.seq_id) {
  107. ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
  108. }
  109. } else {
  110. // simple split
  111. if (batch->n_seq_id) {
  112. ubatch.n_seq_id = batch->n_seq_id + seq.offset;
  113. } else {
  114. for (size_t i = 0; i < length; ++i) {
  115. ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
  116. }
  117. }
  118. if (batch->seq_id) {
  119. ubatch.seq_id = batch->seq_id + seq.offset;
  120. }
  121. }
  122. if (logits_all) {
  123. for (size_t i = 0; i < length; ++i) {
  124. ubatch.output[ubatch.n_tokens + i] = 1;
  125. out_ids.push_back(ids[seq.offset + i]);
  126. }
  127. } else if (batch->logits) {
  128. if (ubatch.equal_seqs) {
  129. for (size_t i = 0; i < length; ++i) {
  130. size_t id = ids[seq.offset + i];
  131. int8_t is_output = batch->logits[id];
  132. ubatch.output[ubatch.n_tokens + i] = is_output;
  133. if (is_output) { out_ids.push_back(id); }
  134. }
  135. } else {
  136. // simple split
  137. ubatch.output = batch->logits + seq.offset;
  138. for (size_t i = 0; i < length; ++i) {
  139. if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
  140. }
  141. }
  142. } else {
  143. // only get last output
  144. for (size_t i = 0; i < length; ++i) {
  145. size_t id = ids[seq.offset + i];
  146. int8_t is_last = id == ids.size() - 1;
  147. ubatch.output[ubatch.n_tokens + i] = is_last;
  148. if (is_last) { out_ids.push_back(id); }
  149. }
  150. }
  151. if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
  152. ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
  153. }
  154. ubatch.n_tokens += length;
  155. ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
  156. seq.offset += length;
  157. seq.length -= length;
  158. n_tokens -= length;
  159. GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
  160. }
  161. llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
  162. n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
  163. llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
  164. ubatch.equal_seqs = false;
  165. if (!seq.empty()) {
  166. llama_sbatch_seq & s = seq[0];
  167. size_t length = s.length < n_ubatch ? s.length : n_ubatch;
  168. GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
  169. add_seq_to_ubatch(ubatch, s, length);
  170. }
  171. return ubatch;
  172. }
  173. llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
  174. n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
  175. llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
  176. if (!seq.empty()) {
  177. size_t length = 0;
  178. size_t n_tokens_in_ubatch = 0;
  179. GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
  180. // smallest first, because it's easier to split this way;
  181. // starting from the end to pop in constant time.
  182. for (size_t i = seq.size(); i-- > 0;) {
  183. llama_sbatch_seq & s = seq[i];
  184. GGML_ASSERT(s.length > 0);
  185. if (length == 0) {
  186. length = s.length < n_ubatch ? s.length : n_ubatch;
  187. }
  188. add_seq_to_ubatch(ubatch, s, length);
  189. n_tokens_in_ubatch += length;
  190. // shared prompts can't be mixed with any of their sequences,
  191. // so it's safer to compute them in their own ubatch
  192. if (s.n_seq_id > 1) { break; }
  193. // stop when there isn't enough space for another sequence
  194. if (length + n_tokens_in_ubatch > n_ubatch) { break; }
  195. }
  196. }
  197. return ubatch;
  198. }
  199. llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
  200. n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
  201. llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
  202. if (!seq.empty()) {
  203. llama_sbatch_seq & s = seq[seq.size() - 1];
  204. size_t length = s.length < n_ubatch ? s.length : n_ubatch;
  205. GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
  206. add_seq_to_ubatch(ubatch, s, length);
  207. }
  208. return ubatch;
  209. }
  210. void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
  211. GGML_ASSERT(batch.n_tokens >= 0);
  212. this->batch = &batch;
  213. this->n_embd = n_embd;
  214. this->logits_all = logits_all;
  215. n_tokens = batch.n_tokens;
  216. ids.resize(n_tokens);
  217. out_ids.clear();
  218. // TODO: reserve out_ids and seq
  219. for (size_t i = 0; i < n_tokens; ++i) {
  220. ids[i] = i;
  221. }
  222. if (simple_split) {
  223. seq.resize(1);
  224. llama_sbatch_seq & s = seq[0];
  225. s.n_seq_id = 0;
  226. s.seq_id = nullptr;
  227. s.offset = 0;
  228. s.length = n_tokens;
  229. return;
  230. }
  231. std::sort(ids.begin(), ids.end(),
  232. [&batch](size_t a, size_t b) {
  233. int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
  234. int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
  235. // sort by seq_id, then by pos
  236. if (n_seq_a == n_seq_b) {
  237. if (batch.seq_id) {
  238. for (int32_t i = 0; i < n_seq_a; ++i) {
  239. llama_seq_id seq_id_a = batch.seq_id[a][i];
  240. llama_seq_id seq_id_b = batch.seq_id[b][i];
  241. // smaller seq_ids go first
  242. if (seq_id_a != seq_id_b) {
  243. return seq_id_a < seq_id_b;
  244. }
  245. }
  246. }
  247. // when all else is equal, sort by pos
  248. if (batch.pos) {
  249. return batch.pos[a] < batch.pos[b];
  250. }
  251. // no pos, sort by id
  252. return a < b;
  253. }
  254. // shared prompts go first
  255. return n_seq_a > n_seq_b;
  256. }
  257. );
  258. // init seq
  259. llama_sbatch_seq * last_seq = nullptr;
  260. for (size_t i = 0; i < n_tokens; ++i) {
  261. const size_t bi = ids[i];
  262. const int32_t n_seqs = batch.n_seq_id[bi];
  263. llama_seq_id * seq_ids = batch.seq_id[bi];
  264. if (last_seq != nullptr) {
  265. bool same = n_seqs == last_seq->n_seq_id;
  266. for (int32_t j = 0; same && j < n_seqs; ++j) {
  267. if (seq_ids[j] != last_seq->seq_id[j]) {
  268. same = false;
  269. }
  270. }
  271. if (same) {
  272. last_seq->length += 1;
  273. continue;
  274. }
  275. }
  276. llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
  277. seq.push_back(new_seq);
  278. last_seq = &seq.back();
  279. }
  280. // keep shared prompts first at the end, then sort by length descending.
  281. std::sort(seq.begin(), seq.end(),
  282. [](llama_sbatch_seq & a, llama_sbatch_seq & b) {
  283. if (a.n_seq_id == b.n_seq_id) {
  284. return a.length > b.length;
  285. }
  286. return a.n_seq_id < b.n_seq_id;
  287. }
  288. );
  289. }
  290. llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) {
  291. batch = in_batch;
  292. GGML_ASSERT(batch.n_tokens > 0);
  293. if (!batch.pos) {
  294. pos.resize(batch.n_tokens);
  295. for (int32_t i = 0; i < batch.n_tokens; i++) {
  296. pos[i] = i + p0;
  297. }
  298. batch.pos = pos.data();
  299. }
  300. if (!batch.n_seq_id) {
  301. n_seq_id.resize(batch.n_tokens);
  302. for (int32_t i = 0; i < batch.n_tokens; i++) {
  303. n_seq_id[i] = seq_id_0.size();
  304. }
  305. batch.n_seq_id = n_seq_id.data();
  306. }
  307. if (!batch.seq_id) {
  308. seq_id.resize(batch.n_tokens + 1);
  309. seq_id[batch.n_tokens] = NULL;
  310. for (int32_t i = 0; i < batch.n_tokens; i++) {
  311. seq_id[i] = seq_id_0.data();
  312. }
  313. batch.seq_id = seq_id.data();
  314. }
  315. if (!batch.logits) {
  316. logits.resize(batch.n_tokens);
  317. logits[logits.size() - 1] = true;
  318. batch.logits = logits.data();
  319. }
  320. }
  321. //
  322. // interface implementation
  323. //
  324. struct llama_batch llama_batch_get_one(
  325. llama_token * tokens,
  326. int32_t n_tokens) {
  327. return {
  328. /*n_tokens =*/ n_tokens,
  329. /*tokens =*/ tokens,
  330. /*embd =*/ nullptr,
  331. /*n_embd =*/ 0,
  332. /*pos =*/ nullptr,
  333. /*n_seq_id =*/ nullptr,
  334. /*seq_id =*/ nullptr,
  335. /*logits =*/ nullptr,
  336. };
  337. }
  338. struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
  339. llama_batch batch = {
  340. /*n_tokens =*/ 0,
  341. /*tokens =*/ nullptr,
  342. /*embd =*/ nullptr,
  343. /*n_embd =*/ 0,
  344. /*pos =*/ nullptr,
  345. /*n_seq_id =*/ nullptr,
  346. /*seq_id =*/ nullptr,
  347. /*logits =*/ nullptr,
  348. };
  349. if (embd) {
  350. batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
  351. batch.n_embd = embd;
  352. } else {
  353. batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
  354. }
  355. batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc);
  356. batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc);
  357. batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
  358. for (int i = 0; i < n_tokens_alloc; ++i) {
  359. batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
  360. }
  361. batch.seq_id[n_tokens_alloc] = nullptr;
  362. batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
  363. return batch;
  364. }
  365. void llama_batch_free(struct llama_batch batch) {
  366. if (batch.token) free(batch.token);
  367. if (batch.embd) free(batch.embd);
  368. if (batch.pos) free(batch.pos);
  369. if (batch.n_seq_id) free(batch.n_seq_id);
  370. if (batch.seq_id) {
  371. for (int i = 0; batch.seq_id[i] != nullptr; ++i) {
  372. free(batch.seq_id[i]);
  373. }
  374. free(batch.seq_id);
  375. }
  376. if (batch.logits) free(batch.logits);
  377. }