0009-mllama.patch 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794
  1. From c2db1ad0fc86de189959b628021a970511e9c6f9 Mon Sep 17 00:00:00 2001
  2. From: jmorganca <jmorganca@gmail.com>
  3. Date: Tue, 24 Sep 2024 11:53:40 -0700
  4. Subject: [PATCH] add mllama support
  5. mllama adds cross-attention layers to the standard llama architecture
  6. it also requires a way to input a new tensor: cross_attention_state
  7. once per generation
  8. cross-attention layers don't change and so they are cached in the
  9. kv cache once per run
  10. remaining is to implement the cross attention mask
  11. ---
  12. include/llama.h | 5 +
  13. src/llama.cpp | 514 ++++++++++++++++++++++++++++++++++++++++++++++--
  14. 2 files changed, 499 insertions(+), 20 deletions(-)
  15. diff --git a/include/llama.h b/include/llama.h
  16. index bfc37e88..94ce82a4 100644
  17. --- a/include/llama.h
  18. +++ b/include/llama.h
  19. @@ -449,6 +449,11 @@ extern "C" {
  20. struct llama_model * model,
  21. struct llama_context_params params);
  22. + // TODO (jmorganca): this should most likely be passed in as part of a batch
  23. + // and not set on the context for all batches.
  24. + LLAMA_API void llama_set_cross_attn_state(struct llama_context * ctx, float * cross_attn_state);
  25. + LLAMA_API void llama_reset_cross_attn_state(struct llama_context * ctx);
  26. +
  27. // Frees all allocated memory
  28. LLAMA_API void llama_free(struct llama_context * ctx);
  29. diff --git a/src/llama.cpp b/src/llama.cpp
  30. index b7771f53..75bbc226 100644
  31. --- a/src/llama.cpp
  32. +++ b/src/llama.cpp
  33. @@ -170,6 +170,7 @@ static std::string format(const char * fmt, ...) {
  34. enum llm_arch {
  35. LLM_ARCH_LLAMA,
  36. + LLM_ARCH_MLLAMA,
  37. LLM_ARCH_FALCON,
  38. LLM_ARCH_BAICHUAN,
  39. LLM_ARCH_GROK,
  40. @@ -219,6 +220,7 @@ enum llm_arch {
  41. static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
  42. { LLM_ARCH_LLAMA, "llama" },
  43. + { LLM_ARCH_MLLAMA, "mllama" },
  44. { LLM_ARCH_FALCON, "falcon" },
  45. { LLM_ARCH_GROK, "grok" },
  46. { LLM_ARCH_GPT2, "gpt2" },
  47. @@ -317,6 +319,7 @@ enum llm_kv {
  48. LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
  49. LLM_KV_ATTENTION_SLIDING_WINDOW,
  50. LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,
  51. + LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS,
  52. LLM_KV_ROPE_DIMENSION_COUNT,
  53. LLM_KV_ROPE_FREQ_BASE,
  54. @@ -422,6 +425,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
  55. { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
  56. { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
  57. { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, "%s.attention.block_skip_connection.%d" },
  58. + { LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, "%s.attention.cross_attention_layers" },
  59. { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
  60. { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
  61. @@ -594,6 +598,14 @@ enum llm_tensor {
  62. LLM_TENSOR_ENC_FFN_UP,
  63. LLM_TENSOR_ENC_OUTPUT_NORM,
  64. LLM_TENSOR_BSKCN_TV,
  65. + LLM_TENSOR_CROSS_ATTN_K_NORM,
  66. + LLM_TENSOR_CROSS_ATTN_K_PROJ,
  67. + LLM_TENSOR_CROSS_ATTN_O_PROJ,
  68. + LLM_TENSOR_CROSS_ATTN_Q_NORM,
  69. + LLM_TENSOR_CROSS_ATTN_Q_PROJ,
  70. + LLM_TENSOR_CROSS_ATTN_V_PROJ,
  71. + LLM_TENSOR_CROSS_ATTN_ATTN_GATE,
  72. + LLM_TENSOR_CROSS_ATTN_MLP_GATE,
  73. };
  74. static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES = {
  75. @@ -623,6 +635,40 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
  76. { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
  77. },
  78. },
  79. + {
  80. + LLM_ARCH_MLLAMA,
  81. + {
  82. + { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
  83. + { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
  84. + { LLM_TENSOR_OUTPUT, "output" },
  85. + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
  86. + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
  87. + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
  88. + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
  89. + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
  90. + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
  91. + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
  92. + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
  93. + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
  94. + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
  95. + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
  96. + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
  97. + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
  98. + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
  99. + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
  100. + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
  101. + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
  102. + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
  103. + { LLM_TENSOR_CROSS_ATTN_K_NORM, "blk.%d.cross_attn_k_norm" },
  104. + { LLM_TENSOR_CROSS_ATTN_K_PROJ, "blk.%d.cross_attn_k_proj" },
  105. + { LLM_TENSOR_CROSS_ATTN_O_PROJ, "blk.%d.cross_attn_o_proj" },
  106. + { LLM_TENSOR_CROSS_ATTN_Q_NORM, "blk.%d.cross_attn_q_norm" },
  107. + { LLM_TENSOR_CROSS_ATTN_Q_PROJ, "blk.%d.cross_attn_q_proj" },
  108. + { LLM_TENSOR_CROSS_ATTN_V_PROJ, "blk.%d.cross_attn_v_proj" },
  109. + { LLM_TENSOR_CROSS_ATTN_ATTN_GATE, "blk.%d.cross_attn_attn_gate" },
  110. + { LLM_TENSOR_CROSS_ATTN_MLP_GATE, "blk.%d.cross_attn_mlp_gate" },
  111. + },
  112. + },
  113. {
  114. LLM_ARCH_BAICHUAN,
  115. {
  116. @@ -1449,6 +1495,8 @@ static llm_arch llm_arch_from_string(const std::string & name) {
  117. return LLM_ARCH_UNKNOWN;
  118. }
  119. +
  120. +
  121. // helper to handle gguf constants
  122. // usage:
  123. //
  124. @@ -2267,6 +2315,7 @@ enum e_model {
  125. MODEL_40B,
  126. MODEL_65B,
  127. MODEL_70B,
  128. + MODEL_90B,
  129. MODEL_236B,
  130. MODEL_314B,
  131. MODEL_SMALL,
  132. @@ -2309,6 +2358,7 @@ struct llama_hparams {
  133. std::array<uint32_t, LLAMA_MAX_LAYERS> n_ff_arr;
  134. std::array<std::array<uint32_t, LLAMA_MAX_LAYERS>, 4> n_bskcn_arr;
  135. + std::array<uint32_t, LLAMA_MAX_LAYERS> cross_attn_layers;
  136. uint32_t n_layer_dense_lead = 0;
  137. uint32_t n_lora_q = 0;
  138. @@ -2372,10 +2422,11 @@ struct llama_hparams {
  139. if (this->n_expert != other.n_expert) return true;
  140. if (this->n_expert_used != other.n_expert_used) return true;
  141. - if (this->n_head_arr != other.n_head_arr) return true;
  142. - if (this->n_head_kv_arr != other.n_head_kv_arr) return true;
  143. - if (this->n_ff_arr != other.n_ff_arr) return true;
  144. - if (this->n_bskcn_arr != other.n_bskcn_arr) return true;
  145. + if (this->n_head_arr != other.n_head_arr) return true;
  146. + if (this->n_head_kv_arr != other.n_head_kv_arr) return true;
  147. + if (this->n_ff_arr != other.n_ff_arr) return true;
  148. + if (this->n_bskcn_arr != other.n_bskcn_arr) return true;
  149. + if (this->cross_attn_layers != other.cross_attn_layers) return true;
  150. if (this->n_rel_attn_bkts != other.n_rel_attn_bkts) return true;
  151. if (this->n_layer_dense_lead != other.n_layer_dense_lead) return true;
  152. @@ -2490,6 +2541,10 @@ struct llama_hparams {
  153. GGML_ABORT("fatal error");
  154. }
  155. +
  156. + bool cross_attention_layer(uint32_t il) const {
  157. + return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end();
  158. + }
  159. };
  160. static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
  161. @@ -2672,6 +2727,16 @@ struct llama_layer {
  162. struct ggml_tensor * ffn_down_scale;
  163. struct ggml_tensor * bskcn_tv;
  164. +
  165. + // cross attention
  166. + struct ggml_tensor * cross_attn_k_norm;
  167. + struct ggml_tensor * cross_attn_k_proj;
  168. + struct ggml_tensor * cross_attn_o_proj;
  169. + struct ggml_tensor * cross_attn_q_norm;
  170. + struct ggml_tensor * cross_attn_q_proj;
  171. + struct ggml_tensor * cross_attn_v_proj;
  172. + struct ggml_tensor * cross_attn_attn_gate;
  173. + struct ggml_tensor * cross_attn_mlp_gate;
  174. };
  175. // very similar to llama_batch,
  176. @@ -2684,12 +2749,12 @@ struct llama_ubatch {
  177. uint32_t n_seq_tokens; // tokens per sequence
  178. uint32_t n_seqs;
  179. - llama_token * token; // [n_tokens]
  180. - float * embd; // [n_embd, n_tokens]
  181. - llama_pos * pos; // [n_tokens]
  182. - int32_t * n_seq_id; // [n_seqs]
  183. - llama_seq_id ** seq_id; // [n_seqs]
  184. - int8_t * output; // [n_tokens]
  185. + llama_token * token; // [n_tokens]
  186. + float * embd; // [n_embd, n_tokens]
  187. + llama_pos * pos; // [n_tokens]
  188. + int32_t * n_seq_id; // [n_seqs]
  189. + llama_seq_id ** seq_id; // [n_seqs]
  190. + int8_t * output; // [n_tokens]
  191. };
  192. struct llama_kv_cell {
  193. @@ -3268,6 +3333,10 @@ struct llama_context {
  194. // host buffer for the model output (logits and embeddings)
  195. ggml_backend_buffer_t buf_output = nullptr;
  196. + // TODO (jmorganca): this should most likely be passed in as part of a batch
  197. + // and not set on the context for all batches.
  198. + float * cross_attn_state = nullptr;
  199. +
  200. // decode output (2-dimensional array: [n_outputs][n_vocab])
  201. size_t logits_size = 0; // capacity (of floats) for logits
  202. float * logits = nullptr;
  203. @@ -3317,6 +3386,11 @@ struct llama_context {
  204. struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
  205. struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
  206. struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
  207. +
  208. + // TODO (jmorganca): this should most likely be passed in via
  209. + // the input. Ideally we remove this state from llama_context
  210. + bool cross_attn_state_first_pass = true;
  211. + struct ggml_tensor * inp_cross_attn_state; // F32 [4, n_embd, 1061]
  212. };
  213. struct llama_lora_weight {
  214. @@ -3543,6 +3617,18 @@ static bool llama_kv_cache_init(
  215. cache.v_l.reserve(n_layer);
  216. for (int i = 0; i < (int) n_layer; i++) {
  217. + // for cross attention layers
  218. + if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layer(i)) {
  219. + struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
  220. + ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_k, 6404, hparams.n_head_kv(i));
  221. + ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_v, 6404, hparams.n_head_kv(i));
  222. + ggml_format_name(k, "cache_k_l%d", i);
  223. + ggml_format_name(v, "cache_v_l%d", i);
  224. + cache.k_l.push_back(k);
  225. + cache.v_l.push_back(v);
  226. + continue;
  227. + }
  228. +
  229. const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
  230. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
  231. @@ -5312,12 +5398,14 @@ static void llm_load_hparams(
  232. }
  233. // zero-out the per-layer hparams
  234. - std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
  235. - std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
  236. - std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0);
  237. + std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
  238. + std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
  239. + std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0);
  240. + std::fill(hparams.cross_attn_layers.begin(), hparams.cross_attn_layers.end(), -1);
  241. - ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer);
  242. - ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer);
  243. + ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer);
  244. + ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer);
  245. + ml.get_arr(LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, hparams.cross_attn_layers, false);
  246. // n_head_kv is optional, default to n_head
  247. hparams.n_head_kv_arr = hparams.n_head_arr;
  248. @@ -5366,7 +5454,7 @@ static void llm_load_hparams(
  249. ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
  250. - if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
  251. + if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_MLLAMA || model.arch == LLM_ARCH_FALCON) {
  252. if (hparams.n_rot != hparams.n_embd_head_k) {
  253. throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
  254. }
  255. @@ -5404,6 +5492,16 @@ static void llm_load_hparams(
  256. }
  257. }
  258. } break;
  259. + case LLM_ARCH_MLLAMA:
  260. + {
  261. + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
  262. +
  263. + switch (hparams.n_layer) {
  264. + case 40: model.type = e_model::MODEL_11B; break;
  265. + case 100: model.type = e_model::MODEL_90B; break;
  266. + default: model.type = e_model::MODEL_UNKNOWN;
  267. + }
  268. + } break;
  269. case LLM_ARCH_MINICPM:
  270. {
  271. ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
  272. @@ -6918,6 +7016,55 @@ static bool llm_load_tensors(
  273. }
  274. }
  275. } break;
  276. + case LLM_ARCH_MLLAMA:
  277. + {
  278. + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab+8});
  279. +
  280. + // output
  281. + {
  282. + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
  283. + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
  284. +
  285. + // if output is NULL, init from the input tok embed
  286. + if (model.output == NULL) {
  287. + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
  288. + }
  289. + }
  290. +
  291. + for (int i = 0; i < n_layer; ++i) {
  292. + ggml_context * ctx_layer = ctx_for_layer(i);
  293. + ggml_context * ctx_split = ctx_for_layer_split(i);
  294. +
  295. + auto & layer = model.layers[i];
  296. +
  297. + if (hparams.cross_attention_layer(i)) {
  298. + layer.cross_attn_k_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_NORM, "weight", i), {128});
  299. + layer.cross_attn_k_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_PROJ, "weight", i), {n_embd, 1024});
  300. + layer.cross_attn_o_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_O_PROJ, "weight", i), {n_embd, n_embd});
  301. + layer.cross_attn_q_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_Q_NORM, "weight", i), {128});
  302. + layer.cross_attn_q_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_Q_PROJ, "weight", i), {n_embd, n_embd});
  303. + layer.cross_attn_v_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_V_PROJ, "weight", i), {n_embd, 1024});
  304. + layer.cross_attn_attn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_ATTN_GATE, i), {1});
  305. + layer.cross_attn_mlp_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_MLP_GATE, i), {1});
  306. + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
  307. + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
  308. + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
  309. + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
  310. + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
  311. + } else {
  312. + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
  313. + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head});
  314. + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa});
  315. + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa});
  316. + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
  317. + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
  318. + layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
  319. + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
  320. + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
  321. + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
  322. + }
  323. + }
  324. + } break;
  325. case LLM_ARCH_GROK:
  326. {
  327. if (n_expert == 0) {
  328. @@ -8678,7 +8825,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
  329. if (model.vocab.type != LLAMA_VOCAB_TYPE_NONE &&
  330. model.hparams.n_vocab != model.vocab.id_to_token.size()) {
  331. - throw std::runtime_error("vocab size mismatch");
  332. + LLAMA_LOG_WARN("%s: vocab mismatch %u !- %zu ...\n", __func__, model.hparams.n_vocab, model.vocab.id_to_token.size());
  333. }
  334. if (params.vocab_only) {
  335. @@ -8754,7 +8901,6 @@ static struct ggml_tensor * llm_build_inp_embd(
  336. if (batch.token) {
  337. lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, batch.n_tokens);
  338. - cb(lctx.inp_tokens, "inp_tokens", -1);
  339. ggml_set_input(lctx.inp_tokens);
  340. inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
  341. @@ -8769,6 +8915,22 @@ static struct ggml_tensor * llm_build_inp_embd(
  342. return inpL;
  343. }
  344. +static struct ggml_tensor * llm_build_inp_cross_attn_state(
  345. + struct ggml_context * ctx,
  346. + struct llama_context & lctx,
  347. + const llama_hparams & hparams,
  348. + const llm_build_cb & cb) {
  349. + const int64_t n_embd = hparams.n_embd;
  350. +
  351. + struct ggml_tensor * inpCAS;
  352. + lctx.inp_cross_attn_state = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, 1601, 4);
  353. + cb(lctx.inp_cross_attn_state, "inp_cross_attn_state", -1);
  354. + ggml_set_input(lctx.inp_cross_attn_state);
  355. + inpCAS = lctx.inp_cross_attn_state;
  356. +
  357. + return inpCAS;
  358. +}
  359. +
  360. static void llm_build_kv_store(
  361. struct ggml_context * ctx,
  362. const llama_hparams & hparams,
  363. @@ -8790,6 +8952,7 @@ static void llm_build_kv_store(
  364. struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa)*kv_head);
  365. cb(k_cache_view, "k_cache_view", il);
  366. + cb(k_cur, "k_cur", il);
  367. // note: storing RoPE-ed version of K in the KV cache
  368. ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
  369. @@ -9625,6 +9788,40 @@ static struct ggml_tensor * llm_build_rwkv6_channel_mix(
  370. return ggml_mul(ctx, r, llm_build_lora_mm(lctx, ctx, layer->channel_mix_value, k));
  371. }
  372. +
  373. +static void show_tensor(std::string name, ggml_tensor *t) {
  374. + LLAMA_LOG_INFO("%s [%lld, %lld]\n", name.c_str(), t->ne[0], t->ne[1]);
  375. +
  376. + int cols = int(t->ne[0]);
  377. + int rows = int(t->ne[1]);
  378. +
  379. + for(int r=0; r<3; r++) {
  380. + for(int c=0; c<3; c++) {
  381. + float v = ggml_get_f32_nd(t, c, r, 0, 0);
  382. + LLAMA_LOG_INFO("%11.8f ", v);
  383. + }
  384. + LLAMA_LOG_INFO("... ");
  385. + for(int c=0; c<3; c++) {
  386. + float v = ggml_get_f32_nd(t, cols-3+c, r, 0, 0);
  387. + LLAMA_LOG_INFO("%11.8f ", v);
  388. + }
  389. + LLAMA_LOG_INFO("\n");
  390. + }
  391. + LLAMA_LOG_INFO(" ...\n");
  392. + for(int r=0; r<3; r++) {
  393. + for(int c=0; c<3; c++) {
  394. + float v = ggml_get_f32_nd(t, c, rows-3+r, 0, 0);
  395. + LLAMA_LOG_INFO("%11.8f ", v);
  396. + }
  397. + LLAMA_LOG_INFO("... ");
  398. + for(int c=0; c<3; c++) {
  399. + float v = ggml_get_f32_nd(t, cols-3+c, rows-3+r, 0, 0);
  400. + LLAMA_LOG_INFO("%11.8f ", v);
  401. + }
  402. + LLAMA_LOG_INFO("\n");
  403. + }
  404. +}
  405. +
  406. struct llm_build_context {
  407. const llama_model & model;
  408. llama_context & lctx;
  409. @@ -9743,6 +9940,7 @@ struct llm_build_context {
  410. lctx.inp_pos_bucket = nullptr;
  411. lctx.inp_embd_enc = nullptr;
  412. lctx.inp_KQ_mask_cross = nullptr;
  413. + lctx.inp_cross_attn_state = nullptr;
  414. }
  415. void free() {
  416. @@ -10158,6 +10356,253 @@ struct llm_build_context {
  417. LLM_NORM_RMS, cb, -1);
  418. cb(cur, "result_norm", -1);
  419. + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
  420. + cb(cur, "result_output", -1);
  421. +
  422. + ggml_build_forward_expand(gf, cur);
  423. +
  424. + return gf;
  425. + }
  426. +
  427. + struct ggml_cgraph * build_mllama() {
  428. + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
  429. +
  430. + // mutable variable, needed during the last layer of the computation to skip unused tokens
  431. + int32_t n_tokens = this->n_tokens;
  432. +
  433. + const int64_t n_embd_head = hparams.n_embd_head_v;
  434. + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
  435. + GGML_ASSERT(n_embd_head == hparams.n_rot);
  436. +
  437. + struct ggml_tensor * cur;
  438. + struct ggml_tensor * inpL;
  439. + struct ggml_tensor * inpCAS;
  440. +
  441. + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
  442. + inpCAS = llm_build_inp_cross_attn_state(ctx0, lctx, hparams, cb);
  443. +
  444. + // inp_pos - contains the positions
  445. + struct ggml_tensor * inp_pos = build_inp_pos();
  446. +
  447. + // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
  448. + struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
  449. +
  450. + for (int il = 0; il < n_layer; ++il) {
  451. + struct ggml_tensor * inpSA = inpL;
  452. +
  453. + // norm
  454. + cur = llm_build_norm(ctx0, inpL, hparams,
  455. + model.layers[il].attn_norm, NULL,
  456. + LLM_NORM_RMS, cb, il);
  457. + cb(cur, "attn_norm", il);
  458. +
  459. + if (hparams.cross_attention_layer(il)) {
  460. + if (!lctx.cross_attn_state) {
  461. + continue;
  462. + }
  463. +
  464. + // cross attention layer
  465. + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_q_proj, cur);
  466. + cb(Qcur, "Qcur", il);
  467. +
  468. + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
  469. + cb(Qcur, "Qcur", il);
  470. +
  471. + Qcur = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
  472. + cb(Qcur, "Qcur", il);
  473. +
  474. + // TODO: is this required?
  475. + Qcur = ggml_cont(ctx0, Qcur);
  476. + cb(Qcur, "Qcur", il);
  477. +
  478. + Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].cross_attn_q_norm, NULL, LLM_NORM_RMS, cb, il);
  479. + cb(Qcur, "Qcur", il);
  480. +
  481. + struct ggml_tensor * Kcur;
  482. + if (lctx.cross_attn_state_first_pass) {
  483. + Kcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_k_proj, inpCAS);
  484. + cb(Kcur, "Kcur", il);
  485. +
  486. + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, 6404);
  487. + cb(Kcur, "Kcur", il);
  488. +
  489. + Kcur = ggml_permute(ctx0, Kcur, 0, 2, 1, 3);
  490. + cb(Kcur, "Kcur", il);
  491. +
  492. + // TODO: is this required?
  493. + Kcur = ggml_cont(ctx0, Kcur);
  494. + cb(Kcur, "Kcur", il);
  495. +
  496. + Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].cross_attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
  497. + cb(Kcur, "Kcur", il);
  498. +
  499. + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, kv_self.k_l[il]));
  500. + } else {
  501. + Kcur = ggml_view_tensor(ctx0, kv_self.k_l[il]);
  502. + cb(Kcur, "Kcur (view)", il);
  503. + }
  504. +
  505. + struct ggml_tensor * Vcur;
  506. + if (lctx.cross_attn_state_first_pass) {
  507. + Vcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_v_proj, inpCAS);
  508. + cb(Vcur, "Vcur", il);
  509. +
  510. + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, 6404);
  511. + cb(Vcur, "Vcur", il);
  512. +
  513. + Vcur = ggml_permute(ctx0, Vcur, 0, 2, 1, 3);
  514. + cb(Vcur, "Vcur", il);
  515. +
  516. + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, kv_self.v_l[il]));
  517. + } else {
  518. + Vcur = ggml_view_tensor(ctx0, kv_self.v_l[il]);
  519. + cb(Vcur, "Vcur (view)", il);
  520. + }
  521. +
  522. + struct ggml_tensor * kq = ggml_mul_mat(ctx0, Kcur, Qcur);
  523. + cb(kq, "kq", il);
  524. +
  525. + kq = ggml_scale_inplace(ctx0, kq, 1.0f/sqrtf(float(n_embd_head)));
  526. + cb(kq, "kq_scaled", il);
  527. +
  528. + // TODO: apply causal masks
  529. + struct ggml_tensor * kq_soft_max = ggml_soft_max_inplace(ctx0, kq);
  530. + cb(kq_soft_max, "kq_soft_max", il);
  531. +
  532. + Vcur = ggml_cont(ctx0, ggml_transpose(ctx0, Vcur));
  533. + cb(Vcur, "Vcur", il);
  534. +
  535. + struct ggml_tensor * kqv = ggml_mul_mat(ctx0, Vcur, kq_soft_max);
  536. + cb(kqv, "kqv", il);
  537. +
  538. + struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
  539. + cb(kqv_merged, "kqv_merged", il);
  540. +
  541. + cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
  542. + cb(cur, "kqv_merged_cont", il);
  543. +
  544. + cur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_o_proj, cur);
  545. + cb(cur, "cur", il);
  546. +
  547. + // TODO: do this in place once?
  548. + cur = ggml_mul(ctx0, cur, ggml_tanh(ctx0, model.layers[il].cross_attn_attn_gate));
  549. +
  550. + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
  551. + cb(ffn_inp, "ffn_inp", il);
  552. +
  553. + // feed-forward network
  554. + cur = llm_build_norm(ctx0, ffn_inp, hparams,
  555. + model.layers[il].ffn_norm, NULL,
  556. + LLM_NORM_RMS, cb, il);
  557. + cb(cur, "ffn_norm", il);
  558. +
  559. + cur = llm_build_ffn(ctx0, lctx, cur,
  560. + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
  561. + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
  562. + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
  563. + NULL,
  564. + LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
  565. + cb(cur, "ffn_out", il);
  566. +
  567. + // TODO: do this inplace once?
  568. + cur = ggml_add_inplace(ctx0, ggml_mul_inplace(ctx0, cur, ggml_tanh(ctx0, model.layers[il].cross_attn_mlp_gate)), ffn_inp);
  569. + cb(cur, "ffn_out", il);
  570. +
  571. + cur = lctx.cvec.apply_to(ctx0, cur, il);
  572. + cb(cur, "l_out", il);
  573. +
  574. + // input for next layer
  575. + inpL = cur;
  576. + } else {
  577. + // self attention layer
  578. +
  579. + // rope freq factors for llama3; may return nullptr for llama2 and other models
  580. + struct ggml_tensor * rope_factors = build_rope_factors(il);
  581. +
  582. + // compute Q and K and RoPE them
  583. + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
  584. + cb(Qcur, "Qcur", il);
  585. + if (model.layers[il].bq) {
  586. + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
  587. + cb(Qcur, "Qcur", il);
  588. + }
  589. +
  590. + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
  591. + cb(Kcur, "Kcur", il);
  592. + if (model.layers[il].bk) {
  593. + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
  594. + cb(Kcur, "Kcur", il);
  595. + }
  596. +
  597. + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
  598. + cb(Vcur, "Vcur", il);
  599. + if (model.layers[il].bv) {
  600. + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
  601. + cb(Vcur, "Vcur", il);
  602. + }
  603. +
  604. + Qcur = ggml_rope_ext(
  605. + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
  606. + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
  607. + ext_factor, attn_factor, beta_fast, beta_slow
  608. + );
  609. + cb(Qcur, "Qcur", il);
  610. +
  611. + Kcur = ggml_rope_ext(
  612. + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
  613. + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
  614. + ext_factor, attn_factor, beta_fast, beta_slow
  615. + );
  616. + cb(Kcur, "Kcur", il);
  617. +
  618. + cur = llm_build_kv(ctx0, lctx, kv_self, gf,
  619. + model.layers[il].wo, model.layers[il].bo,
  620. + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
  621. +
  622. +
  623. + if (il == n_layer - 1) {
  624. + // skip computing output for unused tokens
  625. + struct ggml_tensor * inp_out_ids = build_inp_out_ids();
  626. + n_tokens = n_outputs;
  627. + cur = ggml_get_rows(ctx0, cur, inp_out_ids);
  628. + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
  629. + }
  630. +
  631. + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
  632. + cb(ffn_inp, "ffn_inp", il);
  633. +
  634. + // feed-forward network
  635. + cur = llm_build_norm(ctx0, ffn_inp, hparams,
  636. + model.layers[il].ffn_norm, NULL,
  637. + LLM_NORM_RMS, cb, il);
  638. + cb(cur, "ffn_norm", il);
  639. +
  640. + cur = llm_build_ffn(ctx0, lctx, cur,
  641. + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
  642. + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
  643. + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
  644. + NULL,
  645. + LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
  646. + cb(cur, "ffn_out", il);
  647. +
  648. + cur = ggml_add(ctx0, cur, ffn_inp);
  649. + cb(cur, "ffn_out", il);
  650. +
  651. + cur = lctx.cvec.apply_to(ctx0, cur, il);
  652. + cb(cur, "l_out", il);
  653. +
  654. + // input for next layer
  655. + inpL = cur;
  656. + }
  657. + }
  658. +
  659. + cur = inpL;
  660. +
  661. + cur = llm_build_norm(ctx0, cur, hparams,
  662. + model.output_norm, NULL,
  663. + LLM_NORM_RMS, cb, -1);
  664. + cb(cur, "result_norm", -1);
  665. +
  666. // lm_head
  667. cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
  668. cb(cur, "result_output", -1);
  669. @@ -15493,6 +15938,10 @@ static struct ggml_cgraph * llama_build_graph(
  670. {
  671. result = llm.build_llama();
  672. } break;
  673. + case LLM_ARCH_MLLAMA:
  674. + {
  675. + result = llm.build_mllama();
  676. + } break;
  677. case LLM_ARCH_BAICHUAN:
  678. {
  679. result = llm.build_baichuan();
  680. @@ -15736,7 +16185,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
  681. if (batch.token) {
  682. const int64_t n_tokens = batch.n_tokens;
  683. -
  684. ggml_backend_tensor_set(lctx.inp_tokens, batch.token, 0, n_tokens*ggml_element_size(lctx.inp_tokens));
  685. }
  686. @@ -16123,6 +16571,13 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
  687. }
  688. }
  689. }
  690. +
  691. + // TODO (jmorganca): this might copy a lot of data on every request of a
  692. + // single generation even though it doesn't change, so we should
  693. + // find a way to not set this more than one time per image
  694. + if (lctx.cross_attn_state && lctx.inp_cross_attn_state->buffer) {
  695. + ggml_backend_tensor_set(lctx.inp_cross_attn_state, lctx.cross_attn_state, 0, hparams.n_embd * 1601 * 4 * ggml_element_size(lctx.inp_cross_attn_state));
  696. + }
  697. }
  698. // Make sure enough space is available for outputs.
  699. @@ -16430,6 +16885,10 @@ static int llama_decode_internal(
  700. llama_set_inputs(lctx, ubatch);
  701. + // TODO: replace with something better to find out if its
  702. + // our first actual pass
  703. + lctx.cross_attn_state_first_pass = false;
  704. +
  705. llama_graph_compute(lctx, gf, n_threads, threadpool);
  706. // update the kv ring buffer
  707. @@ -17586,7 +18045,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
  708. if (llama_model_has_encoder(&model)) {
  709. n_attn_layer *= 3;
  710. }
  711. - GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected");
  712. + if (qs.n_attention_wv != n_attn_layer) {
  713. + LLAMA_LOG_WARN("%s: n_attention_wv is unexpected, expected: %d, found: %d\n", __func__, n_attn_layer, qs.n_attention_wv);
  714. + }
  715. }
  716. size_t total_size_org = 0;
  717. @@ -18681,6 +19142,18 @@ struct llama_context * llama_new_context_with_model(
  718. return ctx;
  719. }
  720. +void llama_set_cross_attn_state(struct llama_context * ctx, float * cross_attn_state) {
  721. + ctx->cross_attn_state = cross_attn_state;
  722. +}
  723. +
  724. +void llama_reset_cross_attn_state(struct llama_context * ctx) {
  725. + ctx->cross_attn_state_first_pass = true;
  726. + if (ctx->cross_attn_state) {
  727. + free(ctx->cross_attn_state);
  728. + ctx->cross_attn_state = nullptr;
  729. + }
  730. +}
  731. +
  732. void llama_free(struct llama_context * ctx) {
  733. delete ctx;
  734. }
  735. @@ -18731,6 +19204,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
  736. // use what we call a normal RoPE, operating on pairs of consecutive head values
  737. case LLM_ARCH_LLAMA:
  738. + case LLM_ARCH_MLLAMA:
  739. case LLM_ARCH_BAICHUAN:
  740. case LLM_ARCH_STARCODER:
  741. case LLM_ARCH_PLAMO:
  742. --
  743. 2.39.3 (Apple Git-146)