From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 From: jmorganca Date: Thu, 17 Oct 2024 15:18:22 -0700 Subject: [PATCH] add mllama support mllama adds cross-attention layers to the standard llama architecture it also requires a way to input a new tensor: cross_attention_state once per generation cross-attention layers don't change and so they are cached in the kv cache once per run remaining is to implement the cross attention mask --- examples/llava/llava.cpp | 5 +- ggml/src/ggml-backend-reg.cpp | 6 +- include/llama.h | 6 + src/llama-arch.cpp | 44 +++++ src/llama-arch.h | 10 ++ src/llama-batch.cpp | 3 + src/llama-context.cpp | 19 ++- src/llama-context.h | 2 + src/llama-cparams.h | 1 + src/llama-hparams.cpp | 8 +- src/llama-hparams.h | 4 + src/llama-kv-cache.cpp | 33 ++++ src/llama-model-loader.cpp | 2 + src/llama-model.cpp | 59 ++----- src/llama-model.h | 51 ++++++ src/llama-quant.cpp | 4 +- src/llama.cpp | 307 +++++++++++++++++++++++++++++++++- 17 files changed, 508 insertions(+), 56 deletions(-) diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index 16f30c56..0f0f3f62 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -429,7 +429,7 @@ struct llava_embd_batch { std::vector seq_ids; std::vector logits; llama_batch batch; - llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { + llava_embd_batch(float * embd, int32_t n_embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { pos .resize(n_tokens); n_seq_id.resize(n_tokens); seq_ids .resize(n_tokens + 1); @@ -441,6 +441,7 @@ struct llava_embd_batch { /*n_tokens =*/ n_tokens, /*tokens =*/ nullptr, /*embd =*/ embd, + /*n_embd =*/ n_embd, /*pos =*/ pos.data(), /*n_seq_id =*/ n_seq_id.data(), /*seq_id =*/ seq_ids.data(), @@ -464,7 +465,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_ n_eval = n_batch; } float * embd = image_embed->embed+i*n_embd; - llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0); + llava_embd_batch llava_batch = llava_embd_batch(embd, n_embd, n_eval, *n_past, 0); if (llama_decode(ctx_llama, llava_batch.batch)) { LOG_ERR("%s : failed to eval\n", __func__); return false; diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp index 7ddd178b..899d16f2 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -171,9 +171,9 @@ struct ggml_backend_registry { #ifdef GGML_USE_CANN register_backend(ggml_backend_cann_reg()); #endif -#ifdef GGML_USE_BLAS - register_backend(ggml_backend_blas_reg()); -#endif +// #ifdef GGML_USE_BLAS +// register_backend(ggml_backend_blas_reg()); +// #endif #ifdef GGML_USE_RPC register_backend(ggml_backend_rpc_reg()); #endif diff --git a/include/llama.h b/include/llama.h index a0d5ba5d..9f411960 100644 --- a/include/llama.h +++ b/include/llama.h @@ -250,6 +250,7 @@ extern "C" { llama_token * token; float * embd; + int32_t n_embd; llama_pos * pos; int32_t * n_seq_id; llama_seq_id ** seq_id; @@ -347,6 +348,7 @@ extern "C" { bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU bool flash_attn; // whether to use flash attention [EXPERIMENTAL] bool no_perf; // whether to measure performance timings + bool cross_attn; // whether to use cross attention // Abort callback // if it returns true, execution of llama_decode() will be aborted @@ -426,6 +428,10 @@ extern "C" { struct llama_model * model, struct llama_context_params params); + // TODO (jmorganca): this should most likely be passed in as part of a batch + // and not set on the context for all batches. + LLAMA_API void llama_set_cross_attention(struct llama_context * ctx, bool cross_attn_state); + // Frees all allocated memory LLAMA_API void llama_free(struct llama_context * ctx); diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 5b376c5e..b35aeb31 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -6,6 +6,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_LLAMA, "llama" }, + { LLM_ARCH_MLLAMA, "mllama" }, { LLM_ARCH_DECI, "deci" }, { LLM_ARCH_FALCON, "falcon" }, { LLM_ARCH_GROK, "grok" }, @@ -124,6 +125,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, "%s.attention.block_skip_connection" }, + { LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, "%s.attention.cross_attention_layers" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, @@ -220,6 +222,40 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, + { + LLM_ARCH_MLLAMA, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_CROSS_ATTN_K_NORM, "blk.%d.cross_attn_k_norm" }, + { LLM_TENSOR_CROSS_ATTN_K_PROJ, "blk.%d.cross_attn_k_proj" }, + { LLM_TENSOR_CROSS_ATTN_O_PROJ, "blk.%d.cross_attn_o_proj" }, + { LLM_TENSOR_CROSS_ATTN_Q_NORM, "blk.%d.cross_attn_q_norm" }, + { LLM_TENSOR_CROSS_ATTN_Q_PROJ, "blk.%d.cross_attn_q_proj" }, + { LLM_TENSOR_CROSS_ATTN_V_PROJ, "blk.%d.cross_attn_v_proj" }, + { LLM_TENSOR_CROSS_ATTN_ATTN_GATE, "blk.%d.cross_attn_attn_gate" }, + { LLM_TENSOR_CROSS_ATTN_MLP_GATE, "blk.%d.cross_attn_mlp_gate" }, + }, + }, { LLM_ARCH_DECI, { @@ -1393,6 +1429,14 @@ static const std::map LLM_TENSOR_INFOS = { // this tensor is loaded for T5, but never used {LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}}, {LLM_TENSOR_BSKCN_TV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_CROSS_ATTN_K_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_CROSS_ATTN_K_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CROSS_ATTN_O_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CROSS_ATTN_Q_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_CROSS_ATTN_Q_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CROSS_ATTN_V_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CROSS_ATTN_ATTN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_CROSS_ATTN_MLP_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}}, {LLM_TENSOR_POS_NET_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_POS_NET_NORM1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, diff --git a/src/llama-arch.h b/src/llama-arch.h index eac7055b..e8235ae0 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -10,6 +10,7 @@ enum llm_arch { LLM_ARCH_LLAMA, + LLM_ARCH_MLLAMA, LLM_ARCH_DECI, LLM_ARCH_FALCON, LLM_ARCH_BAICHUAN, @@ -128,6 +129,7 @@ enum llm_kv { LLM_KV_ATTENTION_SLIDING_WINDOW, LLM_KV_ATTENTION_SCALE, LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, + LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_SECTIONS, @@ -308,6 +310,14 @@ enum llm_tensor { LLM_TENSOR_CLS, LLM_TENSOR_CLS_OUT, LLM_TENSOR_BSKCN_TV, + LLM_TENSOR_CROSS_ATTN_K_NORM, + LLM_TENSOR_CROSS_ATTN_K_PROJ, + LLM_TENSOR_CROSS_ATTN_O_PROJ, + LLM_TENSOR_CROSS_ATTN_Q_NORM, + LLM_TENSOR_CROSS_ATTN_Q_PROJ, + LLM_TENSOR_CROSS_ATTN_V_PROJ, + LLM_TENSOR_CROSS_ATTN_ATTN_GATE, + LLM_TENSOR_CROSS_ATTN_MLP_GATE, LLM_TENSOR_CONV1D, LLM_TENSOR_CONVNEXT_DW, LLM_TENSOR_CONVNEXT_NORM, diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 01d5ca57..8682b0e6 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -316,6 +316,7 @@ struct llama_batch llama_batch_get_one( /*n_tokens =*/ n_tokens, /*tokens =*/ tokens, /*embd =*/ nullptr, + /*n_embd =*/ 0, /*pos =*/ nullptr, /*n_seq_id =*/ nullptr, /*seq_id =*/ nullptr, @@ -328,6 +329,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_ /*n_tokens =*/ 0, /*tokens =*/ nullptr, /*embd =*/ nullptr, + /*n_embd =*/ 0, /*pos =*/ nullptr, /*n_seq_id =*/ nullptr, /*seq_id =*/ nullptr, @@ -336,6 +338,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_ if (embd) { batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd); + batch.n_embd = embd; } else { batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc); } diff --git a/src/llama-context.cpp b/src/llama-context.cpp index b9c4a5bf..9d0e7ca3 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -71,10 +71,19 @@ void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) { } if (ubatch.embd) { - const int64_t n_embd = hparams.n_embd; - const int64_t n_tokens = ubatch.n_tokens; + if (lctx.inp_cross_attn_state && lctx.inp_cross_attn_state->buffer) { + ggml_backend_tensor_set(lctx.inp_cross_attn_state, ubatch.embd, 0, ggml_nbytes(lctx.inp_cross_attn_state)); + // zero out inp_embd since it's not used + float * inp_embd_data = (float *)lctx.inp_embd->data; + for (int i = 0; i < ggml_nelements(lctx.inp_embd); ++i) { + inp_embd_data[i] = 0.0f; + } + } else { + const int64_t n_embd = hparams.n_embd; + const int64_t n_tokens = ubatch.n_tokens; - ggml_backend_tensor_set(lctx.inp_embd, ubatch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd)); + ggml_backend_tensor_set(lctx.inp_embd, ubatch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd)); + } } if (ubatch.pos && lctx.inp_pos) { @@ -653,6 +662,10 @@ void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) { ctx->cparams.causal_attn = causal_attn; } +void llama_set_cross_attention(struct llama_context * ctx, bool cross_attention) { + ctx->cparams.cross_attn = cross_attention; +} + void llama_synchronize(struct llama_context * ctx) { ggml_backend_sched_synchronize(ctx->sched.get()); diff --git a/src/llama-context.h b/src/llama-context.h index 0d163c47..4980a60e 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -107,6 +107,8 @@ struct llama_context { struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch] struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] + + struct ggml_tensor * inp_cross_attn_state; // F32 [4, n_embd, 1061] }; // TODO: make these methods of llama_context diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 252012f3..9681e5a0 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -29,6 +29,7 @@ struct llama_cparams { bool offload_kqv; bool flash_attn; bool no_perf; + bool cross_attn; enum llama_pooling_type pooling_type; diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 450738da..42f8a58f 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -2,6 +2,8 @@ #include "ggml.h" +#include + uint32_t llama_hparams::n_head(uint32_t il) const { if (il < n_layer) { return n_head_arr[il]; @@ -76,4 +78,8 @@ bool llama_hparams::n_bskcn(uint32_t n, uint32_t il) const { } GGML_ABORT("fatal error"); -} \ No newline at end of file +} + +bool llama_hparams::cross_attention_layers(uint32_t il) const { + return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end(); +} diff --git a/src/llama-hparams.h b/src/llama-hparams.h index fd898e27..f826cd9a 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -53,6 +53,7 @@ struct llama_hparams { std::array n_ff_arr; std::array, 4> n_bskcn_arr = {}; + std::array cross_attn_layers; uint32_t n_layer_dense_lead = 0; uint32_t n_lora_q = 0; @@ -139,6 +140,9 @@ struct llama_hparams { // Block skip connection bool n_bskcn(uint32_t n, uint32_t il) const; + + // cross attention layers + bool cross_attention_layers(uint32_t il) const; }; static_assert(std::is_trivially_copyable::value, "llama_hparams must be trivially copyable"); diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 53379253..cf814dbe 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -72,6 +72,39 @@ bool llama_kv_cache_init( cache.v_l.reserve(n_layer); for (int i = 0; i < n_layer; i++) { + // for cross attention layers + if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layers(i)) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); + const llama_model::buft_list_t * buft_list; + if (offload) { + buft_list = model.dev_layer.at(i).buft_list; + } else { + buft_list = &model.cpu_buft_list; + } + ggml_backend_buffer_type_t buft = select_buft(*buft_list, + [&](ggml_context * ctx) { + ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); + if (hparams.rope_type == LLAMA_ROPE_TYPE_NONE) { + return k; + } + ggml_tensor * p = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1); + return ggml_rope(ctx, k, p, hparams.n_rot, hparams.rope_type); + }); + ggml_context * ctx = ctx_for_buft(buft); + + if (!ctx) { + LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__); + return false; + } + ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_k, 6404, hparams.n_head_kv(i)); + ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_v, 6404, hparams.n_head_kv(i)); + ggml_format_name(k, "cache_k_l%d", i); + ggml_format_name(v, "cache_v_l%d", i); + cache.k_l.push_back(k); + cache.v_l.push_back(v); + continue; + } + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index 422524a8..b12d6566 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -240,6 +240,8 @@ namespace GGUFMeta { return true; } + template bool llama_model_loader::get_arr>(enum llm_kv kid, std::array& result, bool required); + template bool llama_model_loader::get_arr(const std::string & key, std::array & result, bool required) { const int kid = gguf_find_key(meta.get(), key.c_str()); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 306c557d..4f9bbf90 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -146,46 +146,6 @@ std::string llama_model_ftype_name(const llama_model & model) { return llama_model_ftype_name(model.ftype); } -template -static bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, F & fn) { - ggml_init_params params = { - /*.mem_size =*/ ggml_tensor_overhead()*8, - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ true, - }; - - ggml_context_ptr ctx { ggml_init(params) }; - if (!ctx) { - throw std::runtime_error(format("failed to create ggml context")); - } - - ggml_backend_buffer_ptr buf { ggml_backend_buft_alloc_buffer(buft, 0) }; - ggml_tensor * op_tensor = fn(ctx.get()); - for (int i = 0; i < GGML_MAX_SRC; i++) { - if (op_tensor->src[i] != nullptr) { - assert(op_tensor->src[i]->buffer == nullptr); - op_tensor->src[i]->buffer = buf.get(); - } - } - - bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor); - - return op_supported; -} - -template -static ggml_backend_buffer_type_t select_buft(const llama_model::buft_list_t & buft_list, const F & fn) { - for (const auto & cur : buft_list) { - ggml_backend_dev_t cur_dev = cur.first; - ggml_backend_buffer_type_t cur_buft = cur.second; - if (buft_supported(cur_buft, cur_dev, fn)) { - return cur_buft; - } - } - - throw std::runtime_error(format("no suitable buffer type found")); -} - ggml_backend_buffer_type_t llama_model_select_buft(const llama_model & model, int il) { return select_buft( *model.dev_layer.at(il).buft_list, @@ -312,9 +272,11 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0); + std::fill(hparams.cross_attn_layers.begin(), hparams.cross_attn_layers.end(), -1); - ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false); - ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false); + ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false); + ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false); + ml.get_arr(LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, hparams.cross_attn_layers, false); // n_head_kv is optional, default to n_head hparams.n_head_kv_arr = hparams.n_head_arr; @@ -363,7 +325,7 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); - if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_DECI || model.arch == LLM_ARCH_FALCON) { + if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_MLLAMA || model.arch == LLM_ARCH_DECI || model.arch == LLM_ARCH_FALCON) { if (hparams.n_rot != hparams.n_embd_head_k) { throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k)); } @@ -405,6 +367,16 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { } } } break; + case LLM_ARCH_MLLAMA: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 40: model.type = e_model::MODEL_11B; break; + case 100: model.type = e_model::MODEL_90B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_DECI: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -2062,6 +2034,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { // use what we call a normal RoPE, operating on pairs of consecutive head values case LLM_ARCH_LLAMA: + case LLM_ARCH_MLLAMA: case LLM_ARCH_DECI: case LLM_ARCH_BAICHUAN: case LLM_ARCH_STARCODER: diff --git a/src/llama-model.h b/src/llama-model.h index c1b9c0a1..5b23e2ba 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -9,6 +9,7 @@ #include "ggml-cpp.h" #include +#include // available models // TODO: this enum does not follow the enum naming convention @@ -62,6 +63,7 @@ enum llm_type { MODEL_40B, MODEL_65B, MODEL_70B, + MODEL_90B, MODEL_236B, MODEL_314B, MODEL_671B, @@ -278,6 +280,16 @@ struct llama_layer { struct ggml_tensor * bskcn_tv = nullptr; + // cross attention + struct ggml_tensor * cross_attn_k_norm = nullptr; + struct ggml_tensor * cross_attn_k_proj = nullptr; + struct ggml_tensor * cross_attn_o_proj = nullptr; + struct ggml_tensor * cross_attn_q_norm = nullptr; + struct ggml_tensor * cross_attn_q_proj = nullptr; + struct ggml_tensor * cross_attn_v_proj = nullptr; + struct ggml_tensor * cross_attn_attn_gate = nullptr; + struct ggml_tensor * cross_attn_mlp_gate = nullptr; + struct llama_layer_posnet posnet; struct llama_layer_convnext convnext; @@ -376,6 +388,45 @@ std::string llama_model_arch_name (const llama_model & model); std::string llama_model_type_name (const llama_model & model); std::string llama_model_ftype_name(const llama_model & model); +template +bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, F & fn) { + ggml_init_params params = { + /*.mem_size =*/ ggml_tensor_overhead()*8, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context_ptr ctx { ggml_init(params) }; + if (!ctx) { + throw std::runtime_error("failed to create ggml context"); + } + + ggml_backend_buffer_ptr buf { ggml_backend_buft_alloc_buffer(buft, 0) }; + ggml_tensor * op_tensor = fn(ctx.get()); + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (op_tensor->src[i] != nullptr) { + op_tensor->src[i]->buffer = buf.get(); + } + } + + bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor); + + return op_supported; +} + +template +ggml_backend_buffer_type_t select_buft(const llama_model::buft_list_t & buft_list, const F & fn) { + for (const auto & cur : buft_list) { + ggml_backend_dev_t cur_dev = cur.first; + ggml_backend_buffer_type_t cur_buft = cur.second; + if (buft_supported(cur_buft, cur_dev, fn)) { + return cur_buft; + } + } + + throw std::runtime_error("no suitable buffer type found"); +} + // used by llama_adapter_cvec ggml_backend_buffer_type_t llama_model_select_buft(const llama_model & model, int il); diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 42974f8f..27def6fd 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -629,7 +629,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (llama_model_has_encoder(&model)) { n_attn_layer *= 3; } - GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected"); + if (qs.n_attention_wv != n_attn_layer) { + LLAMA_LOG_WARN("%s: n_attention_wv is unexpected, expected: %d, found: %d\n", __func__, n_attn_layer, qs.n_attention_wv); + } } size_t total_size_org = 0; diff --git a/src/llama.cpp b/src/llama.cpp index 7dec50ae..bac66c24 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -563,6 +563,52 @@ static bool llm_load_tensors( } } } break; + case LLM_ARCH_MLLAMA: + { + model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab+8}, 0); + + // output + { + model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (model.output == NULL) { + model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = model.layers[i]; + + if (hparams.cross_attention_layers(i)) { + layer.cross_attn_k_norm = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_K_NORM, "weight", i), {128}, 0); + layer.cross_attn_k_proj = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_K_PROJ, "weight", i), {n_embd, 1024}, 0); + layer.cross_attn_o_proj = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_O_PROJ, "weight", i), {n_embd, n_embd}, 0); + layer.cross_attn_q_norm = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_Q_NORM, "weight", i), {128}, 0); + layer.cross_attn_q_proj = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_Q_PROJ, "weight", i), {n_embd, n_embd}, 0); + layer.cross_attn_v_proj = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_V_PROJ, "weight", i), {n_embd, 1024}, 0); + layer.cross_attn_attn_gate = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_ATTN_GATE, i), {1}, 0); + layer.cross_attn_mlp_gate = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_MLP_GATE, i), {1}, 0); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + } else { + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } + } break; case LLM_ARCH_DECI: { model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -2514,7 +2560,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam if (model.vocab.type != LLAMA_VOCAB_TYPE_NONE && model.hparams.n_vocab != model.vocab.id_to_token.size()) { - throw std::runtime_error("vocab size mismatch"); + LLAMA_LOG_WARN("%s: vocab mismatch %u !- %zu ...\n", __func__, model.hparams.n_vocab, model.vocab.id_to_token.size()); } if (params.vocab_only) { @@ -2598,6 +2644,21 @@ static struct ggml_tensor * llm_build_inp_embd( return inpL; } +static struct ggml_tensor * llm_build_inp_cross_attn_state( + struct ggml_context * ctx, + struct llama_context & lctx, + const llama_hparams & hparams, + const llm_build_cb & cb) { + const int64_t n_embd = hparams.n_embd; + + struct ggml_tensor * inpCAS = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, 1601, 4); + cb(inpCAS, "inp_cross_attn_state", -1); + ggml_set_input(inpCAS); + lctx.inp_cross_attn_state = inpCAS; + + return inpCAS; +} + static void llm_build_kv_store( struct ggml_context * ctx, const llama_hparams & hparams, @@ -3593,6 +3654,7 @@ struct llm_build_context { lctx.inp_pos_bucket = nullptr; lctx.inp_embd_enc = nullptr; lctx.inp_KQ_mask_cross = nullptr; + lctx.inp_cross_attn_state = nullptr; } void free() { @@ -4074,6 +4136,240 @@ struct llm_build_context { return gf; } + struct ggml_cgraph * build_mllama() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + struct ggml_tensor * inpCAS; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); + inpCAS = llm_build_inp_cross_attn_state(ctx0, lctx, hparams, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + if (hparams.cross_attention_layers(il)) { + if (!ubatch.embd && !cparams.cross_attn) { + continue; + } + + // cross attention layer + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_q_proj, cur); + cb(Qcur, "Qcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + cb(Qcur, "Qcur", il); + + Qcur = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 0, 2, 1, 3)); + cb(Qcur, "Qcur", il); + + Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].cross_attn_q_norm, NULL, LLM_NORM_RMS, cb, il); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur, * Vcur; + if (ubatch.embd) { + Kcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_k_proj, inpCAS); + cb(Kcur, "Kcur", il); + + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, 6404); + cb(Kcur, "Kcur", il); + + Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); + cb(Kcur, "Kcur", il); + + Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].cross_attn_k_norm, NULL, LLM_NORM_RMS, cb, il); + cb(Kcur, "Kcur", il); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, kv_self.k_l[il])); + + Vcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_v_proj, inpCAS); + cb(Vcur, "Vcur", il); + + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, 6404); + cb(Vcur, "Vcur", il); + + Vcur = ggml_permute(ctx0, Vcur, 0, 2, 1, 3); + cb(Vcur, "Vcur", il); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, kv_self.v_l[il])); + } else { + Kcur = ggml_view_tensor(ctx0, kv_self.k_l[il]); + cb(Kcur, "Kcur (view)", il); + + Vcur = ggml_view_tensor(ctx0, kv_self.v_l[il]); + cb(Vcur, "Vcur (view)", il); + } + + struct ggml_tensor * kq = ggml_mul_mat(ctx0, Kcur, Qcur); + cb(kq, "kq", il); + + // TODO: apply causal masks + struct ggml_tensor * kq_soft_max = ggml_soft_max_ext(ctx0, kq, nullptr, 1.f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias); + cb(kq_soft_max, "kq_soft_max", il); + + Vcur = ggml_cont(ctx0, ggml_transpose(ctx0, Vcur)); + cb(Vcur, "Vcur", il); + + struct ggml_tensor * kqv = ggml_mul_mat(ctx0, Vcur, kq_soft_max); + cb(kqv, "kqv", il); + + struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + cb(kqv_merged, "kqv_merged", il); + + cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens); + cb(cur, "kqv_merged_cont", il); + + cur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_o_proj, cur); + cb(cur, "cur", il); + + // TODO: do this in place once? + cur = ggml_mul(ctx0, cur, ggml_tanh(ctx0, model.layers[il].cross_attn_attn_gate)); + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + + // TODO: do this inplace once? + cur = ggml_add_inplace(ctx0, ggml_mul_inplace(ctx0, cur, ggml_tanh(ctx0, model.layers[il].cross_attn_mlp_gate)), ffn_inp); + cb(cur, "ffn_out", il); + + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } else { + // self attention layer + + // rope freq factors for llama3; may return nullptr for llama2 and other models + struct ggml_tensor * rope_factors = build_rope_factors(il); + + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + struct ggml_cgraph * build_deci() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); @@ -10646,6 +10942,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_llama(); } break; + case LLM_ARCH_MLLAMA: + { + result = llm.build_mllama(); + } break; case LLM_ARCH_DECI: { result = llm.build_deci(); @@ -10971,7 +11271,7 @@ static int llama_decode_internal( n_outputs = 1; } - lctx.sbatch.from_batch(batch, n_embd, + lctx.sbatch.from_batch(batch, batch.n_embd, /* simple_split */ !kv_self.recurrent, /* logits_all */ n_outputs == n_tokens_all); @@ -11282,7 +11582,7 @@ static int llama_encode_internal( const int64_t n_embd = hparams.n_embd; - lctx.sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true); + lctx.sbatch.from_batch(batch, batch.n_embd, /* simple_split */ true, /* logits_all */ true); const llama_ubatch ubatch = lctx.sbatch.split_simple(n_tokens); @@ -11775,6 +12075,7 @@ struct llama_context_params llama_context_default_params() { /*.offload_kqv =*/ true, /*.flash_attn =*/ false, /*.no_perf =*/ true, + /*.cross_attn =*/ false, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, };