0005-solar-pro.patch 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
  2. From: Michael Yang <mxyng@pm.me>
  3. Date: Mon, 16 Sep 2024 15:53:16 -0700
  4. Subject: [PATCH] solar-pro
  5. solar-pro introduces block skip connections where blocks are connected
  6. to other, non-sequential blocks with a scale multiple
  7. this change adds 4 new keys to store the skip connections and one new
  8. tensor to store the scalar. the scalar is implemented a 1-dimensional
  9. tensor with 2 elements dervied from the model's bskcn_tv configuration.
  10. in general, the values are (bskcn_tv, 1 - bskcn_tv)
  11. ---
  12. src/llama-arch.cpp | 21 +++++
  13. src/llama-arch.h | 3 +
  14. src/llama-hparams.cpp | 8 ++
  15. src/llama-hparams.h | 5 ++
  16. src/llama-model-loader.cpp | 1 +
  17. src/llama-model.cpp | 44 +++++++++++
  18. src/llama-model.h | 3 +
  19. src/llama.cpp | 152 ++++++++++++++++++++++++++++++++++++-
  20. 8 files changed, 236 insertions(+), 1 deletion(-)
  21. diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
  22. index 97a1e7e5..a1e0ebcc 100644
  23. --- a/src/llama-arch.cpp
  24. +++ b/src/llama-arch.cpp
  25. @@ -61,6 +61,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
  26. { LLM_ARCH_GRANITE, "granite" },
  27. { LLM_ARCH_GRANITE_MOE, "granitemoe" },
  28. { LLM_ARCH_CHAMELEON, "chameleon" },
  29. + { LLM_ARCH_SOLAR, "solar" },
  30. { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
  31. { LLM_ARCH_UNKNOWN, "(unknown)" },
  32. };
  33. @@ -125,6 +126,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
  34. { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
  35. { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
  36. { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
  37. + { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, "%s.attention.block_skip_connection" },
  38. { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
  39. { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
  40. @@ -1271,6 +1273,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
  41. { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
  42. },
  43. },
  44. + {
  45. + LLM_ARCH_SOLAR,
  46. + {
  47. + { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
  48. + { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
  49. + { LLM_TENSOR_OUTPUT, "output" },
  50. + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
  51. + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
  52. + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
  53. + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
  54. + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
  55. + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
  56. + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
  57. + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
  58. + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
  59. + { LLM_TENSOR_BSKCN_TV, "bskcn_tv" },
  60. + },
  61. + },
  62. {
  63. LLM_ARCH_WAVTOKENIZER_DEC,
  64. {
  65. @@ -1429,6 +1449,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
  66. {LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
  67. // this tensor is loaded for T5, but never used
  68. {LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
  69. + {LLM_TENSOR_BSKCN_TV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
  70. {LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}},
  71. {LLM_TENSOR_POS_NET_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
  72. {LLM_TENSOR_POS_NET_NORM1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
  73. diff --git a/src/llama-arch.h b/src/llama-arch.h
  74. index 122fdceb..77919578 100644
  75. --- a/src/llama-arch.h
  76. +++ b/src/llama-arch.h
  77. @@ -65,6 +65,7 @@ enum llm_arch {
  78. LLM_ARCH_GRANITE,
  79. LLM_ARCH_GRANITE_MOE,
  80. LLM_ARCH_CHAMELEON,
  81. + LLM_ARCH_SOLAR,
  82. LLM_ARCH_WAVTOKENIZER_DEC,
  83. LLM_ARCH_UNKNOWN,
  84. };
  85. @@ -129,6 +130,7 @@ enum llm_kv {
  86. LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
  87. LLM_KV_ATTENTION_SLIDING_WINDOW,
  88. LLM_KV_ATTENTION_SCALE,
  89. + LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,
  90. LLM_KV_ROPE_DIMENSION_COUNT,
  91. LLM_KV_ROPE_DIMENSION_SECTIONS,
  92. @@ -311,6 +313,7 @@ enum llm_tensor {
  93. LLM_TENSOR_ENC_OUTPUT_NORM,
  94. LLM_TENSOR_CLS,
  95. LLM_TENSOR_CLS_OUT,
  96. + LLM_TENSOR_BSKCN_TV,
  97. LLM_TENSOR_CONV1D,
  98. LLM_TENSOR_CONVNEXT_DW,
  99. LLM_TENSOR_CONVNEXT_NORM,
  100. diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp
  101. index ea87b295..f3955de9 100644
  102. --- a/src/llama-hparams.cpp
  103. +++ b/src/llama-hparams.cpp
  104. @@ -69,3 +69,11 @@ uint32_t llama_hparams::n_embd_v_s() const {
  105. // corresponds to Mamba's ssm_states size
  106. return ssm_d_state * ssm_d_inner;
  107. }
  108. +
  109. +bool llama_hparams::n_bskcn(uint32_t n, uint32_t il) const {
  110. + if (il < n_layer) {
  111. + return n_bskcn_arr[n][il] > 0;
  112. + }
  113. +
  114. + GGML_ABORT("fatal error");
  115. +}
  116. \ No newline at end of file
  117. diff --git a/src/llama-hparams.h b/src/llama-hparams.h
  118. index 1fe45410..1bdcdfd5 100644
  119. --- a/src/llama-hparams.h
  120. +++ b/src/llama-hparams.h
  121. @@ -50,6 +50,8 @@ struct llama_hparams {
  122. std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_kv_arr;
  123. std::array<uint32_t, LLAMA_MAX_LAYERS> n_ff_arr;
  124. + std::array<std::array<uint32_t, LLAMA_MAX_LAYERS>, 4> n_bskcn_arr = {};
  125. +
  126. uint32_t n_layer_dense_lead = 0;
  127. uint32_t n_lora_q = 0;
  128. uint32_t n_lora_kv = 0;
  129. @@ -133,6 +135,9 @@ struct llama_hparams {
  130. // dimension of the recurrent state embeddings
  131. uint32_t n_embd_v_s() const;
  132. +
  133. + // Block skip connection
  134. + bool n_bskcn(uint32_t n, uint32_t il) const;
  135. };
  136. static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
  137. diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp
  138. index 05d58ad9..1252aca1 100644
  139. --- a/src/llama-model-loader.cpp
  140. +++ b/src/llama-model-loader.cpp
  141. @@ -439,6 +439,7 @@ namespace GGUFMeta {
  142. // TODO: this is not very clever - figure out something better
  143. template bool llama_model_loader::get_key_or_arr<std::array<int, 4>>(enum llm_kv kid, std::array<int, 4> & result, uint32_t n, bool required);
  144. template bool llama_model_loader::get_key_or_arr<std::array<uint32_t, 512>>(enum llm_kv kid, std::array<uint32_t, 512> & result, uint32_t n, bool required);
  145. + template bool llama_model_loader::get_key_or_arr<uint32_t>(const std::string & key, std::array<uint32_t, 512> & result, uint32_t n, bool required);
  146. llama_model_loader::llama_model_loader(
  147. const std::string & fname,
  148. diff --git a/src/llama-model.cpp b/src/llama-model.cpp
  149. index 36a0a009..ad1315c6 100644
  150. --- a/src/llama-model.cpp
  151. +++ b/src/llama-model.cpp
  152. @@ -1238,6 +1238,21 @@ void llama_model::load_hparams(llama_model_loader & ml) {
  153. default: type = LLM_TYPE_UNKNOWN;
  154. }
  155. } break;
  156. + case LLM_ARCH_SOLAR:
  157. + {
  158. + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
  159. + for (size_t i = 0; i < hparams.n_bskcn_arr.max_size(); ++i) {
  160. + auto & bskcn = hparams.n_bskcn_arr[i];
  161. + bskcn.fill(0);
  162. + auto kv = LLM_KV(arch);
  163. + ml.get_key_or_arr(format((kv(LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION) + ".%d").c_str(), i), bskcn, hparams.n_layer, false);
  164. + }
  165. +
  166. + switch (hparams.n_layer) {
  167. + case 64: type = LLM_TYPE_22B; break;
  168. + default: type = LLM_TYPE_UNKNOWN;
  169. + }
  170. + } break;
  171. case LLM_ARCH_WAVTOKENIZER_DEC:
  172. {
  173. ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
  174. @@ -3316,6 +3331,34 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
  175. layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
  176. + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
  177. + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
  178. + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
  179. + }
  180. + } break;
  181. + case LLM_ARCH_SOLAR:
  182. + {
  183. + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
  184. +
  185. + // output
  186. + {
  187. + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
  188. + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
  189. + }
  190. +
  191. + for (int i = 0; i < n_layer; ++i) {
  192. + auto & layer = layers[i];
  193. +
  194. + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
  195. +
  196. + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
  197. + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
  198. + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
  199. + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
  200. +
  201. + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
  202. +
  203. + layer.bskcn_tv = create_tensor(tn(LLM_TENSOR_BSKCN_TV, "weight", i), {2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
  204. layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
  205. layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
  206. layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
  207. @@ -3900,6 +3943,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
  208. case LLM_ARCH_GRANITE:
  209. case LLM_ARCH_GRANITE_MOE:
  210. case LLM_ARCH_CHAMELEON:
  211. + case LLM_ARCH_SOLAR:
  212. return LLAMA_ROPE_TYPE_NORM;
  213. // the pairs of head values are offset by n_rot/2
  214. diff --git a/src/llama-model.h b/src/llama-model.h
  215. index a7c30444..1afb0024 100644
  216. --- a/src/llama-model.h
  217. +++ b/src/llama-model.h
  218. @@ -55,6 +55,7 @@ enum llm_type {
  219. LLM_TYPE_15B,
  220. LLM_TYPE_16B,
  221. LLM_TYPE_20B,
  222. + LLM_TYPE_22B,
  223. LLM_TYPE_30B,
  224. LLM_TYPE_32B,
  225. LLM_TYPE_34B,
  226. @@ -281,6 +282,8 @@ struct llama_layer {
  227. struct ggml_tensor * ffn_up_scale = nullptr;
  228. struct ggml_tensor * ffn_down_scale = nullptr;
  229. + struct ggml_tensor * bskcn_tv = nullptr;
  230. +
  231. struct llama_layer_posnet posnet;
  232. struct llama_layer_convnext convnext;
  233. diff --git a/src/llama.cpp b/src/llama.cpp
  234. index ac85bfed..6d320ea4 100644
  235. --- a/src/llama.cpp
  236. +++ b/src/llama.cpp
  237. @@ -7953,9 +7953,155 @@ struct llm_build_context {
  238. cb(img_logits, "img_logits", -1);
  239. cur = ggml_set_1d(ctx0, cur, img_logits, ggml_element_size(cur) * img_token_start_idx);
  240. cb(cur, "result_output", -1);
  241. -
  242. ggml_build_forward_expand(gf, cur);
  243. + return gf;
  244. + }
  245. +
  246. + ggml_cgraph * build_solar() {
  247. + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
  248. +
  249. + // mutable variable, needed during the last layer of the computation to skip unused tokens
  250. + int32_t n_tokens = this->n_tokens;
  251. +
  252. + const int64_t n_embd_head = hparams.n_embd_head_v;
  253. + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
  254. + GGML_ASSERT(n_embd_head == hparams.n_rot);
  255. +
  256. + struct ggml_tensor * cur;
  257. + struct ggml_tensor * inpL;
  258. +
  259. + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
  260. +
  261. + // inp_pos - contains the positions
  262. + struct ggml_tensor * inp_pos = build_inp_pos();
  263. +
  264. + // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
  265. + struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
  266. +
  267. + struct ggml_tensor * bskcn_1;
  268. + struct ggml_tensor * bskcn_2;
  269. +
  270. + for (int il = 0; il < n_layer; ++il) {
  271. + struct ggml_tensor * inpSA = inpL;
  272. +
  273. + if (hparams.n_bskcn(0, il)) {
  274. + bskcn_1 = inpSA;
  275. + }
  276. +
  277. + if (hparams.n_bskcn(1, il)) {
  278. + bskcn_2 = inpSA;
  279. + }
  280. +
  281. + if (hparams.n_bskcn(2, il)) {
  282. + inpSA = ggml_add(
  283. + ctx0,
  284. + ggml_mul(ctx0, bskcn_1, ggml_view_1d(ctx0, model.layers[il].bskcn_tv, 1, 0)),
  285. + ggml_mul(ctx0, inpSA, ggml_view_1d(ctx0, model.layers[il].bskcn_tv, 1, ggml_element_size(model.layers[il].bskcn_tv))));
  286. + }
  287. +
  288. + if (hparams.n_bskcn(3, il)) {
  289. + inpSA = ggml_add(
  290. + ctx0,
  291. + ggml_mul(ctx0, bskcn_2, ggml_view_1d(ctx0, model.layers[il].bskcn_tv, 1, 0)),
  292. + ggml_mul(ctx0, inpSA, ggml_view_1d(ctx0, model.layers[il].bskcn_tv, 1, ggml_element_size(model.layers[il].bskcn_tv))));
  293. + }
  294. + // norm
  295. + cur = llm_build_norm(ctx0, inpL, hparams,
  296. + model.layers[il].attn_norm, NULL,
  297. + LLM_NORM_RMS, cb, il);
  298. + cb(cur, "attn_norm", il);
  299. +
  300. + // self-attention
  301. + {
  302. + // rope freq factors for llama3; may return nullptr for llama2 and other models
  303. + struct ggml_tensor * rope_factors = build_rope_factors(il);
  304. +
  305. + // compute Q and K and RoPE them
  306. + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
  307. + cb(Qcur, "Qcur", il);
  308. + if (model.layers[il].bq) {
  309. + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
  310. + cb(Qcur, "Qcur", il);
  311. + }
  312. +
  313. + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
  314. + cb(Kcur, "Kcur", il);
  315. + if (model.layers[il].bk) {
  316. + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
  317. + cb(Kcur, "Kcur", il);
  318. + }
  319. +
  320. + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
  321. + cb(Vcur, "Vcur", il);
  322. + if (model.layers[il].bv) {
  323. + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
  324. + cb(Vcur, "Vcur", il);
  325. + }
  326. +
  327. + Qcur = ggml_rope_ext(
  328. + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
  329. + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
  330. + ext_factor, attn_factor, beta_fast, beta_slow
  331. + );
  332. + cb(Qcur, "Qcur", il);
  333. +
  334. + Kcur = ggml_rope_ext(
  335. + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
  336. + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
  337. + ext_factor, attn_factor, beta_fast, beta_slow
  338. + );
  339. + cb(Kcur, "Kcur", il);
  340. +
  341. + cur = llm_build_kv(ctx0, lctx, kv_self, gf,
  342. + model.layers[il].wo, model.layers[il].bo,
  343. + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
  344. + }
  345. +
  346. + if (il == n_layer - 1) {
  347. + // skip computing output for unused tokens
  348. + struct ggml_tensor * inp_out_ids = build_inp_out_ids();
  349. + n_tokens = n_outputs;
  350. + cur = ggml_get_rows(ctx0, cur, inp_out_ids);
  351. + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
  352. + }
  353. +
  354. + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
  355. + cb(ffn_inp, "ffn_inp", il);
  356. +
  357. + // feed-forward network
  358. + cur = llm_build_norm(ctx0, ffn_inp, hparams,
  359. + model.layers[il].ffn_norm, NULL,
  360. + LLM_NORM_RMS, cb, il);
  361. + cb(cur, "ffn_norm", il);
  362. +
  363. + cur = llm_build_ffn(ctx0, lctx, cur,
  364. + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
  365. + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
  366. + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
  367. + NULL,
  368. + LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
  369. + cb(cur, "ffn_out", il);
  370. +
  371. + cur = ggml_add(ctx0, cur, ffn_inp);
  372. + cb(cur, "ffn_out", il);
  373. +
  374. + cur = lctx.cvec.apply_to(ctx0, cur, il);
  375. + cb(cur, "l_out", il);
  376. +
  377. + // input for next layer
  378. + inpL = cur;
  379. + }
  380. +
  381. + cur = inpL;
  382. + cur = llm_build_norm(ctx0, cur, hparams,
  383. + model.output_norm, NULL,
  384. + LLM_NORM_RMS, cb, -1);
  385. + cb(cur, "result_norm", -1);
  386. + // lm_head
  387. + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
  388. + cb(cur, "result_output", -1);
  389. + ggml_build_forward_expand(gf, cur);
  390. return gf;
  391. }
  392. @@ -8398,6 +8544,10 @@ static struct ggml_cgraph * llama_build_graph(
  393. {
  394. result = llm.build_chameleon();
  395. } break;
  396. + case LLM_ARCH_SOLAR:
  397. + {
  398. + result = llm.build_solar();
  399. + } break;
  400. case LLM_ARCH_WAVTOKENIZER_DEC:
  401. {
  402. result = llm.build_wavtokenizer_dec();