0021-gemma3-quantization.patch 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
  2. From: Patrick Devine <patrick@infrahq.com>
  3. Date: Fri, 14 Mar 2025 16:33:23 -0700
  4. Subject: [PATCH] gemma3 quantization
  5. ---
  6. src/llama-arch.cpp | 19 +++++++++++++++++++
  7. src/llama-arch.h | 1 +
  8. src/llama-model.cpp | 7 +++++++
  9. src/llama-quant.cpp | 9 +++++++++
  10. 4 files changed, 36 insertions(+)
  11. diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
  12. index b6f20286..b443fcd3 100644
  13. --- a/src/llama-arch.cpp
  14. +++ b/src/llama-arch.cpp
  15. @@ -37,6 +37,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
  16. { LLM_ARCH_MINICPM3, "minicpm3" },
  17. { LLM_ARCH_GEMMA, "gemma" },
  18. { LLM_ARCH_GEMMA2, "gemma2" },
  19. + { LLM_ARCH_GEMMA3, "gemma3" },
  20. { LLM_ARCH_STARCODER2, "starcoder2" },
  21. { LLM_ARCH_MAMBA, "mamba" },
  22. { LLM_ARCH_XVERSE, "xverse" },
  23. @@ -804,6 +805,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
  24. { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
  25. },
  26. },
  27. + {
  28. + LLM_ARCH_GEMMA3,
  29. + {
  30. + { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
  31. + { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
  32. + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
  33. + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
  34. + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
  35. + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
  36. + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
  37. + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
  38. + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
  39. + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
  40. + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
  41. + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
  42. + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
  43. + },
  44. + },
  45. {
  46. LLM_ARCH_STARCODER2,
  47. {
  48. diff --git a/src/llama-arch.h b/src/llama-arch.h
  49. index ec742224..aad92a5d 100644
  50. --- a/src/llama-arch.h
  51. +++ b/src/llama-arch.h
  52. @@ -41,6 +41,7 @@ enum llm_arch {
  53. LLM_ARCH_MINICPM3,
  54. LLM_ARCH_GEMMA,
  55. LLM_ARCH_GEMMA2,
  56. + LLM_ARCH_GEMMA3,
  57. LLM_ARCH_STARCODER2,
  58. LLM_ARCH_MAMBA,
  59. LLM_ARCH_XVERSE,
  60. diff --git a/src/llama-model.cpp b/src/llama-model.cpp
  61. index ab1a07d1..70183041 100644
  62. --- a/src/llama-model.cpp
  63. +++ b/src/llama-model.cpp
  64. @@ -878,6 +878,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
  65. default: type = LLM_TYPE_UNKNOWN;
  66. }
  67. } break;
  68. + case LLM_ARCH_GEMMA3:
  69. + {
  70. + } break;
  71. case LLM_ARCH_STARCODER2:
  72. {
  73. ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
  74. @@ -2537,6 +2540,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
  75. layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
  76. }
  77. } break;
  78. + case LLM_ARCH_GEMMA3:
  79. + {
  80. + } break;
  81. case LLM_ARCH_STARCODER2:
  82. {
  83. tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
  84. @@ -4029,6 +4035,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
  85. case LLM_ARCH_PHIMOE:
  86. case LLM_ARCH_GEMMA:
  87. case LLM_ARCH_GEMMA2:
  88. + case LLM_ARCH_GEMMA3:
  89. case LLM_ARCH_STARCODER2:
  90. case LLM_ARCH_OPENELM:
  91. case LLM_ARCH_GPTNEOX:
  92. diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp
  93. index 6eb1da08..d2f3a510 100644
  94. --- a/src/llama-quant.cpp
  95. +++ b/src/llama-quant.cpp
  96. @@ -737,6 +737,15 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
  97. // This used to be a regex, but <regex> has an extreme cost to compile times.
  98. bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
  99. + // don't quantize vision stuff
  100. + quantize &= name.find("v.blk.") == std::string::npos;
  101. +
  102. + quantize &= name.find("mm.mm_input_projection.weight") == std::string::npos;
  103. + quantize &= name.find("mm.mm_soft_emb_norm.weight") == std::string::npos;
  104. + quantize &= name.find("v.patch_embedding.weight") == std::string::npos;
  105. + quantize &= name.find("v.position_embedding.weight") == std::string::npos;
  106. + quantize &= name.find("v.post_layernorm.weight") == std::string::npos;
  107. +
  108. // quantize only 2D and 3D tensors (experts)
  109. quantize &= (ggml_n_dims(tensor) >= 2);