ggml-aarch64.c 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. /**
  2. * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
  3. *
  4. * MIT License
  5. *
  6. * Copyright (c) 2023-2024 The ggml authors
  7. *
  8. * Permission is hereby granted, free of charge, to any person obtaining a copy
  9. * of this software and associated documentation files (the "Software"), to deal
  10. * in the Software without restriction, including without limitation the rights
  11. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  12. * copies of the Software, and to permit persons to whom the Software is
  13. * furnished to do so, subject to the following conditions:
  14. *
  15. * The above copyright notice and this permission notice shall be included in all
  16. * copies or substantial portions of the Software.
  17. *
  18. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  19. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  20. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  21. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  22. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  23. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  24. * SOFTWARE.
  25. */
  26. #define GGML_COMMON_DECL_C
  27. #include "ggml-common.h"
  28. #include "ggml-aarch64.h"
  29. #include "ggml-impl.h"
  30. #include "ggml-quants.h"
  31. #include <assert.h>
  32. #define UNUSED GGML_UNUSED
  33. static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) {
  34. block_q4_0x4 out;
  35. for (int i = 0; i < 4; i++) {
  36. out.d[i] = in[i].d;
  37. }
  38. const int end = QK4_0 * 2 / blck_size_interleave;
  39. if (blck_size_interleave == 8) {
  40. const uint64_t xor_mask = 0x8888888888888888ULL;
  41. for (int i = 0; i < end; ++i) {
  42. int src_id = i % 4;
  43. int src_offset = (i / 4) * blck_size_interleave;
  44. int dst_offset = i * blck_size_interleave;
  45. uint64_t elems;
  46. // Using memcpy to avoid unaligned memory accesses
  47. memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
  48. elems ^= xor_mask;
  49. memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
  50. }
  51. } else if (blck_size_interleave == 4) {
  52. const uint32_t xor_mask = 0x88888888;
  53. for (int i = 0; i < end; ++i) {
  54. int src_id = i % 4;
  55. int src_offset = (i / 4) * blck_size_interleave;
  56. int dst_offset = i * blck_size_interleave;
  57. uint32_t elems;
  58. memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint32_t));
  59. elems ^= xor_mask;
  60. memcpy(&out.qs[dst_offset], &elems, sizeof(uint32_t));
  61. }
  62. } else {
  63. GGML_ASSERT(false);
  64. }
  65. return out;
  66. }
  67. // interleave 8 block_q4_0s in blocks of blck_size_interleave
  68. // returns an interleaved block_q4_0x8
  69. // in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks
  70. // first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave
  71. static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave) {
  72. block_q4_0x8 out;
  73. for (int i = 0; i < 8; i++) {
  74. out.d[i] = in[i].d;
  75. }
  76. const int end = QK4_0 * 4 / blck_size_interleave;
  77. const uint64_t xor_mask = 0x8888888888888888ULL;
  78. for (int i = 0; i < end; ++i) {
  79. int src_id = i % 8;
  80. int src_offset = (i / 8) * blck_size_interleave;
  81. int dst_offset = i * blck_size_interleave;
  82. uint64_t elems;
  83. memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
  84. elems ^= xor_mask;
  85. memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
  86. }
  87. return out;
  88. }
  89. static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, int nrows_interleaved, int blck_size_interleave) {
  90. assert(n_per_row % QK4_0 == 0);
  91. const int nb = n_per_row / QK4_0;
  92. void * out_ptr = NULL;
  93. if (nrows_interleaved == 8) {
  94. out_ptr = (block_q4_0x8 *) dst;
  95. }
  96. else if (nrows_interleaved == 4) {
  97. out_ptr = (block_q4_0x4 *) dst;
  98. }
  99. assert(nrows_interleaved <= 8);
  100. block_q4_0 dst_tmp[8];
  101. for (int b = 0; b < (nrow * n_per_row); b += nrows_interleaved * n_per_row) {
  102. for (int64_t x = 0; x < nb; x++) {
  103. for (int i = 0; i < nrows_interleaved; i++ ) {
  104. quantize_row_q4_0_ref(src + b + i * n_per_row + x * QK4_0, (block_q4_0 *) dst_tmp + i, QK4_0);
  105. }
  106. if (nrows_interleaved == 8) {
  107. *(block_q4_0x8 *) out_ptr = make_block_q4_0x8(dst_tmp, blck_size_interleave);
  108. out_ptr = (block_q4_0x8 *) out_ptr + 1;
  109. }
  110. else if (nrows_interleaved == 4) {
  111. *(block_q4_0x4 *) out_ptr = make_block_q4_0x4(dst_tmp, blck_size_interleave);
  112. out_ptr = (block_q4_0x4 *) out_ptr + 1;
  113. }
  114. }
  115. }
  116. return ((nrow * n_per_row) / QK4_0 * sizeof(block_q4_0));
  117. }
  118. size_t quantize_q4_0_4x4(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
  119. UNUSED(quant_weights);
  120. return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 4);
  121. }
  122. size_t quantize_q4_0_4x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
  123. UNUSED(quant_weights);
  124. return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 8);
  125. }
  126. size_t quantize_q4_0_8x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
  127. UNUSED(quant_weights);
  128. return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 8, 8);
  129. }