123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786 |
- From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
- From: jmorganca <jmorganca@gmail.com>
- Date: Thu, 17 Oct 2024 15:18:22 -0700
- Subject: [PATCH] add mllama support
- mllama adds cross-attention layers to the standard llama architecture
- it also requires a way to input a new tensor: cross_attention_state
- once per generation
- cross-attention layers don't change and so they are cached in the
- kv cache once per run
- remaining is to implement the cross attention mask
- ---
- examples/llava/llava.cpp | 5 +-
- include/llama.h | 5 +
- src/llama.cpp | 477 +++++++++++++++++++++++++++++++++++++--
- 3 files changed, 467 insertions(+), 20 deletions(-)
- diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp
- index 16f30c56..0f0f3f62 100644
- --- a/examples/llava/llava.cpp
- +++ b/examples/llava/llava.cpp
- @@ -429,7 +429,7 @@ struct llava_embd_batch {
- std::vector<llama_seq_id *> seq_ids;
- std::vector<int8_t> logits;
- llama_batch batch;
- - llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
- + llava_embd_batch(float * embd, int32_t n_embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
- pos .resize(n_tokens);
- n_seq_id.resize(n_tokens);
- seq_ids .resize(n_tokens + 1);
- @@ -441,6 +441,7 @@ struct llava_embd_batch {
- /*n_tokens =*/ n_tokens,
- /*tokens =*/ nullptr,
- /*embd =*/ embd,
- + /*n_embd =*/ n_embd,
- /*pos =*/ pos.data(),
- /*n_seq_id =*/ n_seq_id.data(),
- /*seq_id =*/ seq_ids.data(),
- @@ -464,7 +465,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
- n_eval = n_batch;
- }
- float * embd = image_embed->embed+i*n_embd;
- - llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0);
- + llava_embd_batch llava_batch = llava_embd_batch(embd, n_embd, n_eval, *n_past, 0);
- if (llama_decode(ctx_llama, llava_batch.batch)) {
- LOG_ERR("%s : failed to eval\n", __func__);
- return false;
- diff --git a/include/llama.h b/include/llama.h
- index c67988a3..0f266283 100644
- --- a/include/llama.h
- +++ b/include/llama.h
- @@ -249,6 +249,7 @@ extern "C" {
-
- llama_token * token;
- float * embd;
- + int32_t n_embd;
- llama_pos * pos;
- int32_t * n_seq_id;
- llama_seq_id ** seq_id;
- @@ -423,6 +424,10 @@ extern "C" {
- struct llama_model * model,
- struct llama_context_params params);
-
- + // TODO (jmorganca): this should most likely be passed in as part of a batch
- + // and not set on the context for all batches.
- + LLAMA_API void llama_set_cross_attention(struct llama_context * ctx, bool cross_attn_state);
- +
- // Frees all allocated memory
- LLAMA_API void llama_free(struct llama_context * ctx);
-
- diff --git a/src/llama.cpp b/src/llama.cpp
- index 26be6254..4778a9ed 100644
- --- a/src/llama.cpp
- +++ b/src/llama.cpp
- @@ -146,6 +146,7 @@ static std::string format(const char * fmt, ...) {
-
- enum llm_arch {
- LLM_ARCH_LLAMA,
- + LLM_ARCH_MLLAMA,
- LLM_ARCH_FALCON,
- LLM_ARCH_BAICHUAN,
- LLM_ARCH_GROK,
- @@ -202,6 +203,7 @@ enum llm_arch {
-
- static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
- { LLM_ARCH_LLAMA, "llama" },
- + { LLM_ARCH_MLLAMA, "mllama" },
- { LLM_ARCH_FALCON, "falcon" },
- { LLM_ARCH_GROK, "grok" },
- { LLM_ARCH_GPT2, "gpt2" },
- @@ -311,6 +313,7 @@ enum llm_kv {
- LLM_KV_ATTENTION_SLIDING_WINDOW,
- LLM_KV_ATTENTION_SCALE,
- LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,
- + LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS,
-
- LLM_KV_ROPE_DIMENSION_COUNT,
- LLM_KV_ROPE_DIMENSION_SECTIONS,
- @@ -429,6 +432,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
- { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
- { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
- { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, "%s.attention.block_skip_connection.%d" },
- + { LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, "%s.attention.cross_attention_layers" },
-
- { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
- { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
- @@ -612,6 +616,14 @@ enum llm_tensor {
- LLM_TENSOR_CLS,
- LLM_TENSOR_CLS_OUT,
- LLM_TENSOR_BSKCN_TV,
- + LLM_TENSOR_CROSS_ATTN_K_NORM,
- + LLM_TENSOR_CROSS_ATTN_K_PROJ,
- + LLM_TENSOR_CROSS_ATTN_O_PROJ,
- + LLM_TENSOR_CROSS_ATTN_Q_NORM,
- + LLM_TENSOR_CROSS_ATTN_Q_PROJ,
- + LLM_TENSOR_CROSS_ATTN_V_PROJ,
- + LLM_TENSOR_CROSS_ATTN_ATTN_GATE,
- + LLM_TENSOR_CROSS_ATTN_MLP_GATE,
- };
-
- static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_NAMES = {
- @@ -641,6 +653,40 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
- { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
- },
- },
- + {
- + LLM_ARCH_MLLAMA,
- + {
- + { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
- + { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
- + { LLM_TENSOR_OUTPUT, "output" },
- + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
- + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
- + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
- + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
- + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
- + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
- + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
- + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
- + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
- + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
- + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
- + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
- + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
- + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
- + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
- + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
- + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
- + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
- + { LLM_TENSOR_CROSS_ATTN_K_NORM, "blk.%d.cross_attn_k_norm" },
- + { LLM_TENSOR_CROSS_ATTN_K_PROJ, "blk.%d.cross_attn_k_proj" },
- + { LLM_TENSOR_CROSS_ATTN_O_PROJ, "blk.%d.cross_attn_o_proj" },
- + { LLM_TENSOR_CROSS_ATTN_Q_NORM, "blk.%d.cross_attn_q_norm" },
- + { LLM_TENSOR_CROSS_ATTN_Q_PROJ, "blk.%d.cross_attn_q_proj" },
- + { LLM_TENSOR_CROSS_ATTN_V_PROJ, "blk.%d.cross_attn_v_proj" },
- + { LLM_TENSOR_CROSS_ATTN_ATTN_GATE, "blk.%d.cross_attn_attn_gate" },
- + { LLM_TENSOR_CROSS_ATTN_MLP_GATE, "blk.%d.cross_attn_mlp_gate" },
- + },
- + },
- {
- LLM_ARCH_BAICHUAN,
- {
- @@ -2456,6 +2502,7 @@ enum e_model {
- MODEL_40B,
- MODEL_65B,
- MODEL_70B,
- + MODEL_90B,
- MODEL_236B,
- MODEL_314B,
- MODEL_SMALL,
- @@ -2500,6 +2547,7 @@ struct llama_hparams {
- std::array<uint32_t, LLAMA_MAX_LAYERS> n_ff_arr;
-
- std::array<std::array<uint32_t, LLAMA_MAX_LAYERS>, 4> n_bskcn_arr;
- + std::array<uint32_t, LLAMA_MAX_LAYERS> cross_attn_layers;
-
- uint32_t n_layer_dense_lead = 0;
- uint32_t n_lora_q = 0;
- @@ -2569,10 +2617,11 @@ struct llama_hparams {
- if (this->n_expert != other.n_expert) return true;
- if (this->n_expert_used != other.n_expert_used) return true;
-
- - if (this->n_head_arr != other.n_head_arr) return true;
- - if (this->n_head_kv_arr != other.n_head_kv_arr) return true;
- - if (this->n_ff_arr != other.n_ff_arr) return true;
- - if (this->n_bskcn_arr != other.n_bskcn_arr) return true;
- + if (this->n_head_arr != other.n_head_arr) return true;
- + if (this->n_head_kv_arr != other.n_head_kv_arr) return true;
- + if (this->n_ff_arr != other.n_ff_arr) return true;
- + if (this->n_bskcn_arr != other.n_bskcn_arr) return true;
- + if (this->cross_attn_layers != other.cross_attn_layers) return true;
-
- if (this->n_rel_attn_bkts != other.n_rel_attn_bkts) return true;
- if (this->n_layer_dense_lead != other.n_layer_dense_lead) return true;
- @@ -2693,6 +2742,10 @@ struct llama_hparams {
-
- GGML_ABORT("fatal error");
- }
- +
- + bool cross_attention_layers(uint32_t il) const {
- + return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end();
- + }
- };
-
- static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
- @@ -2722,6 +2775,9 @@ struct llama_cparams {
- bool offload_kqv;
- bool flash_attn;
- bool no_perf;
- + // TODO (jmorganca): this should most likely be passed in as part of a batch
- + // and not set on the context for all batches.
- + bool cross_attn = false;
-
- enum llama_pooling_type pooling_type;
-
- @@ -2881,6 +2937,16 @@ struct llama_layer {
- struct ggml_tensor * ffn_down_scale;
-
- struct ggml_tensor * bskcn_tv;
- +
- + // cross attention
- + struct ggml_tensor * cross_attn_k_norm;
- + struct ggml_tensor * cross_attn_k_proj;
- + struct ggml_tensor * cross_attn_o_proj;
- + struct ggml_tensor * cross_attn_q_norm;
- + struct ggml_tensor * cross_attn_q_proj;
- + struct ggml_tensor * cross_attn_v_proj;
- + struct ggml_tensor * cross_attn_attn_gate;
- + struct ggml_tensor * cross_attn_mlp_gate;
- };
-
- // very similar to llama_batch,
- @@ -3472,6 +3538,8 @@ struct llama_context {
- struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
- struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
- struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
- +
- + struct ggml_tensor * inp_cross_attn_state; // F32 [4, n_embd, 1061]
- };
-
- struct llama_lora_weight {
- @@ -3610,6 +3678,39 @@ static bool llama_kv_cache_init(
- cache.v_l.reserve(n_layer);
-
- for (int i = 0; i < (int) n_layer; i++) {
- + // for cross attention layers
- + if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layers(i)) {
- + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
- + const llama_model::buft_list_t * buft_list;
- + if (offload) {
- + buft_list = model.dev_layer.at(i).buft_list;
- + } else {
- + buft_list = &model.cpu_buft_list;
- + }
- + ggml_backend_buffer_type_t buft = select_buft(*buft_list,
- + [&](ggml_context * ctx) {
- + ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
- + if (hparams.rope_type == LLAMA_ROPE_TYPE_NONE) {
- + return k;
- + }
- + ggml_tensor * p = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
- + return ggml_rope(ctx, k, p, hparams.n_rot, hparams.rope_type);
- + });
- + ggml_context * ctx = ctx_for_buft(buft);
- +
- + if (!ctx) {
- + LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__);
- + return false;
- + }
- + ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_k, 6404, hparams.n_head_kv(i));
- + ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_v, 6404, hparams.n_head_kv(i));
- + ggml_format_name(k, "cache_k_l%d", i);
- + ggml_format_name(v, "cache_v_l%d", i);
- + cache.k_l.push_back(k);
- + cache.v_l.push_back(v);
- + continue;
- + }
- +
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
-
- @@ -5547,12 +5648,14 @@ static void llm_load_hparams(
- }
-
- // zero-out the per-layer hparams
- - std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
- - std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
- - std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0);
- + std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
- + std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
- + std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0);
- + std::fill(hparams.cross_attn_layers.begin(), hparams.cross_attn_layers.end(), -1);
-
- - ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer);
- - ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer);
- + ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer);
- + ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer);
- + ml.get_arr(LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, hparams.cross_attn_layers, false);
-
- // n_head_kv is optional, default to n_head
- hparams.n_head_kv_arr = hparams.n_head_arr;
- @@ -5601,7 +5704,7 @@ static void llm_load_hparams(
-
- ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
-
- - if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
- + if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_MLLAMA || model.arch == LLM_ARCH_FALCON) {
- if (hparams.n_rot != hparams.n_embd_head_k) {
- throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
- }
- @@ -5641,6 +5744,16 @@ static void llm_load_hparams(
- }
- }
- } break;
- + case LLM_ARCH_MLLAMA:
- + {
- + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
- +
- + switch (hparams.n_layer) {
- + case 40: model.type = e_model::MODEL_11B; break;
- + case 100: model.type = e_model::MODEL_90B; break;
- + default: model.type = e_model::MODEL_UNKNOWN;
- + }
- + } break;
- case LLM_ARCH_MINICPM:
- {
- ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
- @@ -7291,7 +7404,15 @@ static const std::map<llm_tensor, llm_tensor_info> llm_tensor_info_mapping = {
- {LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
- // this tensor is loaded for T5, but never used
- {LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
- - {LLM_TENSOR_BSKCN_TV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}
- + {LLM_TENSOR_BSKCN_TV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
- + {LLM_TENSOR_CROSS_ATTN_K_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
- + {LLM_TENSOR_CROSS_ATTN_K_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- + {LLM_TENSOR_CROSS_ATTN_O_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- + {LLM_TENSOR_CROSS_ATTN_Q_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
- + {LLM_TENSOR_CROSS_ATTN_Q_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- + {LLM_TENSOR_CROSS_ATTN_V_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- + {LLM_TENSOR_CROSS_ATTN_ATTN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
- + {LLM_TENSOR_CROSS_ATTN_MLP_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
- };
-
- // checks if the weight tensor can be used with the specified buffer type and device
- @@ -7801,6 +7922,53 @@ static bool llm_load_tensors(
- }
- }
- } break;
- + case LLM_ARCH_MLLAMA:
- + {
- + model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab+8}, 0);
- +
- + // output
- + {
- + model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
- + model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
- +
- + // if output is NULL, init from the input tok embed
- + if (model.output == NULL) {
- + model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
- + }
- + }
- +
- + for (int i = 0; i < n_layer; ++i) {
- +
- + auto & layer = model.layers[i];
- +
- + if (hparams.cross_attention_layers(i)) {
- + layer.cross_attn_k_norm = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_K_NORM, "weight", i), {128}, 0);
- + layer.cross_attn_k_proj = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_K_PROJ, "weight", i), {n_embd, 1024}, 0);
- + layer.cross_attn_o_proj = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_O_PROJ, "weight", i), {n_embd, n_embd}, 0);
- + layer.cross_attn_q_norm = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_Q_NORM, "weight", i), {128}, 0);
- + layer.cross_attn_q_proj = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_Q_PROJ, "weight", i), {n_embd, n_embd}, 0);
- + layer.cross_attn_v_proj = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_V_PROJ, "weight", i), {n_embd, 1024}, 0);
- + layer.cross_attn_attn_gate = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_ATTN_GATE, i), {1}, 0);
- + layer.cross_attn_mlp_gate = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_MLP_GATE, i), {1}, 0);
- + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
- + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
- + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
- + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
- + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
- + } else {
- + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
- + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
- + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
- + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
- + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
- + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
- + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
- + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
- + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
- + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
- + }
- + }
- + } break;
- case LLM_ARCH_MINICPM3:
- {
- const int64_t n_embd_head_qk_rope = hparams.n_rot;
- @@ -9511,7 +9679,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
-
- if (model.vocab.type != LLAMA_VOCAB_TYPE_NONE &&
- model.hparams.n_vocab != model.vocab.id_to_token.size()) {
- - throw std::runtime_error("vocab size mismatch");
- + LLAMA_LOG_WARN("%s: vocab mismatch %u !- %zu ...\n", __func__, model.hparams.n_vocab, model.vocab.id_to_token.size());
- }
-
- if (params.vocab_only) {
- @@ -9594,6 +9762,21 @@ static struct ggml_tensor * llm_build_inp_embd(
- return inpL;
- }
-
- +static struct ggml_tensor * llm_build_inp_cross_attn_state(
- + struct ggml_context * ctx,
- + struct llama_context & lctx,
- + const llama_hparams & hparams,
- + const llm_build_cb & cb) {
- + const int64_t n_embd = hparams.n_embd;
- +
- + struct ggml_tensor * inpCAS = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, 1601, 4);
- + cb(inpCAS, "inp_cross_attn_state", -1);
- + ggml_set_input(inpCAS);
- + lctx.inp_cross_attn_state = inpCAS;
- +
- + return inpCAS;
- +}
- +
- static void llm_build_kv_store(
- struct ggml_context * ctx,
- const llama_hparams & hparams,
- @@ -10561,6 +10744,7 @@ struct llm_build_context {
- lctx.inp_pos_bucket = nullptr;
- lctx.inp_embd_enc = nullptr;
- lctx.inp_KQ_mask_cross = nullptr;
- + lctx.inp_cross_attn_state = nullptr;
- }
-
- void free() {
- @@ -11040,6 +11224,240 @@ struct llm_build_context {
- return gf;
- }
-
- + struct ggml_cgraph * build_mllama() {
- + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
- +
- + // mutable variable, needed during the last layer of the computation to skip unused tokens
- + int32_t n_tokens = this->n_tokens;
- +
- + const int64_t n_embd_head = hparams.n_embd_head_v;
- + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
- + GGML_ASSERT(n_embd_head == hparams.n_rot);
- +
- + struct ggml_tensor * cur;
- + struct ggml_tensor * inpL;
- + struct ggml_tensor * inpCAS;
- +
- + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
- + inpCAS = llm_build_inp_cross_attn_state(ctx0, lctx, hparams, cb);
- +
- + // inp_pos - contains the positions
- + struct ggml_tensor * inp_pos = build_inp_pos();
- +
- + // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
- + struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
- +
- + for (int il = 0; il < n_layer; ++il) {
- + struct ggml_tensor * inpSA = inpL;
- +
- + // norm
- + cur = llm_build_norm(ctx0, inpL, hparams,
- + model.layers[il].attn_norm, NULL,
- + LLM_NORM_RMS, cb, il);
- + cb(cur, "attn_norm", il);
- +
- + if (hparams.cross_attention_layers(il)) {
- + if (!ubatch.embd && !cparams.cross_attn) {
- + continue;
- + }
- +
- + // cross attention layer
- + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_q_proj, cur);
- + cb(Qcur, "Qcur", il);
- +
- + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
- + cb(Qcur, "Qcur", il);
- +
- + Qcur = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 0, 2, 1, 3));
- + cb(Qcur, "Qcur", il);
- +
- + Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].cross_attn_q_norm, NULL, LLM_NORM_RMS, cb, il);
- + cb(Qcur, "Qcur", il);
- +
- + struct ggml_tensor * Kcur, * Vcur;
- + if (ubatch.embd) {
- + Kcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_k_proj, inpCAS);
- + cb(Kcur, "Kcur", il);
- +
- + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, 6404);
- + cb(Kcur, "Kcur", il);
- +
- + Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
- + cb(Kcur, "Kcur", il);
- +
- + Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].cross_attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
- + cb(Kcur, "Kcur", il);
- +
- + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, kv_self.k_l[il]));
- +
- + Vcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_v_proj, inpCAS);
- + cb(Vcur, "Vcur", il);
- +
- + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, 6404);
- + cb(Vcur, "Vcur", il);
- +
- + Vcur = ggml_permute(ctx0, Vcur, 0, 2, 1, 3);
- + cb(Vcur, "Vcur", il);
- +
- + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, kv_self.v_l[il]));
- + } else {
- + Kcur = ggml_view_tensor(ctx0, kv_self.k_l[il]);
- + cb(Kcur, "Kcur (view)", il);
- +
- + Vcur = ggml_view_tensor(ctx0, kv_self.v_l[il]);
- + cb(Vcur, "Vcur (view)", il);
- + }
- +
- + struct ggml_tensor * kq = ggml_mul_mat(ctx0, Kcur, Qcur);
- + cb(kq, "kq", il);
- +
- + // TODO: apply causal masks
- + struct ggml_tensor * kq_soft_max = ggml_soft_max_ext(ctx0, kq, nullptr, 1.f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias);
- + cb(kq_soft_max, "kq_soft_max", il);
- +
- + Vcur = ggml_cont(ctx0, ggml_transpose(ctx0, Vcur));
- + cb(Vcur, "Vcur", il);
- +
- + struct ggml_tensor * kqv = ggml_mul_mat(ctx0, Vcur, kq_soft_max);
- + cb(kqv, "kqv", il);
- +
- + struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
- + cb(kqv_merged, "kqv_merged", il);
- +
- + cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
- + cb(cur, "kqv_merged_cont", il);
- +
- + cur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_o_proj, cur);
- + cb(cur, "cur", il);
- +
- + // TODO: do this in place once?
- + cur = ggml_mul(ctx0, cur, ggml_tanh(ctx0, model.layers[il].cross_attn_attn_gate));
- +
- + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
- + cb(ffn_inp, "ffn_inp", il);
- +
- + // feed-forward network
- + cur = llm_build_norm(ctx0, ffn_inp, hparams,
- + model.layers[il].ffn_norm, NULL,
- + LLM_NORM_RMS, cb, il);
- + cb(cur, "ffn_norm", il);
- +
- + cur = llm_build_ffn(ctx0, lctx, cur,
- + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
- + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
- + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
- + NULL,
- + LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
- + cb(cur, "ffn_out", il);
- +
- + // TODO: do this inplace once?
- + cur = ggml_add_inplace(ctx0, ggml_mul_inplace(ctx0, cur, ggml_tanh(ctx0, model.layers[il].cross_attn_mlp_gate)), ffn_inp);
- + cb(cur, "ffn_out", il);
- +
- + cur = lctx.cvec.apply_to(ctx0, cur, il);
- + cb(cur, "l_out", il);
- +
- + // input for next layer
- + inpL = cur;
- + } else {
- + // self attention layer
- +
- + // rope freq factors for llama3; may return nullptr for llama2 and other models
- + struct ggml_tensor * rope_factors = build_rope_factors(il);
- +
- + // compute Q and K and RoPE them
- + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
- + cb(Qcur, "Qcur", il);
- + if (model.layers[il].bq) {
- + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
- + cb(Qcur, "Qcur", il);
- + }
- +
- + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
- + cb(Kcur, "Kcur", il);
- + if (model.layers[il].bk) {
- + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
- + cb(Kcur, "Kcur", il);
- + }
- +
- + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
- + cb(Vcur, "Vcur", il);
- + if (model.layers[il].bv) {
- + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
- + cb(Vcur, "Vcur", il);
- + }
- +
- + Qcur = ggml_rope_ext(
- + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
- + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
- + ext_factor, attn_factor, beta_fast, beta_slow
- + );
- + cb(Qcur, "Qcur", il);
- +
- + Kcur = ggml_rope_ext(
- + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
- + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
- + ext_factor, attn_factor, beta_fast, beta_slow
- + );
- + cb(Kcur, "Kcur", il);
- +
- + cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- + model.layers[il].wo, model.layers[il].bo,
- + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
- +
- +
- + if (il == n_layer - 1) {
- + // skip computing output for unused tokens
- + struct ggml_tensor * inp_out_ids = build_inp_out_ids();
- + n_tokens = n_outputs;
- + cur = ggml_get_rows(ctx0, cur, inp_out_ids);
- + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
- + }
- +
- + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
- + cb(ffn_inp, "ffn_inp", il);
- +
- + // feed-forward network
- + cur = llm_build_norm(ctx0, ffn_inp, hparams,
- + model.layers[il].ffn_norm, NULL,
- + LLM_NORM_RMS, cb, il);
- + cb(cur, "ffn_norm", il);
- +
- + cur = llm_build_ffn(ctx0, lctx, cur,
- + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
- + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
- + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
- + NULL,
- + LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
- + cb(cur, "ffn_out", il);
- +
- + cur = ggml_add(ctx0, cur, ffn_inp);
- + cb(cur, "ffn_out", il);
- +
- + cur = lctx.cvec.apply_to(ctx0, cur, il);
- + cb(cur, "l_out", il);
- +
- + // input for next layer
- + inpL = cur;
- + }
- + }
- +
- + cur = inpL;
- +
- + cur = llm_build_norm(ctx0, cur, hparams,
- + model.output_norm, NULL,
- + LLM_NORM_RMS, cb, -1);
- + cb(cur, "result_norm", -1);
- +
- + // lm_head
- + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
- + cb(cur, "result_output", -1);
- +
- + ggml_build_forward_expand(gf, cur);
- +
- + return gf;
- + }
- +
- struct ggml_cgraph * build_baichuan() {
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
-
- @@ -16993,6 +17411,10 @@ static struct ggml_cgraph * llama_build_graph(
- {
- result = llm.build_llama();
- } break;
- + case LLM_ARCH_MLLAMA:
- + {
- + result = llm.build_mllama();
- + } break;
- case LLM_ARCH_BAICHUAN:
- {
- result = llm.build_baichuan();
- @@ -17258,10 +17680,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch)
- }
-
- if (ubatch.embd) {
- - const int64_t n_embd = hparams.n_embd;
- - const int64_t n_tokens = ubatch.n_tokens;
- + if (lctx.inp_cross_attn_state && lctx.inp_cross_attn_state->buffer) {
- + ggml_backend_tensor_set(lctx.inp_cross_attn_state, ubatch.embd, 0, ggml_nbytes(lctx.inp_cross_attn_state));
- + // zero out inp_embd since it's not used
- + float * inp_embd_data = (float *)lctx.inp_embd->data;
- + for (int i = 0; i < ggml_nelements(lctx.inp_embd); ++i) {
- + inp_embd_data[i] = 0.0f;
- + }
- + } else {
- + const int64_t n_embd = hparams.n_embd;
- + const int64_t n_tokens = ubatch.n_tokens;
-
- - ggml_backend_tensor_set(lctx.inp_embd, ubatch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
- + ggml_backend_tensor_set(lctx.inp_embd, ubatch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
- + }
- }
-
- if (ubatch.pos && lctx.inp_pos) {
- @@ -17862,7 +18293,7 @@ static int llama_decode_internal(
- n_outputs = 1;
- }
-
- - lctx.sbatch.from_batch(batch, n_embd,
- + lctx.sbatch.from_batch(batch, batch.n_embd,
- /* simple_split */ !kv_self.recurrent,
- /* logits_all */ n_outputs == n_tokens_all);
-
- @@ -18172,7 +18603,7 @@ static int llama_encode_internal(
-
- const int64_t n_embd = hparams.n_embd;
-
- - lctx.sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
- + lctx.sbatch.from_batch(batch, batch.n_embd, /* simple_split */ true, /* logits_all */ true);
-
- const llama_ubatch ubatch = lctx.sbatch.split_simple(n_tokens);
-
- @@ -19203,7 +19634,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
- if (llama_model_has_encoder(&model)) {
- n_attn_layer *= 3;
- }
- - GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected");
- + if (qs.n_attention_wv != n_attn_layer) {
- + LLAMA_LOG_WARN("%s: n_attention_wv is unexpected, expected: %d, found: %d\n", __func__, n_attn_layer, qs.n_attention_wv);
- + }
- }
-
- size_t total_size_org = 0;
- @@ -20360,6 +20793,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
-
- // use what we call a normal RoPE, operating on pairs of consecutive head values
- case LLM_ARCH_LLAMA:
- + case LLM_ARCH_MLLAMA:
- case LLM_ARCH_BAICHUAN:
- case LLM_ARCH_STARCODER:
- case LLM_ARCH_PLAMO:
- @@ -21790,6 +22224,10 @@ void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
- ctx->cparams.causal_attn = causal_attn;
- }
-
- +void llama_set_cross_attention(struct llama_context * ctx, bool cross_attention) {
- + ctx->cparams.cross_attn = cross_attention;
- +}
- +
- struct llama_batch llama_batch_get_one(
- llama_token * tokens,
- int32_t n_tokens) {
- @@ -21797,6 +22235,7 @@ struct llama_batch llama_batch_get_one(
- /*n_tokens =*/ n_tokens,
- /*tokens =*/ tokens,
- /*embd =*/ nullptr,
- + /*n_embd =*/ 0,
- /*pos =*/ nullptr,
- /*n_seq_id =*/ nullptr,
- /*seq_id =*/ nullptr,
- @@ -21809,6 +22248,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
- /*n_tokens =*/ 0,
- /*tokens =*/ nullptr,
- /*embd =*/ nullptr,
- + /*n_embd =*/ 0,
- /*pos =*/ nullptr,
- /*n_seq_id =*/ nullptr,
- /*seq_id =*/ nullptr,
- @@ -21817,6 +22257,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
-
- if (embd) {
- batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
- + batch.n_embd = embd;
- } else {
- batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
- }
|