llama-model.h 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. /**
  2. * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
  3. *
  4. * MIT License
  5. *
  6. * Copyright (c) 2023-2024 The ggml authors
  7. *
  8. * Permission is hereby granted, free of charge, to any person obtaining a copy
  9. * of this software and associated documentation files (the "Software"), to deal
  10. * in the Software without restriction, including without limitation the rights
  11. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  12. * copies of the Software, and to permit persons to whom the Software is
  13. * furnished to do so, subject to the following conditions:
  14. *
  15. * The above copyright notice and this permission notice shall be included in all
  16. * copies or substantial portions of the Software.
  17. *
  18. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  19. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  20. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  21. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  22. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  23. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  24. * SOFTWARE.
  25. */
  26. #pragma once
  27. #include "llama.h"
  28. #include "llama-arch.h"
  29. #include "llama-hparams.h"
  30. #include "llama-vocab.h"
  31. #include "llama-mmap.h"
  32. #include "ggml-cpp.h"
  33. #include <vector>
  34. #include <stdexcept>
  35. // available models
  36. // TODO: this enum does not follow the enum naming convention
  37. enum llm_type {
  38. MODEL_UNKNOWN,
  39. MODEL_14M,
  40. MODEL_17M,
  41. MODEL_22M,
  42. MODEL_33M,
  43. MODEL_60M,
  44. MODEL_70M,
  45. MODEL_80M,
  46. MODEL_109M,
  47. MODEL_137M,
  48. MODEL_160M,
  49. MODEL_220M,
  50. MODEL_250M,
  51. MODEL_270M,
  52. MODEL_335M,
  53. MODEL_410M,
  54. MODEL_450M,
  55. MODEL_770M,
  56. MODEL_780M,
  57. MODEL_0_5B,
  58. MODEL_1B,
  59. MODEL_1_3B,
  60. MODEL_1_4B,
  61. MODEL_1_5B,
  62. MODEL_1_6B,
  63. MODEL_2B,
  64. MODEL_2_8B,
  65. MODEL_3B,
  66. MODEL_4B,
  67. MODEL_6B,
  68. MODEL_6_9B,
  69. MODEL_7B,
  70. MODEL_8B,
  71. MODEL_9B,
  72. MODEL_11B,
  73. MODEL_12B,
  74. MODEL_13B,
  75. MODEL_14B,
  76. MODEL_15B,
  77. MODEL_16B,
  78. MODEL_20B,
  79. MODEL_22B,
  80. MODEL_30B,
  81. MODEL_32B,
  82. MODEL_34B,
  83. MODEL_35B,
  84. MODEL_40B,
  85. MODEL_65B,
  86. MODEL_70B,
  87. MODEL_90B,
  88. MODEL_236B,
  89. MODEL_314B,
  90. MODEL_671B,
  91. MODEL_SMALL,
  92. MODEL_MEDIUM,
  93. MODEL_LARGE,
  94. MODEL_XL,
  95. MODEL_A1_7B,
  96. MODEL_A2_7B,
  97. MODEL_8x7B,
  98. MODEL_8x22B,
  99. MODEL_16x12B,
  100. MODEL_10B_128x3_66B,
  101. MODEL_57B_A14B,
  102. MODEL_27B,
  103. };
  104. struct llama_layer_posnet {
  105. // resnet
  106. struct ggml_tensor * norm1 = nullptr;
  107. struct ggml_tensor * norm1_b = nullptr;
  108. struct ggml_tensor * conv1 = nullptr;
  109. struct ggml_tensor * conv1_b = nullptr;
  110. struct ggml_tensor * norm2 = nullptr;
  111. struct ggml_tensor * norm2_b = nullptr;
  112. struct ggml_tensor * conv2 = nullptr;
  113. struct ggml_tensor * conv2_b = nullptr;
  114. // attention
  115. struct ggml_tensor * attn_norm = nullptr;
  116. struct ggml_tensor * attn_norm_b = nullptr;
  117. struct ggml_tensor * attn_q = nullptr;
  118. struct ggml_tensor * attn_q_b = nullptr;
  119. struct ggml_tensor * attn_k = nullptr;
  120. struct ggml_tensor * attn_k_b = nullptr;
  121. struct ggml_tensor * attn_v = nullptr;
  122. struct ggml_tensor * attn_v_b = nullptr;
  123. struct ggml_tensor * attn_o = nullptr;
  124. struct ggml_tensor * attn_o_b = nullptr;
  125. // normalize
  126. struct ggml_tensor * norm = nullptr;
  127. struct ggml_tensor * norm_b = nullptr;
  128. };
  129. struct llama_layer_convnext {
  130. struct ggml_tensor * dw = nullptr;
  131. struct ggml_tensor * dw_b = nullptr;
  132. struct ggml_tensor * norm = nullptr;
  133. struct ggml_tensor * norm_b = nullptr;
  134. struct ggml_tensor * pw1 = nullptr;
  135. struct ggml_tensor * pw1_b = nullptr;
  136. struct ggml_tensor * pw2 = nullptr;
  137. struct ggml_tensor * pw2_b = nullptr;
  138. struct ggml_tensor * gamma = nullptr;
  139. };
  140. struct llama_layer {
  141. // normalization
  142. struct ggml_tensor * attn_norm = nullptr;
  143. struct ggml_tensor * attn_norm_b = nullptr;
  144. struct ggml_tensor * attn_norm_2 = nullptr;
  145. struct ggml_tensor * attn_norm_2_b = nullptr;
  146. struct ggml_tensor * attn_q_norm = nullptr;
  147. struct ggml_tensor * attn_q_norm_b = nullptr;
  148. struct ggml_tensor * attn_k_norm = nullptr;
  149. struct ggml_tensor * attn_k_norm_b = nullptr;
  150. struct ggml_tensor * attn_out_norm = nullptr;
  151. struct ggml_tensor * attn_out_norm_b = nullptr;
  152. struct ggml_tensor * attn_q_a_norm = nullptr;
  153. struct ggml_tensor * attn_kv_a_norm = nullptr;
  154. struct ggml_tensor * attn_sub_norm = nullptr;
  155. struct ggml_tensor * attn_post_norm = nullptr;
  156. struct ggml_tensor * ffn_sub_norm = nullptr;
  157. struct ggml_tensor * attn_norm_cross = nullptr;
  158. struct ggml_tensor * attn_norm_enc = nullptr;
  159. // attention
  160. struct ggml_tensor * wq = nullptr;
  161. struct ggml_tensor * wk = nullptr;
  162. struct ggml_tensor * wv = nullptr;
  163. struct ggml_tensor * wo = nullptr;
  164. struct ggml_tensor * wqkv = nullptr;
  165. struct ggml_tensor * wq_a = nullptr;
  166. struct ggml_tensor * wq_b = nullptr;
  167. struct ggml_tensor * wkv_a_mqa = nullptr;
  168. struct ggml_tensor * wkv_b = nullptr;
  169. struct ggml_tensor * wq_cross = nullptr;
  170. struct ggml_tensor * wk_cross = nullptr;
  171. struct ggml_tensor * wv_cross = nullptr;
  172. struct ggml_tensor * wo_cross = nullptr;
  173. struct ggml_tensor * wq_enc = nullptr;
  174. struct ggml_tensor * wk_enc = nullptr;
  175. struct ggml_tensor * wv_enc = nullptr;
  176. struct ggml_tensor * wo_enc = nullptr;
  177. // attention bias
  178. struct ggml_tensor * bq = nullptr;
  179. struct ggml_tensor * bk = nullptr;
  180. struct ggml_tensor * bv = nullptr;
  181. struct ggml_tensor * bo = nullptr;
  182. struct ggml_tensor * bqkv = nullptr;
  183. // relative position bias
  184. struct ggml_tensor * attn_rel_b = nullptr;
  185. struct ggml_tensor * attn_rel_b_enc = nullptr;
  186. struct ggml_tensor * attn_rel_b_cross = nullptr;
  187. // normalization
  188. struct ggml_tensor * ffn_norm = nullptr;
  189. struct ggml_tensor * ffn_norm_b = nullptr;
  190. struct ggml_tensor * ffn_post_norm = nullptr;
  191. struct ggml_tensor * layer_out_norm = nullptr;
  192. struct ggml_tensor * layer_out_norm_b = nullptr;
  193. struct ggml_tensor * ffn_norm_exps = nullptr;
  194. struct ggml_tensor * ffn_norm_enc = nullptr;
  195. // ff
  196. struct ggml_tensor * ffn_gate = nullptr; // w1
  197. struct ggml_tensor * ffn_down = nullptr; // w2
  198. struct ggml_tensor * ffn_up = nullptr; // w3
  199. struct ggml_tensor * ffn_gate_enc = nullptr;
  200. struct ggml_tensor * ffn_down_enc = nullptr;
  201. struct ggml_tensor * ffn_up_enc = nullptr;
  202. // ff MoE
  203. struct ggml_tensor * ffn_gate_inp = nullptr;
  204. struct ggml_tensor * ffn_gate_exps = nullptr;
  205. struct ggml_tensor * ffn_down_exps = nullptr;
  206. struct ggml_tensor * ffn_up_exps = nullptr;
  207. // ff shared expert (shexp)
  208. struct ggml_tensor * ffn_gate_inp_shexp = nullptr;
  209. struct ggml_tensor * ffn_gate_shexp = nullptr;
  210. struct ggml_tensor * ffn_down_shexp = nullptr;
  211. struct ggml_tensor * ffn_up_shexp = nullptr;
  212. // ff bias
  213. struct ggml_tensor * ffn_gate_b = nullptr;
  214. struct ggml_tensor * ffn_down_b = nullptr; // b2
  215. struct ggml_tensor * ffn_up_b = nullptr; // b3
  216. struct ggml_tensor * ffn_act = nullptr;
  217. struct ggml_tensor * ffn_exp_probs_b = nullptr;
  218. // mamba proj
  219. struct ggml_tensor * ssm_in = nullptr;
  220. struct ggml_tensor * ssm_x = nullptr;
  221. struct ggml_tensor * ssm_dt = nullptr;
  222. struct ggml_tensor * ssm_out = nullptr;
  223. // mamba
  224. struct ggml_tensor * ssm_conv1d = nullptr;
  225. struct ggml_tensor * ssm_a = nullptr;
  226. struct ggml_tensor * ssm_d = nullptr;
  227. // mamba bias
  228. struct ggml_tensor * ssm_conv1d_b = nullptr;
  229. struct ggml_tensor * ssm_dt_b = nullptr;
  230. // rwkv
  231. struct ggml_tensor * time_mix_w1 = nullptr;
  232. struct ggml_tensor * time_mix_w2 = nullptr;
  233. struct ggml_tensor * time_mix_lerp_x = nullptr;
  234. struct ggml_tensor * time_mix_lerp_w = nullptr;
  235. struct ggml_tensor * time_mix_lerp_k = nullptr;
  236. struct ggml_tensor * time_mix_lerp_v = nullptr;
  237. struct ggml_tensor * time_mix_lerp_r = nullptr;
  238. struct ggml_tensor * time_mix_lerp_g = nullptr;
  239. struct ggml_tensor * time_mix_first = nullptr;
  240. struct ggml_tensor * time_mix_decay = nullptr;
  241. struct ggml_tensor * time_mix_decay_w1 = nullptr;
  242. struct ggml_tensor * time_mix_decay_w2 = nullptr;
  243. struct ggml_tensor * time_mix_key = nullptr;
  244. struct ggml_tensor * time_mix_value = nullptr;
  245. struct ggml_tensor * time_mix_receptance = nullptr;
  246. struct ggml_tensor * time_mix_gate = nullptr;
  247. struct ggml_tensor * time_mix_ln = nullptr;
  248. struct ggml_tensor * time_mix_ln_b = nullptr;
  249. struct ggml_tensor * time_mix_output = nullptr;
  250. struct ggml_tensor * channel_mix_lerp_k = nullptr;
  251. struct ggml_tensor * channel_mix_lerp_r = nullptr;
  252. struct ggml_tensor * channel_mix_key = nullptr;
  253. struct ggml_tensor * channel_mix_receptance = nullptr;
  254. struct ggml_tensor * channel_mix_value = nullptr;
  255. // long rope factors
  256. struct ggml_tensor * rope_long = nullptr;
  257. struct ggml_tensor * rope_short = nullptr;
  258. struct ggml_tensor * rope_freqs = nullptr;
  259. // bitnet scale
  260. struct ggml_tensor * wq_scale = nullptr;
  261. struct ggml_tensor * wk_scale = nullptr;
  262. struct ggml_tensor * wv_scale = nullptr;
  263. struct ggml_tensor * wo_scale = nullptr;
  264. struct ggml_tensor * ffn_gate_scale = nullptr;
  265. struct ggml_tensor * ffn_up_scale = nullptr;
  266. struct ggml_tensor * ffn_down_scale = nullptr;
  267. struct ggml_tensor * bskcn_tv = nullptr;
  268. // cross attention
  269. struct ggml_tensor * cross_attn_k_norm = nullptr;
  270. struct ggml_tensor * cross_attn_k_proj = nullptr;
  271. struct ggml_tensor * cross_attn_o_proj = nullptr;
  272. struct ggml_tensor * cross_attn_q_norm = nullptr;
  273. struct ggml_tensor * cross_attn_q_proj = nullptr;
  274. struct ggml_tensor * cross_attn_v_proj = nullptr;
  275. struct ggml_tensor * cross_attn_attn_gate = nullptr;
  276. struct ggml_tensor * cross_attn_mlp_gate = nullptr;
  277. struct llama_layer_posnet posnet;
  278. struct llama_layer_convnext convnext;
  279. };
  280. struct llama_model {
  281. llm_type type = MODEL_UNKNOWN;
  282. llm_arch arch = LLM_ARCH_UNKNOWN;
  283. llama_ftype ftype = LLAMA_FTYPE_ALL_F32;
  284. std::string name = "n/a";
  285. llama_hparams hparams = {};
  286. llama_vocab vocab;
  287. struct ggml_tensor * tok_embd = nullptr;
  288. struct ggml_tensor * type_embd = nullptr;
  289. struct ggml_tensor * pos_embd = nullptr;
  290. struct ggml_tensor * tok_norm = nullptr;
  291. struct ggml_tensor * tok_norm_b = nullptr;
  292. struct ggml_tensor * output_norm = nullptr;
  293. struct ggml_tensor * output_norm_b = nullptr;
  294. struct ggml_tensor * output = nullptr;
  295. struct ggml_tensor * output_b = nullptr;
  296. struct ggml_tensor * output_norm_enc = nullptr;
  297. // classifier
  298. struct ggml_tensor * cls = nullptr;
  299. struct ggml_tensor * cls_b = nullptr;
  300. struct ggml_tensor * cls_out = nullptr;
  301. struct ggml_tensor * cls_out_b = nullptr;
  302. struct ggml_tensor * conv1d = nullptr;
  303. struct ggml_tensor * conv1d_b = nullptr;
  304. std::vector<llama_layer> layers;
  305. // gguf metadata
  306. std::unordered_map<std::string, std::string> gguf_kv;
  307. llama_split_mode split_mode;
  308. int main_gpu;
  309. int n_gpu_layers;
  310. std::vector<std::string> rpc_servers;
  311. // list of devices used in this model
  312. std::vector<ggml_backend_dev_t> devices;
  313. // lists of buffer types used for each layer
  314. using buft_list_t = std::vector<std::pair<ggml_backend_dev_t, ggml_backend_buffer_type_t>>;
  315. buft_list_t cpu_buft_list;
  316. std::map<ggml_backend_dev_t, buft_list_t> gpu_buft_list;
  317. struct layer_dev {
  318. ggml_backend_dev_t dev;
  319. buft_list_t * buft_list;
  320. };
  321. layer_dev dev_input = {};
  322. layer_dev dev_output = {};
  323. std::vector<layer_dev> dev_layer;
  324. // contexts where the model tensors metadata is stored
  325. std::vector<ggml_context_ptr> ctxs;
  326. // the model memory buffers for the tensor data
  327. std::vector<ggml_backend_buffer_ptr> bufs;
  328. // model memory mapped files
  329. llama_mmaps mappings;
  330. // objects representing data potentially being locked in memory
  331. llama_mlocks mlock_bufs;
  332. llama_mlocks mlock_mmaps;
  333. // for quantize-stats only
  334. std::vector<std::pair<std::string, struct ggml_tensor *>> tensors_by_name;
  335. int64_t t_load_us = 0;
  336. int64_t t_start_us = 0;
  337. // total number of parameters in the model
  338. uint64_t n_elements = 0;
  339. // total size of all the tensors in the model in bytes
  340. size_t n_bytes = 0;
  341. };
  342. const char * llm_type_name(llm_type type);
  343. std::string llama_model_arch_name (const llama_model & model);
  344. std::string llama_model_type_name (const llama_model & model);
  345. std::string llama_model_ftype_name(const llama_model & model);
  346. template<typename F>
  347. bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, F & fn) {
  348. ggml_init_params params = {
  349. /*.mem_size =*/ ggml_tensor_overhead()*8,
  350. /*.mem_buffer =*/ NULL,
  351. /*.no_alloc =*/ true,
  352. };
  353. ggml_context_ptr ctx { ggml_init(params) };
  354. if (!ctx) {
  355. throw std::runtime_error("failed to create ggml context");
  356. }
  357. ggml_backend_buffer_ptr buf { ggml_backend_buft_alloc_buffer(buft, 0) };
  358. ggml_tensor * op_tensor = fn(ctx.get());
  359. for (int i = 0; i < GGML_MAX_SRC; i++) {
  360. if (op_tensor->src[i] != nullptr) {
  361. op_tensor->src[i]->buffer = buf.get();
  362. }
  363. }
  364. bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
  365. return op_supported;
  366. }
  367. template<typename F>
  368. ggml_backend_buffer_type_t select_buft(const llama_model::buft_list_t & buft_list, const F & fn) {
  369. for (const auto & cur : buft_list) {
  370. ggml_backend_dev_t cur_dev = cur.first;
  371. ggml_backend_buffer_type_t cur_buft = cur.second;
  372. if (buft_supported(cur_buft, cur_dev, fn)) {
  373. return cur_buft;
  374. }
  375. }
  376. throw std::runtime_error("no suitable buffer type found");
  377. }
  378. // used by llama_adapter_cvec
  379. ggml_backend_buffer_type_t llama_model_select_buft(const llama_model & model, int il);
  380. // used by llama_adapter_lora
  381. struct ggml_tensor * llama_model_get_tensor(const struct llama_model & model, const char * name);
  382. size_t llama_model_max_nodes(const llama_model & model);
  383. struct llama_model_loader;
  384. // TODO: become llama_model methods
  385. void llm_load_stats (llama_model_loader & ml, llama_model & model);
  386. void llm_load_arch (llama_model_loader & ml, llama_model & model);
  387. void llm_load_hparams (llama_model_loader & ml, llama_model & model);
  388. void llm_load_vocab (llama_model_loader & ml, llama_model & model);
  389. void llm_load_print_meta(llama_model_loader & ml, llama_model & model);