k_quants.h 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. /**
  2. * llama.cpp - git 3ebb00935f3f0522b75df49c2769ab1774b91380
  3. *
  4. * MIT License
  5. *
  6. * Copyright (c) 2023 Georgi Gerganov
  7. *
  8. * Permission is hereby granted, free of charge, to any person obtaining a copy
  9. * of this software and associated documentation files (the "Software"), to deal
  10. * in the Software without restriction, including without limitation the rights
  11. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  12. * copies of the Software, and to permit persons to whom the Software is
  13. * furnished to do so, subject to the following conditions:
  14. *
  15. * The above copyright notice and this permission notice shall be included in all
  16. * copies or substantial portions of the Software.
  17. *
  18. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  19. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  20. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  21. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  22. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  23. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  24. * SOFTWARE.
  25. */
  26. #pragma once
  27. #include "ggml.h"
  28. #include <stdint.h>
  29. #include <assert.h>
  30. #include <stddef.h>
  31. // Super-block size
  32. #ifdef GGML_QKK_64
  33. #define QK_K 64
  34. #define K_SCALE_SIZE 4
  35. #else
  36. #define QK_K 256
  37. #define K_SCALE_SIZE 12
  38. #endif
  39. #ifndef static_assert
  40. #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L)
  41. #define static_assert(cond, msg) _Static_assert(cond, msg)
  42. #else
  43. #define static_assert(cond, msg) struct global_scope_noop_trick
  44. #endif
  45. #endif
  46. //
  47. // Super-block quantization structures
  48. //
  49. // 2-bit quantization
  50. // weight is represented as x = a * q + b
  51. // 16 blocks of 16 elemenets each
  52. // Effectively 2.5625 bits per weight
  53. typedef struct {
  54. uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
  55. uint8_t qs[QK_K/4]; // quants
  56. ggml_fp16_t d; // super-block scale for quantized scales
  57. ggml_fp16_t dmin; // super-block scale for quantized mins
  58. } block_q2_K;
  59. static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
  60. // 3-bit quantization
  61. // weight is represented as x = a * q
  62. // 16 blocks of 16 elemenets each
  63. // Effectively 3.4375 bits per weight
  64. #ifdef GGML_QKK_64
  65. typedef struct {
  66. uint8_t hmask[QK_K/8]; // quants - high bit
  67. uint8_t qs[QK_K/4]; // quants - low 2 bits
  68. uint8_t scales[2];
  69. ggml_fp16_t d; // super-block scale
  70. } block_q3_K;
  71. static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 2, "wrong q3_K block size/padding");
  72. #else
  73. typedef struct {
  74. uint8_t hmask[QK_K/8]; // quants - high bit
  75. uint8_t qs[QK_K/4]; // quants - low 2 bits
  76. uint8_t scales[12]; // scales, quantized with 6 bits
  77. ggml_fp16_t d; // super-block scale
  78. } block_q3_K;
  79. static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding");
  80. #endif
  81. // 4-bit quantization
  82. // 16 blocks of 32 elements each
  83. // weight is represented as x = a * q + b
  84. // Effectively 4.5 bits per weight
  85. #ifdef GGML_QKK_64
  86. typedef struct {
  87. ggml_fp16_t d[2]; // super-block scales/mins
  88. uint8_t scales[2]; // 4-bit block scales/mins
  89. uint8_t qs[QK_K/2]; // 4--bit quants
  90. } block_q4_K;
  91. static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding");
  92. #else
  93. typedef struct {
  94. ggml_fp16_t d; // super-block scale for quantized scales
  95. ggml_fp16_t dmin; // super-block scale for quantized mins
  96. uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
  97. uint8_t qs[QK_K/2]; // 4--bit quants
  98. } block_q4_K;
  99. static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding");
  100. #endif
  101. // 5-bit quantization
  102. // 16 blocks of 32 elements each
  103. // weight is represented as x = a * q + b
  104. // Effectively 5.5 bits per weight
  105. #ifdef GGML_QKK_64
  106. typedef struct {
  107. ggml_fp16_t d; // super-block scale
  108. int8_t scales[QK_K/16]; // 8-bit block scales
  109. uint8_t qh[QK_K/8]; // quants, high bit
  110. uint8_t qs[QK_K/2]; // quants, low 4 bits
  111. } block_q5_K;
  112. static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding");
  113. #else
  114. typedef struct {
  115. ggml_fp16_t d; // super-block scale for quantized scales
  116. ggml_fp16_t dmin; // super-block scale for quantized mins
  117. uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
  118. uint8_t qh[QK_K/8]; // quants, high bit
  119. uint8_t qs[QK_K/2]; // quants, low 4 bits
  120. } block_q5_K;
  121. static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
  122. #endif
  123. // 6-bit quantization
  124. // weight is represented as x = a * q
  125. // 16 blocks of 16 elemenets each
  126. // Effectively 6.5625 bits per weight
  127. typedef struct {
  128. uint8_t ql[QK_K/2]; // quants, lower 4 bits
  129. uint8_t qh[QK_K/4]; // quants, upper 2 bits
  130. int8_t scales[QK_K/16]; // scales, quantized with 8 bits
  131. ggml_fp16_t d; // super-block scale
  132. } block_q6_K;
  133. static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + QK_K / 16 + 3*QK_K/4, "wrong q6_K block size/padding");
  134. // This is only used for intermediate quantization and dot products
  135. typedef struct {
  136. float d; // delta
  137. int8_t qs[QK_K]; // quants
  138. int16_t bsums[QK_K/16]; // sum of quants in groups of 16
  139. } block_q8_K;
  140. static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
  141. // Quantization
  142. void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k);
  143. void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k);
  144. void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k);
  145. void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k);
  146. void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k);
  147. void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k);
  148. void quantize_row_q2_K(const float * restrict x, void * restrict y, int k);
  149. void quantize_row_q3_K(const float * restrict x, void * restrict y, int k);
  150. void quantize_row_q4_K(const float * restrict x, void * restrict y, int k);
  151. void quantize_row_q5_K(const float * restrict x, void * restrict y, int k);
  152. void quantize_row_q6_K(const float * restrict x, void * restrict y, int k);
  153. void quantize_row_q8_K(const float * restrict x, void * restrict y, int k);
  154. // Dequantization
  155. void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k);
  156. void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k);
  157. void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k);
  158. void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k);
  159. void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k);
  160. void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k);
  161. // Dot product
  162. void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
  163. void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
  164. void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
  165. void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
  166. void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
  167. // Quantization with histogram collection
  168. size_t ggml_quantize_q2_K(const float * src, void * dst, int n, int k, int64_t * hist);
  169. size_t ggml_quantize_q3_K(const float * src, void * dst, int n, int k, int64_t * hist);
  170. size_t ggml_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist);
  171. size_t ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist);
  172. size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist);