0010-add-mllama-support.patch 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690
  1. From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
  2. From: jmorganca <jmorganca@gmail.com>
  3. Date: Thu, 17 Oct 2024 15:18:22 -0700
  4. Subject: [PATCH] add mllama support
  5. mllama adds cross-attention layers to the standard llama architecture
  6. it also requires a way to input a new tensor: cross_attention_state
  7. once per generation
  8. cross-attention layers don't change and so they are cached in the
  9. kv cache once per run
  10. remaining is to implement the cross attention mask
  11. ---
  12. include/llama.h | 4 +
  13. src/llama.cpp | 456 ++++++++++++++++++++++++++++++++++++++++++++++--
  14. 2 files changed, 447 insertions(+), 13 deletions(-)
  15. diff --git a/include/llama.h b/include/llama.h
  16. index 7cae1bbe..122e3cf1 100644
  17. --- a/include/llama.h
  18. +++ b/include/llama.h
  19. @@ -423,6 +423,10 @@ extern "C" {
  20. struct llama_model * model,
  21. struct llama_context_params params);
  22. + // TODO (jmorganca): this should most likely be passed in as part of a batch
  23. + // and not set on the context for all batches.
  24. + LLAMA_API void llama_set_cross_attn_state(struct llama_context * ctx, float * cross_attn_state);
  25. +
  26. // Frees all allocated memory
  27. LLAMA_API void llama_free(struct llama_context * ctx);
  28. diff --git a/src/llama.cpp b/src/llama.cpp
  29. index 83b80b59..b189a19a 100644
  30. --- a/src/llama.cpp
  31. +++ b/src/llama.cpp
  32. @@ -169,6 +169,7 @@ static std::string format(const char * fmt, ...) {
  33. enum llm_arch {
  34. LLM_ARCH_LLAMA,
  35. + LLM_ARCH_MLLAMA,
  36. LLM_ARCH_FALCON,
  37. LLM_ARCH_BAICHUAN,
  38. LLM_ARCH_GROK,
  39. @@ -223,6 +224,7 @@ enum llm_arch {
  40. static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
  41. { LLM_ARCH_LLAMA, "llama" },
  42. + { LLM_ARCH_MLLAMA, "mllama" },
  43. { LLM_ARCH_FALCON, "falcon" },
  44. { LLM_ARCH_GROK, "grok" },
  45. { LLM_ARCH_GPT2, "gpt2" },
  46. @@ -330,6 +332,7 @@ enum llm_kv {
  47. LLM_KV_ATTENTION_SLIDING_WINDOW,
  48. LLM_KV_ATTENTION_SCALE,
  49. LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,
  50. + LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS,
  51. LLM_KV_ROPE_DIMENSION_COUNT,
  52. LLM_KV_ROPE_FREQ_BASE,
  53. @@ -439,6 +442,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
  54. { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
  55. { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
  56. { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, "%s.attention.block_skip_connection.%d" },
  57. + { LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, "%s.attention.cross_attention_layers" },
  58. { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
  59. { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
  60. @@ -613,6 +617,14 @@ enum llm_tensor {
  61. LLM_TENSOR_CLS,
  62. LLM_TENSOR_CLS_OUT,
  63. LLM_TENSOR_BSKCN_TV,
  64. + LLM_TENSOR_CROSS_ATTN_K_NORM,
  65. + LLM_TENSOR_CROSS_ATTN_K_PROJ,
  66. + LLM_TENSOR_CROSS_ATTN_O_PROJ,
  67. + LLM_TENSOR_CROSS_ATTN_Q_NORM,
  68. + LLM_TENSOR_CROSS_ATTN_Q_PROJ,
  69. + LLM_TENSOR_CROSS_ATTN_V_PROJ,
  70. + LLM_TENSOR_CROSS_ATTN_ATTN_GATE,
  71. + LLM_TENSOR_CROSS_ATTN_MLP_GATE,
  72. };
  73. static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES = {
  74. @@ -642,6 +654,40 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
  75. { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
  76. },
  77. },
  78. + {
  79. + LLM_ARCH_MLLAMA,
  80. + {
  81. + { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
  82. + { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
  83. + { LLM_TENSOR_OUTPUT, "output" },
  84. + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
  85. + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
  86. + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
  87. + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
  88. + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
  89. + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
  90. + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
  91. + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
  92. + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
  93. + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
  94. + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
  95. + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
  96. + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
  97. + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
  98. + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
  99. + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
  100. + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
  101. + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
  102. + { LLM_TENSOR_CROSS_ATTN_K_NORM, "blk.%d.cross_attn_k_norm" },
  103. + { LLM_TENSOR_CROSS_ATTN_K_PROJ, "blk.%d.cross_attn_k_proj" },
  104. + { LLM_TENSOR_CROSS_ATTN_O_PROJ, "blk.%d.cross_attn_o_proj" },
  105. + { LLM_TENSOR_CROSS_ATTN_Q_NORM, "blk.%d.cross_attn_q_norm" },
  106. + { LLM_TENSOR_CROSS_ATTN_Q_PROJ, "blk.%d.cross_attn_q_proj" },
  107. + { LLM_TENSOR_CROSS_ATTN_V_PROJ, "blk.%d.cross_attn_v_proj" },
  108. + { LLM_TENSOR_CROSS_ATTN_ATTN_GATE, "blk.%d.cross_attn_attn_gate" },
  109. + { LLM_TENSOR_CROSS_ATTN_MLP_GATE, "blk.%d.cross_attn_mlp_gate" },
  110. + },
  111. + },
  112. {
  113. LLM_ARCH_BAICHUAN,
  114. {
  115. @@ -2390,6 +2436,7 @@ enum e_model {
  116. MODEL_40B,
  117. MODEL_65B,
  118. MODEL_70B,
  119. + MODEL_90B,
  120. MODEL_236B,
  121. MODEL_314B,
  122. MODEL_SMALL,
  123. @@ -2434,6 +2481,7 @@ struct llama_hparams {
  124. std::array<uint32_t, LLAMA_MAX_LAYERS> n_ff_arr;
  125. std::array<std::array<uint32_t, LLAMA_MAX_LAYERS>, 4> n_bskcn_arr;
  126. + std::array<uint32_t, LLAMA_MAX_LAYERS> cross_attn_layers;
  127. uint32_t n_layer_dense_lead = 0;
  128. uint32_t n_lora_q = 0;
  129. @@ -2502,10 +2550,11 @@ struct llama_hparams {
  130. if (this->n_expert != other.n_expert) return true;
  131. if (this->n_expert_used != other.n_expert_used) return true;
  132. - if (this->n_head_arr != other.n_head_arr) return true;
  133. - if (this->n_head_kv_arr != other.n_head_kv_arr) return true;
  134. - if (this->n_ff_arr != other.n_ff_arr) return true;
  135. - if (this->n_bskcn_arr != other.n_bskcn_arr) return true;
  136. + if (this->n_head_arr != other.n_head_arr) return true;
  137. + if (this->n_head_kv_arr != other.n_head_kv_arr) return true;
  138. + if (this->n_ff_arr != other.n_ff_arr) return true;
  139. + if (this->n_bskcn_arr != other.n_bskcn_arr) return true;
  140. + if (this->cross_attn_layers != other.cross_attn_layers) return true;
  141. if (this->n_rel_attn_bkts != other.n_rel_attn_bkts) return true;
  142. if (this->n_layer_dense_lead != other.n_layer_dense_lead) return true;
  143. @@ -2623,6 +2672,10 @@ struct llama_hparams {
  144. GGML_ABORT("fatal error");
  145. }
  146. +
  147. + bool cross_attention_layer(uint32_t il) const {
  148. + return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end();
  149. + }
  150. };
  151. static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
  152. @@ -2806,6 +2859,16 @@ struct llama_layer {
  153. struct ggml_tensor * ffn_down_scale;
  154. struct ggml_tensor * bskcn_tv;
  155. +
  156. + // cross attention
  157. + struct ggml_tensor * cross_attn_k_norm;
  158. + struct ggml_tensor * cross_attn_k_proj;
  159. + struct ggml_tensor * cross_attn_o_proj;
  160. + struct ggml_tensor * cross_attn_q_norm;
  161. + struct ggml_tensor * cross_attn_q_proj;
  162. + struct ggml_tensor * cross_attn_v_proj;
  163. + struct ggml_tensor * cross_attn_attn_gate;
  164. + struct ggml_tensor * cross_attn_mlp_gate;
  165. };
  166. // very similar to llama_batch,
  167. @@ -3452,6 +3515,12 @@ struct llama_context {
  168. struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
  169. struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
  170. struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
  171. +
  172. + // TODO (jmorganca): this should most likely be passed in as part of a batch
  173. + // and not set on the context for all batches.
  174. + float * cross_attn_state = nullptr;
  175. + bool cross_attn_state_first_pass = true;
  176. + struct ggml_tensor * inp_cross_attn_state; // F32 [4, n_embd, 1061]
  177. };
  178. struct llama_lora_weight {
  179. @@ -3686,6 +3755,18 @@ static bool llama_kv_cache_init(
  180. cache.v_l.reserve(n_layer);
  181. for (int i = 0; i < (int) n_layer; i++) {
  182. + // for cross attention layers
  183. + if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layer(i)) {
  184. + struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
  185. + ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_k, 6404, hparams.n_head_kv(i));
  186. + ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_v, 6404, hparams.n_head_kv(i));
  187. + ggml_format_name(k, "cache_k_l%d", i);
  188. + ggml_format_name(v, "cache_v_l%d", i);
  189. + cache.k_l.push_back(k);
  190. + cache.v_l.push_back(v);
  191. + continue;
  192. + }
  193. +
  194. const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
  195. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
  196. @@ -5460,12 +5541,14 @@ static void llm_load_hparams(
  197. }
  198. // zero-out the per-layer hparams
  199. - std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
  200. - std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
  201. - std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0);
  202. + std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
  203. + std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
  204. + std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0);
  205. + std::fill(hparams.cross_attn_layers.begin(), hparams.cross_attn_layers.end(), -1);
  206. - ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer);
  207. - ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer);
  208. + ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer);
  209. + ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer);
  210. + ml.get_arr(LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, hparams.cross_attn_layers, false);
  211. // n_head_kv is optional, default to n_head
  212. hparams.n_head_kv_arr = hparams.n_head_arr;
  213. @@ -5514,7 +5597,7 @@ static void llm_load_hparams(
  214. ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
  215. - if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
  216. + if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_MLLAMA || model.arch == LLM_ARCH_FALCON) {
  217. if (hparams.n_rot != hparams.n_embd_head_k) {
  218. throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
  219. }
  220. @@ -5554,6 +5637,16 @@ static void llm_load_hparams(
  221. }
  222. }
  223. } break;
  224. + case LLM_ARCH_MLLAMA:
  225. + {
  226. + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
  227. +
  228. + switch (hparams.n_layer) {
  229. + case 40: model.type = e_model::MODEL_11B; break;
  230. + case 100: model.type = e_model::MODEL_90B; break;
  231. + default: model.type = e_model::MODEL_UNKNOWN;
  232. + }
  233. + } break;
  234. case LLM_ARCH_MINICPM:
  235. {
  236. ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
  237. @@ -7249,6 +7342,55 @@ static bool llm_load_tensors(
  238. layer.rope_short = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight"), { n_embd_head_qk_rope/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
  239. }
  240. } break;
  241. + case LLM_ARCH_MLLAMA:
  242. + {
  243. + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab+8});
  244. +
  245. + // output
  246. + {
  247. + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
  248. + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
  249. +
  250. + // if output is NULL, init from the input tok embed
  251. + if (model.output == NULL) {
  252. + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
  253. + }
  254. + }
  255. +
  256. + for (int i = 0; i < n_layer; ++i) {
  257. + ggml_context * ctx_layer = ctx_for_layer(i);
  258. + ggml_context * ctx_split = ctx_for_layer_split(i);
  259. +
  260. + auto & layer = model.layers[i];
  261. +
  262. + if (hparams.cross_attention_layer(i)) {
  263. + layer.cross_attn_k_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_NORM, "weight", i), {128});
  264. + layer.cross_attn_k_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_PROJ, "weight", i), {n_embd, 1024});
  265. + layer.cross_attn_o_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_O_PROJ, "weight", i), {n_embd, n_embd});
  266. + layer.cross_attn_q_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_Q_NORM, "weight", i), {128});
  267. + layer.cross_attn_q_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_Q_PROJ, "weight", i), {n_embd, n_embd});
  268. + layer.cross_attn_v_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_V_PROJ, "weight", i), {n_embd, 1024});
  269. + layer.cross_attn_attn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_ATTN_GATE, i), {1});
  270. + layer.cross_attn_mlp_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_MLP_GATE, i), {1});
  271. + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
  272. + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
  273. + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
  274. + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
  275. + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
  276. + } else {
  277. + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
  278. + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head});
  279. + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa});
  280. + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa});
  281. + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
  282. + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
  283. + layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
  284. + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
  285. + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
  286. + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
  287. + }
  288. + }
  289. + } break;
  290. case LLM_ARCH_GROK:
  291. {
  292. if (n_expert == 0) {
  293. @@ -9093,7 +9235,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
  294. if (model.vocab.type != LLAMA_VOCAB_TYPE_NONE &&
  295. model.hparams.n_vocab != model.vocab.id_to_token.size()) {
  296. - throw std::runtime_error("vocab size mismatch");
  297. + LLAMA_LOG_WARN("%s: vocab mismatch %u !- %zu ...\n", __func__, model.hparams.n_vocab, model.vocab.id_to_token.size());
  298. }
  299. if (params.vocab_only) {
  300. @@ -9178,7 +9320,7 @@ static struct ggml_tensor * llm_build_inp_embd(
  301. inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
  302. } else {
  303. - lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
  304. + lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
  305. inpL = lctx.inp_embd;
  306. ggml_set_input(lctx.inp_embd);
  307. }
  308. @@ -9193,6 +9335,22 @@ static struct ggml_tensor * llm_build_inp_embd(
  309. return inpL;
  310. }
  311. +static struct ggml_tensor * llm_build_inp_cross_attn_state(
  312. + struct ggml_context * ctx,
  313. + struct llama_context & lctx,
  314. + const llama_hparams & hparams,
  315. + const llm_build_cb & cb) {
  316. + const int64_t n_embd = hparams.n_embd;
  317. +
  318. + struct ggml_tensor * inpCAS;
  319. + lctx.inp_cross_attn_state = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, 1601, 4);
  320. + cb(lctx.inp_cross_attn_state, "inp_cross_attn_state", -1);
  321. + ggml_set_input(lctx.inp_cross_attn_state);
  322. + inpCAS = lctx.inp_cross_attn_state;
  323. +
  324. + return inpCAS;
  325. +}
  326. +
  327. static void llm_build_kv_store(
  328. struct ggml_context * ctx,
  329. const llama_hparams & hparams,
  330. @@ -10167,6 +10325,7 @@ struct llm_build_context {
  331. lctx.inp_pos_bucket = nullptr;
  332. lctx.inp_embd_enc = nullptr;
  333. lctx.inp_KQ_mask_cross = nullptr;
  334. + lctx.inp_cross_attn_state = nullptr;
  335. }
  336. void free() {
  337. @@ -10754,6 +10913,253 @@ struct llm_build_context {
  338. LLM_NORM_RMS, cb, -1);
  339. cb(cur, "result_norm", -1);
  340. + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
  341. + cb(cur, "result_output", -1);
  342. +
  343. + ggml_build_forward_expand(gf, cur);
  344. +
  345. + return gf;
  346. + }
  347. +
  348. + struct ggml_cgraph * build_mllama() {
  349. + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
  350. +
  351. + // mutable variable, needed during the last layer of the computation to skip unused tokens
  352. + int32_t n_tokens = this->n_tokens;
  353. +
  354. + const int64_t n_embd_head = hparams.n_embd_head_v;
  355. + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
  356. + GGML_ASSERT(n_embd_head == hparams.n_rot);
  357. +
  358. + struct ggml_tensor * cur;
  359. + struct ggml_tensor * inpL;
  360. + struct ggml_tensor * inpCAS;
  361. +
  362. + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
  363. + inpCAS = llm_build_inp_cross_attn_state(ctx0, lctx, hparams, cb);
  364. +
  365. + // inp_pos - contains the positions
  366. + struct ggml_tensor * inp_pos = build_inp_pos();
  367. +
  368. + // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
  369. + struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
  370. +
  371. + for (int il = 0; il < n_layer; ++il) {
  372. + struct ggml_tensor * inpSA = inpL;
  373. +
  374. + // norm
  375. + cur = llm_build_norm(ctx0, inpL, hparams,
  376. + model.layers[il].attn_norm, NULL,
  377. + LLM_NORM_RMS, cb, il);
  378. + cb(cur, "attn_norm", il);
  379. +
  380. + if (hparams.cross_attention_layer(il)) {
  381. + if (!lctx.cross_attn_state) {
  382. + continue;
  383. + }
  384. +
  385. + // cross attention layer
  386. + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_q_proj, cur);
  387. + cb(Qcur, "Qcur", il);
  388. +
  389. + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
  390. + cb(Qcur, "Qcur", il);
  391. +
  392. + Qcur = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
  393. + cb(Qcur, "Qcur", il);
  394. +
  395. + // TODO: is this required?
  396. + Qcur = ggml_cont(ctx0, Qcur);
  397. + cb(Qcur, "Qcur", il);
  398. +
  399. + Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].cross_attn_q_norm, NULL, LLM_NORM_RMS, cb, il);
  400. + cb(Qcur, "Qcur", il);
  401. +
  402. + struct ggml_tensor * Kcur;
  403. + if (lctx.cross_attn_state_first_pass) {
  404. + Kcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_k_proj, inpCAS);
  405. + cb(Kcur, "Kcur", il);
  406. +
  407. + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, 6404);
  408. + cb(Kcur, "Kcur", il);
  409. +
  410. + Kcur = ggml_permute(ctx0, Kcur, 0, 2, 1, 3);
  411. + cb(Kcur, "Kcur", il);
  412. +
  413. + // TODO: is this required?
  414. + Kcur = ggml_cont(ctx0, Kcur);
  415. + cb(Kcur, "Kcur", il);
  416. +
  417. + Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].cross_attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
  418. + cb(Kcur, "Kcur", il);
  419. +
  420. + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, kv_self.k_l[il]));
  421. + } else {
  422. + Kcur = ggml_view_tensor(ctx0, kv_self.k_l[il]);
  423. + cb(Kcur, "Kcur (view)", il);
  424. + }
  425. +
  426. + struct ggml_tensor * Vcur;
  427. + if (lctx.cross_attn_state_first_pass) {
  428. + Vcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_v_proj, inpCAS);
  429. + cb(Vcur, "Vcur", il);
  430. +
  431. + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, 6404);
  432. + cb(Vcur, "Vcur", il);
  433. +
  434. + Vcur = ggml_permute(ctx0, Vcur, 0, 2, 1, 3);
  435. + cb(Vcur, "Vcur", il);
  436. +
  437. + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, kv_self.v_l[il]));
  438. + } else {
  439. + Vcur = ggml_view_tensor(ctx0, kv_self.v_l[il]);
  440. + cb(Vcur, "Vcur (view)", il);
  441. + }
  442. +
  443. + struct ggml_tensor * kq = ggml_mul_mat(ctx0, Kcur, Qcur);
  444. + cb(kq, "kq", il);
  445. +
  446. + kq = ggml_scale_inplace(ctx0, kq, 1.0f/sqrtf(float(n_embd_head)));
  447. + cb(kq, "kq_scaled", il);
  448. +
  449. + // TODO: apply causal masks
  450. + struct ggml_tensor * kq_soft_max = ggml_soft_max_inplace(ctx0, kq);
  451. + cb(kq_soft_max, "kq_soft_max", il);
  452. +
  453. + Vcur = ggml_cont(ctx0, ggml_transpose(ctx0, Vcur));
  454. + cb(Vcur, "Vcur", il);
  455. +
  456. + struct ggml_tensor * kqv = ggml_mul_mat(ctx0, Vcur, kq_soft_max);
  457. + cb(kqv, "kqv", il);
  458. +
  459. + struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
  460. + cb(kqv_merged, "kqv_merged", il);
  461. +
  462. + cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
  463. + cb(cur, "kqv_merged_cont", il);
  464. +
  465. + cur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_o_proj, cur);
  466. + cb(cur, "cur", il);
  467. +
  468. + // TODO: do this in place once?
  469. + cur = ggml_mul(ctx0, cur, ggml_tanh(ctx0, model.layers[il].cross_attn_attn_gate));
  470. +
  471. + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
  472. + cb(ffn_inp, "ffn_inp", il);
  473. +
  474. + // feed-forward network
  475. + cur = llm_build_norm(ctx0, ffn_inp, hparams,
  476. + model.layers[il].ffn_norm, NULL,
  477. + LLM_NORM_RMS, cb, il);
  478. + cb(cur, "ffn_norm", il);
  479. +
  480. + cur = llm_build_ffn(ctx0, lctx, cur,
  481. + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
  482. + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
  483. + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
  484. + NULL,
  485. + LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
  486. + cb(cur, "ffn_out", il);
  487. +
  488. + // TODO: do this inplace once?
  489. + cur = ggml_add_inplace(ctx0, ggml_mul_inplace(ctx0, cur, ggml_tanh(ctx0, model.layers[il].cross_attn_mlp_gate)), ffn_inp);
  490. + cb(cur, "ffn_out", il);
  491. +
  492. + cur = lctx.cvec.apply_to(ctx0, cur, il);
  493. + cb(cur, "l_out", il);
  494. +
  495. + // input for next layer
  496. + inpL = cur;
  497. + } else {
  498. + // self attention layer
  499. +
  500. + // rope freq factors for llama3; may return nullptr for llama2 and other models
  501. + struct ggml_tensor * rope_factors = build_rope_factors(il);
  502. +
  503. + // compute Q and K and RoPE them
  504. + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
  505. + cb(Qcur, "Qcur", il);
  506. + if (model.layers[il].bq) {
  507. + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
  508. + cb(Qcur, "Qcur", il);
  509. + }
  510. +
  511. + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
  512. + cb(Kcur, "Kcur", il);
  513. + if (model.layers[il].bk) {
  514. + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
  515. + cb(Kcur, "Kcur", il);
  516. + }
  517. +
  518. + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
  519. + cb(Vcur, "Vcur", il);
  520. + if (model.layers[il].bv) {
  521. + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
  522. + cb(Vcur, "Vcur", il);
  523. + }
  524. +
  525. + Qcur = ggml_rope_ext(
  526. + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
  527. + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
  528. + ext_factor, attn_factor, beta_fast, beta_slow
  529. + );
  530. + cb(Qcur, "Qcur", il);
  531. +
  532. + Kcur = ggml_rope_ext(
  533. + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
  534. + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
  535. + ext_factor, attn_factor, beta_fast, beta_slow
  536. + );
  537. + cb(Kcur, "Kcur", il);
  538. +
  539. + cur = llm_build_kv(ctx0, lctx, kv_self, gf,
  540. + model.layers[il].wo, model.layers[il].bo,
  541. + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
  542. +
  543. +
  544. + if (il == n_layer - 1) {
  545. + // skip computing output for unused tokens
  546. + struct ggml_tensor * inp_out_ids = build_inp_out_ids();
  547. + n_tokens = n_outputs;
  548. + cur = ggml_get_rows(ctx0, cur, inp_out_ids);
  549. + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
  550. + }
  551. +
  552. + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
  553. + cb(ffn_inp, "ffn_inp", il);
  554. +
  555. + // feed-forward network
  556. + cur = llm_build_norm(ctx0, ffn_inp, hparams,
  557. + model.layers[il].ffn_norm, NULL,
  558. + LLM_NORM_RMS, cb, il);
  559. + cb(cur, "ffn_norm", il);
  560. +
  561. + cur = llm_build_ffn(ctx0, lctx, cur,
  562. + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
  563. + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
  564. + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
  565. + NULL,
  566. + LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
  567. + cb(cur, "ffn_out", il);
  568. +
  569. + cur = ggml_add(ctx0, cur, ffn_inp);
  570. + cb(cur, "ffn_out", il);
  571. +
  572. + cur = lctx.cvec.apply_to(ctx0, cur, il);
  573. + cb(cur, "l_out", il);
  574. +
  575. + // input for next layer
  576. + inpL = cur;
  577. + }
  578. + }
  579. +
  580. + cur = inpL;
  581. +
  582. + cur = llm_build_norm(ctx0, cur, hparams,
  583. + model.output_norm, NULL,
  584. + LLM_NORM_RMS, cb, -1);
  585. + cb(cur, "result_norm", -1);
  586. +
  587. // lm_head
  588. cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
  589. cb(cur, "result_output", -1);
  590. @@ -16501,6 +16907,10 @@ static struct ggml_cgraph * llama_build_graph(
  591. {
  592. result = llm.build_llama();
  593. } break;
  594. + case LLM_ARCH_MLLAMA:
  595. + {
  596. + result = llm.build_mllama();
  597. + } break;
  598. case LLM_ARCH_BAICHUAN:
  599. {
  600. result = llm.build_baichuan();
  601. @@ -16773,6 +17183,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
  602. ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
  603. }
  604. + // TODO (jmorganca): this might copy a lot of data on every request of a
  605. + // single generation even though it doesn't change, so we should
  606. + // find a way to not set this more than one time per image
  607. + if (lctx.inp_cross_attn_state &&
  608. + lctx.inp_cross_attn_state->buffer) {
  609. + ggml_backend_tensor_set(lctx.inp_cross_attn_state, lctx.cross_attn_state, 0, hparams.n_embd * 1601 * 4 * ggml_element_size(lctx.inp_cross_attn_state));
  610. + }
  611. +
  612. if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
  613. GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
  614. const int64_t n_tokens = batch.n_tokens;
  615. @@ -17455,6 +17873,10 @@ static int llama_decode_internal(
  616. llama_set_inputs(lctx, ubatch);
  617. + // TODO: replace with something better to find out if its
  618. + // our first actual pass
  619. + lctx.cross_attn_state_first_pass = false;
  620. +
  621. llama_graph_compute(lctx, gf, n_threads, threadpool);
  622. // update the kv ring buffer
  623. @@ -18648,7 +19070,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
  624. if (llama_model_has_encoder(&model)) {
  625. n_attn_layer *= 3;
  626. }
  627. - GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected");
  628. + if (qs.n_attention_wv != n_attn_layer) {
  629. + LLAMA_LOG_WARN("%s: n_attention_wv is unexpected, expected: %d, found: %d\n", __func__, n_attn_layer, qs.n_attention_wv);
  630. + }
  631. }
  632. size_t total_size_org = 0;
  633. @@ -19744,6 +20168,11 @@ struct llama_context * llama_new_context_with_model(
  634. return ctx;
  635. }
  636. +void llama_set_cross_attn_state(struct llama_context * ctx, float * cross_attn_state) {
  637. + ctx->cross_attn_state_first_pass = true;
  638. + ctx->cross_attn_state = cross_attn_state;
  639. +}
  640. +
  641. void llama_free(struct llama_context * ctx) {
  642. delete ctx;
  643. }
  644. @@ -19814,6 +20243,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
  645. // use what we call a normal RoPE, operating on pairs of consecutive head values
  646. case LLM_ARCH_LLAMA:
  647. + case LLM_ARCH_MLLAMA:
  648. case LLM_ARCH_BAICHUAN:
  649. case LLM_ARCH_STARCODER:
  650. case LLM_ARCH_PLAMO: