|
@@ -12,27 +12,49 @@ kv cache once per run
|
|
|
|
|
|
remaining is to implement the cross attention mask
|
|
|
---
|
|
|
- include/llama.h | 4 +
|
|
|
- src/llama.cpp | 456 ++++++++++++++++++++++++++++++++++++++++++++++--
|
|
|
- 2 files changed, 447 insertions(+), 13 deletions(-)
|
|
|
+ examples/llava/llava.cpp | 2 +-
|
|
|
+ include/llama.h | 5 +
|
|
|
+ src/llama.cpp | 447 +++++++++++++++++++++++++++++++++++++--
|
|
|
+ 3 files changed, 436 insertions(+), 18 deletions(-)
|
|
|
|
|
|
+diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp
|
|
|
+index 8558c6bd..37b2f2e2 100644
|
|
|
+--- a/examples/llava/llava.cpp
|
|
|
++++ b/examples/llava/llava.cpp
|
|
|
+@@ -409,7 +409,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
|
|
|
+ if (n_eval > n_batch) {
|
|
|
+ n_eval = n_batch;
|
|
|
+ }
|
|
|
+- llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, };
|
|
|
++ llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), n_embd, nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, };
|
|
|
+ if (llama_decode(ctx_llama, batch)) {
|
|
|
+ LOG_ERR("%s : failed to eval\n", __func__);
|
|
|
+ return false;
|
|
|
diff --git a/include/llama.h b/include/llama.h
|
|
|
-index 7cae1bbe..122e3cf1 100644
|
|
|
+index 7cae1bbe..aca09310 100644
|
|
|
--- a/include/llama.h
|
|
|
+++ b/include/llama.h
|
|
|
-@@ -423,6 +423,10 @@ extern "C" {
|
|
|
+@@ -240,6 +240,7 @@ extern "C" {
|
|
|
+
|
|
|
+ llama_token * token;
|
|
|
+ float * embd;
|
|
|
++ int32_t n_embd;
|
|
|
+ llama_pos * pos;
|
|
|
+ int32_t * n_seq_id;
|
|
|
+ llama_seq_id ** seq_id;
|
|
|
+@@ -423,6 +424,10 @@ extern "C" {
|
|
|
struct llama_model * model,
|
|
|
struct llama_context_params params);
|
|
|
|
|
|
+ // TODO (jmorganca): this should most likely be passed in as part of a batch
|
|
|
+ // and not set on the context for all batches.
|
|
|
-+ LLAMA_API void llama_set_cross_attn_state(struct llama_context * ctx, float * cross_attn_state);
|
|
|
++ LLAMA_API void llama_set_cross_attention(struct llama_context * ctx, bool cross_attn_state);
|
|
|
+
|
|
|
// Frees all allocated memory
|
|
|
LLAMA_API void llama_free(struct llama_context * ctx);
|
|
|
|
|
|
diff --git a/src/llama.cpp b/src/llama.cpp
|
|
|
-index 83b80b59..b189a19a 100644
|
|
|
+index 83b80b59..35748488 100644
|
|
|
--- a/src/llama.cpp
|
|
|
+++ b/src/llama.cpp
|
|
|
@@ -169,6 +169,7 @@ static std::string format(const char * fmt, ...) {
|
|
@@ -160,13 +182,23 @@ index 83b80b59..b189a19a 100644
|
|
|
GGML_ABORT("fatal error");
|
|
|
}
|
|
|
+
|
|
|
-+ bool cross_attention_layer(uint32_t il) const {
|
|
|
++ bool cross_attention_layers(uint32_t il) const {
|
|
|
+ return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end();
|
|
|
+ }
|
|
|
};
|
|
|
|
|
|
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
|
|
|
-@@ -2806,6 +2859,16 @@ struct llama_layer {
|
|
|
+@@ -2652,6 +2705,9 @@ struct llama_cparams {
|
|
|
+ bool offload_kqv;
|
|
|
+ bool flash_attn;
|
|
|
+ bool no_perf;
|
|
|
++ // TODO (jmorganca): this should most likely be passed in as part of a batch
|
|
|
++ // and not set on the context for all batches.
|
|
|
++ bool cross_attn = false;
|
|
|
+
|
|
|
+ enum llama_pooling_type pooling_type;
|
|
|
+
|
|
|
+@@ -2806,6 +2862,16 @@ struct llama_layer {
|
|
|
struct ggml_tensor * ffn_down_scale;
|
|
|
|
|
|
struct ggml_tensor * bskcn_tv;
|
|
@@ -183,25 +215,21 @@ index 83b80b59..b189a19a 100644
|
|
|
};
|
|
|
|
|
|
// very similar to llama_batch,
|
|
|
-@@ -3452,6 +3515,12 @@ struct llama_context {
|
|
|
+@@ -3452,6 +3518,8 @@ struct llama_context {
|
|
|
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
|
|
|
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
|
|
|
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
|
|
|
+
|
|
|
-+ // TODO (jmorganca): this should most likely be passed in as part of a batch
|
|
|
-+ // and not set on the context for all batches.
|
|
|
-+ float * cross_attn_state = nullptr;
|
|
|
-+ bool cross_attn_state_first_pass = true;
|
|
|
+ struct ggml_tensor * inp_cross_attn_state; // F32 [4, n_embd, 1061]
|
|
|
};
|
|
|
|
|
|
struct llama_lora_weight {
|
|
|
-@@ -3686,6 +3755,18 @@ static bool llama_kv_cache_init(
|
|
|
+@@ -3686,6 +3754,18 @@ static bool llama_kv_cache_init(
|
|
|
cache.v_l.reserve(n_layer);
|
|
|
|
|
|
for (int i = 0; i < (int) n_layer; i++) {
|
|
|
+ // for cross attention layers
|
|
|
-+ if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layer(i)) {
|
|
|
++ if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layers(i)) {
|
|
|
+ struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
|
|
|
+ ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_k, 6404, hparams.n_head_kv(i));
|
|
|
+ ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_v, 6404, hparams.n_head_kv(i));
|
|
@@ -215,7 +243,7 @@ index 83b80b59..b189a19a 100644
|
|
|
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
|
|
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
|
|
|
|
|
|
-@@ -5460,12 +5541,14 @@ static void llm_load_hparams(
|
|
|
+@@ -5460,12 +5540,14 @@ static void llm_load_hparams(
|
|
|
}
|
|
|
|
|
|
// zero-out the per-layer hparams
|
|
@@ -235,7 +263,7 @@ index 83b80b59..b189a19a 100644
|
|
|
|
|
|
// n_head_kv is optional, default to n_head
|
|
|
hparams.n_head_kv_arr = hparams.n_head_arr;
|
|
|
-@@ -5514,7 +5597,7 @@ static void llm_load_hparams(
|
|
|
+@@ -5514,7 +5596,7 @@ static void llm_load_hparams(
|
|
|
|
|
|
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
|
|
|
|
|
@@ -244,7 +272,7 @@ index 83b80b59..b189a19a 100644
|
|
|
if (hparams.n_rot != hparams.n_embd_head_k) {
|
|
|
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
|
|
|
}
|
|
|
-@@ -5554,6 +5637,16 @@ static void llm_load_hparams(
|
|
|
+@@ -5554,6 +5636,16 @@ static void llm_load_hparams(
|
|
|
}
|
|
|
}
|
|
|
} break;
|
|
@@ -261,7 +289,7 @@ index 83b80b59..b189a19a 100644
|
|
|
case LLM_ARCH_MINICPM:
|
|
|
{
|
|
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
|
|
-@@ -7249,6 +7342,55 @@ static bool llm_load_tensors(
|
|
|
+@@ -7249,6 +7341,55 @@ static bool llm_load_tensors(
|
|
|
layer.rope_short = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight"), { n_embd_head_qk_rope/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
|
|
}
|
|
|
} break;
|
|
@@ -286,7 +314,7 @@ index 83b80b59..b189a19a 100644
|
|
|
+
|
|
|
+ auto & layer = model.layers[i];
|
|
|
+
|
|
|
-+ if (hparams.cross_attention_layer(i)) {
|
|
|
++ if (hparams.cross_attention_layers(i)) {
|
|
|
+ layer.cross_attn_k_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_NORM, "weight", i), {128});
|
|
|
+ layer.cross_attn_k_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_PROJ, "weight", i), {n_embd, 1024});
|
|
|
+ layer.cross_attn_o_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_O_PROJ, "weight", i), {n_embd, n_embd});
|
|
@@ -317,7 +345,7 @@ index 83b80b59..b189a19a 100644
|
|
|
case LLM_ARCH_GROK:
|
|
|
{
|
|
|
if (n_expert == 0) {
|
|
|
-@@ -9093,7 +9235,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
|
|
|
+@@ -9093,7 +9234,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
|
|
|
|
|
|
if (model.vocab.type != LLAMA_VOCAB_TYPE_NONE &&
|
|
|
model.hparams.n_vocab != model.vocab.id_to_token.size()) {
|
|
@@ -326,16 +354,7 @@ index 83b80b59..b189a19a 100644
|
|
|
}
|
|
|
|
|
|
if (params.vocab_only) {
|
|
|
-@@ -9178,7 +9320,7 @@ static struct ggml_tensor * llm_build_inp_embd(
|
|
|
-
|
|
|
- inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
|
|
|
- } else {
|
|
|
-- lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
|
|
|
-+ lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
|
|
|
- inpL = lctx.inp_embd;
|
|
|
- ggml_set_input(lctx.inp_embd);
|
|
|
- }
|
|
|
-@@ -9193,6 +9335,22 @@ static struct ggml_tensor * llm_build_inp_embd(
|
|
|
+@@ -9193,6 +9334,21 @@ static struct ggml_tensor * llm_build_inp_embd(
|
|
|
return inpL;
|
|
|
}
|
|
|
|
|
@@ -346,11 +365,10 @@ index 83b80b59..b189a19a 100644
|
|
|
+ const llm_build_cb & cb) {
|
|
|
+ const int64_t n_embd = hparams.n_embd;
|
|
|
+
|
|
|
-+ struct ggml_tensor * inpCAS;
|
|
|
-+ lctx.inp_cross_attn_state = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, 1601, 4);
|
|
|
-+ cb(lctx.inp_cross_attn_state, "inp_cross_attn_state", -1);
|
|
|
-+ ggml_set_input(lctx.inp_cross_attn_state);
|
|
|
-+ inpCAS = lctx.inp_cross_attn_state;
|
|
|
++ struct ggml_tensor * inpCAS = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, 1601, 4);
|
|
|
++ cb(inpCAS, "inp_cross_attn_state", -1);
|
|
|
++ ggml_set_input(inpCAS);
|
|
|
++ lctx.inp_cross_attn_state = inpCAS;
|
|
|
+
|
|
|
+ return inpCAS;
|
|
|
+}
|
|
@@ -358,7 +376,7 @@ index 83b80b59..b189a19a 100644
|
|
|
static void llm_build_kv_store(
|
|
|
struct ggml_context * ctx,
|
|
|
const llama_hparams & hparams,
|
|
|
-@@ -10167,6 +10325,7 @@ struct llm_build_context {
|
|
|
+@@ -10167,6 +10323,7 @@ struct llm_build_context {
|
|
|
lctx.inp_pos_bucket = nullptr;
|
|
|
lctx.inp_embd_enc = nullptr;
|
|
|
lctx.inp_KQ_mask_cross = nullptr;
|
|
@@ -366,7 +384,7 @@ index 83b80b59..b189a19a 100644
|
|
|
}
|
|
|
|
|
|
void free() {
|
|
|
-@@ -10754,6 +10913,253 @@ struct llm_build_context {
|
|
|
+@@ -10754,6 +10911,239 @@ struct llm_build_context {
|
|
|
LLM_NORM_RMS, cb, -1);
|
|
|
cb(cur, "result_norm", -1);
|
|
|
|
|
@@ -410,8 +428,8 @@ index 83b80b59..b189a19a 100644
|
|
|
+ LLM_NORM_RMS, cb, il);
|
|
|
+ cb(cur, "attn_norm", il);
|
|
|
+
|
|
|
-+ if (hparams.cross_attention_layer(il)) {
|
|
|
-+ if (!lctx.cross_attn_state) {
|
|
|
++ if (hparams.cross_attention_layers(il)) {
|
|
|
++ if (!batch.embd && !cparams.cross_attn) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+
|
|
@@ -422,42 +440,28 @@ index 83b80b59..b189a19a 100644
|
|
|
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
|
|
+ cb(Qcur, "Qcur", il);
|
|
|
+
|
|
|
-+ Qcur = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
|
|
-+ cb(Qcur, "Qcur", il);
|
|
|
-+
|
|
|
-+ // TODO: is this required?
|
|
|
-+ Qcur = ggml_cont(ctx0, Qcur);
|
|
|
++ Qcur = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 0, 2, 1, 3));
|
|
|
+ cb(Qcur, "Qcur", il);
|
|
|
+
|
|
|
+ Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].cross_attn_q_norm, NULL, LLM_NORM_RMS, cb, il);
|
|
|
+ cb(Qcur, "Qcur", il);
|
|
|
+
|
|
|
-+ struct ggml_tensor * Kcur;
|
|
|
-+ if (lctx.cross_attn_state_first_pass) {
|
|
|
++ struct ggml_tensor * Kcur, * Vcur;
|
|
|
++ if (batch.embd) {
|
|
|
+ Kcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_k_proj, inpCAS);
|
|
|
+ cb(Kcur, "Kcur", il);
|
|
|
+
|
|
|
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, 6404);
|
|
|
+ cb(Kcur, "Kcur", il);
|
|
|
+
|
|
|
-+ Kcur = ggml_permute(ctx0, Kcur, 0, 2, 1, 3);
|
|
|
-+ cb(Kcur, "Kcur", il);
|
|
|
-+
|
|
|
-+ // TODO: is this required?
|
|
|
-+ Kcur = ggml_cont(ctx0, Kcur);
|
|
|
++ Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
|
|
|
+ cb(Kcur, "Kcur", il);
|
|
|
+
|
|
|
+ Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].cross_attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
|
|
|
+ cb(Kcur, "Kcur", il);
|
|
|
+
|
|
|
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, kv_self.k_l[il]));
|
|
|
-+ } else {
|
|
|
-+ Kcur = ggml_view_tensor(ctx0, kv_self.k_l[il]);
|
|
|
-+ cb(Kcur, "Kcur (view)", il);
|
|
|
-+ }
|
|
|
+
|
|
|
-+ struct ggml_tensor * Vcur;
|
|
|
-+ if (lctx.cross_attn_state_first_pass) {
|
|
|
+ Vcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_v_proj, inpCAS);
|
|
|
+ cb(Vcur, "Vcur", il);
|
|
|
+
|
|
@@ -469,6 +473,9 @@ index 83b80b59..b189a19a 100644
|
|
|
+
|
|
|
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, kv_self.v_l[il]));
|
|
|
+ } else {
|
|
|
++ Kcur = ggml_view_tensor(ctx0, kv_self.k_l[il]);
|
|
|
++ cb(Kcur, "Kcur (view)", il);
|
|
|
++
|
|
|
+ Vcur = ggml_view_tensor(ctx0, kv_self.v_l[il]);
|
|
|
+ cb(Vcur, "Vcur (view)", il);
|
|
|
+ }
|
|
@@ -476,11 +483,8 @@ index 83b80b59..b189a19a 100644
|
|
|
+ struct ggml_tensor * kq = ggml_mul_mat(ctx0, Kcur, Qcur);
|
|
|
+ cb(kq, "kq", il);
|
|
|
+
|
|
|
-+ kq = ggml_scale_inplace(ctx0, kq, 1.0f/sqrtf(float(n_embd_head)));
|
|
|
-+ cb(kq, "kq_scaled", il);
|
|
|
-+
|
|
|
+ // TODO: apply causal masks
|
|
|
-+ struct ggml_tensor * kq_soft_max = ggml_soft_max_inplace(ctx0, kq);
|
|
|
++ struct ggml_tensor * kq_soft_max = ggml_soft_max_ext(ctx0, kq, nullptr, 1.f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias);
|
|
|
+ cb(kq_soft_max, "kq_soft_max", il);
|
|
|
+
|
|
|
+ Vcur = ggml_cont(ctx0, ggml_transpose(ctx0, Vcur));
|
|
@@ -570,8 +574,8 @@ index 83b80b59..b189a19a 100644
|
|
|
+ cb(Kcur, "Kcur", il);
|
|
|
+
|
|
|
+ cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
|
-+ model.layers[il].wo, model.layers[il].bo,
|
|
|
-+ Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
|
++ model.layers[il].wo, model.layers[il].bo,
|
|
|
++ Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
|
+
|
|
|
+
|
|
|
+ if (il == n_layer - 1) {
|
|
@@ -620,7 +624,7 @@ index 83b80b59..b189a19a 100644
|
|
|
// lm_head
|
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
|
cb(cur, "result_output", -1);
|
|
|
-@@ -16501,6 +16907,10 @@ static struct ggml_cgraph * llama_build_graph(
|
|
|
+@@ -16501,6 +16891,10 @@ static struct ggml_cgraph * llama_build_graph(
|
|
|
{
|
|
|
result = llm.build_llama();
|
|
|
} break;
|
|
@@ -631,33 +635,48 @@ index 83b80b59..b189a19a 100644
|
|
|
case LLM_ARCH_BAICHUAN:
|
|
|
{
|
|
|
result = llm.build_baichuan();
|
|
|
-@@ -16773,6 +17183,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
|
|
|
- ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
|
|
|
+@@ -16761,10 +17155,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
|
|
|
}
|
|
|
|
|
|
-+ // TODO (jmorganca): this might copy a lot of data on every request of a
|
|
|
-+ // single generation even though it doesn't change, so we should
|
|
|
-+ // find a way to not set this more than one time per image
|
|
|
-+ if (lctx.inp_cross_attn_state &&
|
|
|
-+ lctx.inp_cross_attn_state->buffer) {
|
|
|
-+ ggml_backend_tensor_set(lctx.inp_cross_attn_state, lctx.cross_attn_state, 0, hparams.n_embd * 1601 * 4 * ggml_element_size(lctx.inp_cross_attn_state));
|
|
|
-+ }
|
|
|
-+
|
|
|
- if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
|
|
- GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
|
|
|
- const int64_t n_tokens = batch.n_tokens;
|
|
|
-@@ -17455,6 +17873,10 @@ static int llama_decode_internal(
|
|
|
+ if (batch.embd) {
|
|
|
+- const int64_t n_embd = hparams.n_embd;
|
|
|
+- const int64_t n_tokens = batch.n_tokens;
|
|
|
++ if (lctx.inp_cross_attn_state && lctx.inp_cross_attn_state->buffer) {
|
|
|
++ ggml_backend_tensor_set(lctx.inp_cross_attn_state, batch.embd, 0, ggml_nbytes(lctx.inp_cross_attn_state));
|
|
|
++ // zero out inp_embd since it's not used
|
|
|
++ float * inp_embd_data = (float *)lctx.inp_embd->data;
|
|
|
++ for (int i = 0; i < ggml_nelements(lctx.inp_embd); ++i) {
|
|
|
++ inp_embd_data[i] = 0.0f;
|
|
|
++ }
|
|
|
++ } else {
|
|
|
++ const int64_t n_embd = hparams.n_embd;
|
|
|
++ const int64_t n_tokens = batch.n_tokens;
|
|
|
|
|
|
- llama_set_inputs(lctx, ubatch);
|
|
|
+- ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
|
|
|
++ ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
|
|
|
++ }
|
|
|
+ }
|
|
|
|
|
|
-+ // TODO: replace with something better to find out if its
|
|
|
-+ // our first actual pass
|
|
|
-+ lctx.cross_attn_state_first_pass = false;
|
|
|
-+
|
|
|
- llama_graph_compute(lctx, gf, n_threads, threadpool);
|
|
|
+ if (batch.pos && lctx.inp_pos) {
|
|
|
+@@ -17345,7 +17748,7 @@ static int llama_decode_internal(
|
|
|
+ n_outputs = 1;
|
|
|
+ }
|
|
|
+
|
|
|
+- lctx.sbatch.from_batch(batch_all, n_embd,
|
|
|
++ lctx.sbatch.from_batch(batch_all, batch_all.n_embd,
|
|
|
+ /* simple_split */ !kv_self.recurrent,
|
|
|
+ /* logits_all */ n_outputs == n_tokens_all);
|
|
|
+
|
|
|
+@@ -17638,7 +18041,7 @@ static int llama_encode_internal(
|
|
|
+
|
|
|
+ const int64_t n_embd = hparams.n_embd;
|
|
|
+
|
|
|
+- lctx.sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
|
|
|
++ lctx.sbatch.from_batch(batch, batch.n_embd, /* simple_split */ true, /* logits_all */ true);
|
|
|
+
|
|
|
+ const llama_ubatch ubatch = lctx.sbatch.split_simple(n_tokens);
|
|
|
|
|
|
- // update the kv ring buffer
|
|
|
-@@ -18648,7 +19070,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
|
|
+@@ -18648,7 +19051,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
|
|
if (llama_model_has_encoder(&model)) {
|
|
|
n_attn_layer *= 3;
|
|
|
}
|
|
@@ -668,19 +687,7 @@ index 83b80b59..b189a19a 100644
|
|
|
}
|
|
|
|
|
|
size_t total_size_org = 0;
|
|
|
-@@ -19744,6 +20168,11 @@ struct llama_context * llama_new_context_with_model(
|
|
|
- return ctx;
|
|
|
- }
|
|
|
-
|
|
|
-+void llama_set_cross_attn_state(struct llama_context * ctx, float * cross_attn_state) {
|
|
|
-+ ctx->cross_attn_state_first_pass = true;
|
|
|
-+ ctx->cross_attn_state = cross_attn_state;
|
|
|
-+}
|
|
|
-+
|
|
|
- void llama_free(struct llama_context * ctx) {
|
|
|
- delete ctx;
|
|
|
- }
|
|
|
-@@ -19814,6 +20243,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
|
|
+@@ -19814,6 +20219,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
|
|
|
|
|
// use what we call a normal RoPE, operating on pairs of consecutive head values
|
|
|
case LLM_ARCH_LLAMA:
|
|
@@ -688,3 +695,38 @@ index 83b80b59..b189a19a 100644
|
|
|
case LLM_ARCH_BAICHUAN:
|
|
|
case LLM_ARCH_STARCODER:
|
|
|
case LLM_ARCH_PLAMO:
|
|
|
+@@ -21230,6 +21636,10 @@ void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
|
|
|
+ ctx->cparams.causal_attn = causal_attn;
|
|
|
+ }
|
|
|
+
|
|
|
++void llama_set_cross_attention(struct llama_context * ctx, bool cross_attention) {
|
|
|
++ ctx->cparams.cross_attn = cross_attention;
|
|
|
++}
|
|
|
++
|
|
|
+ struct llama_batch llama_batch_get_one(
|
|
|
+ llama_token * tokens,
|
|
|
+ int32_t n_tokens,
|
|
|
+@@ -21239,6 +21649,7 @@ struct llama_batch llama_batch_get_one(
|
|
|
+ /*n_tokens =*/ n_tokens,
|
|
|
+ /*tokens =*/ tokens,
|
|
|
+ /*embd =*/ nullptr,
|
|
|
++ /*n_embd =*/ 0,
|
|
|
+ /*pos =*/ nullptr,
|
|
|
+ /*n_seq_id =*/ nullptr,
|
|
|
+ /*seq_id =*/ nullptr,
|
|
|
+@@ -21254,6 +21665,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
|
|
|
+ /*n_tokens =*/ 0,
|
|
|
+ /*tokens =*/ nullptr,
|
|
|
+ /*embd =*/ nullptr,
|
|
|
++ /*n_embd =*/ 0,
|
|
|
+ /*pos =*/ nullptr,
|
|
|
+ /*n_seq_id =*/ nullptr,
|
|
|
+ /*seq_id =*/ nullptr,
|
|
|
+@@ -21265,6 +21677,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
|
|
|
+
|
|
|
+ if (embd) {
|
|
|
+ batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
|
|
|
++ batch.n_embd = embd;
|
|
|
+ } else {
|
|
|
+ batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
|
|
|
+ }
|