10-params.diff 1.1 KB

1234567891011121314151617181920
  1. diff --git a/src/llama.cpp b/src/llama.cpp
  2. index a207451f..fba6b175 100644
  3. --- a/src/llama.cpp
  4. +++ b/src/llama.cpp
  5. @@ -4969,6 +4969,7 @@ static void llm_load_hparams(
  6. hparams.attn_soft_cap = true;
  7. switch (hparams.n_layer) {
  8. + case 26: model.type = e_model::MODEL_2B; break;
  9. case 42: model.type = e_model::MODEL_9B; break;
  10. case 46: model.type = e_model::MODEL_27B; break;
  11. default: model.type = e_model::MODEL_UNKNOWN;
  12. @@ -11736,6 +11737,7 @@ struct llm_build_context {
  13. // ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
  14. switch (model.type) {
  15. + case e_model::MODEL_2B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k))); break;
  16. case e_model::MODEL_9B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k))); break;
  17. case e_model::MODEL_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head))); break;
  18. default: GGML_ABORT("fatal error");