123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114 |
- /**
- * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
- *
- * MIT License
- *
- * Copyright (c) 2023-2024 The ggml authors
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to deal
- * in the Software without restriction, including without limitation the rights
- * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- * copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
- #pragma once
- #include "llama.h"
- #include <array>
- #include <vector>
- // very similar to llama_batch,
- // but has more metadata about sequences
- struct llama_ubatch {
- bool equal_seqs;
- // TODO: whole_seqs for embeddings?
- uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
- uint32_t n_seq_tokens; // tokens per sequence
- uint32_t n_seqs;
- llama_token * token; // [n_tokens]
- float * embd; // [n_embd, n_tokens]
- llama_pos * pos; // [n_tokens]
- int32_t * n_seq_id; // [n_seqs]
- llama_seq_id ** seq_id; // [n_seqs]
- int8_t * output; // [n_tokens]
- };
- struct llama_sbatch_seq {
- int32_t n_seq_id;
- llama_seq_id * seq_id;
- size_t offset;
- size_t length;
- };
- // sequence-length-aware batch splitting
- struct llama_sbatch {
- // tokens left in this batch
- size_t n_tokens;
- size_t n_embd;
- bool logits_all; // TODO: remove once lctx.logits_all is removed too
- // sorted indices into the batch
- std::vector<size_t> ids;
- // batch indices of the output
- std::vector<size_t> out_ids;
- std::vector<llama_sbatch_seq> seq;
- const llama_batch * batch = nullptr;
- // buffers for the ubatch
- std::vector<llama_token> ubatch_token;
- std::vector<float> ubatch_embd;
- std::vector<llama_pos> ubatch_pos;
- std::vector<int32_t> ubatch_n_seq_id;
- std::vector<llama_seq_id *> ubatch_seq_id;
- std::vector<int8_t> ubatch_output;
- llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);
- void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length);
- // simple split, unknown number of sequences of unequal lengths
- llama_ubatch split_simple(size_t n_ubatch);
- // make batches of equal-length sequences
- llama_ubatch split_equal(size_t n_ubatch);
- // sequence-wise split
- llama_ubatch split_seq(size_t n_ubatch);
- void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
- };
- // temporary allocate memory for the input batch if needed
- struct llama_batch_allocr {
- struct llama_batch batch;
- std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
- std::vector<llama_pos> pos;
- std::vector<int32_t> n_seq_id;
- std::vector<llama_seq_id *> seq_id;
- std::vector<int8_t> logits;
- // optionally fulfill the batch returned by llama_batch_get_one
- llama_batch_allocr(struct llama_batch in_batch, llama_pos p0);
- };
|