ggml-cpu-impl.h 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. /**
  2. * llama.cpp - commit ba1cb19cdd0d92e012e0f6e009e0620f854b6afd - do not edit this file
  3. *
  4. * MIT License
  5. *
  6. * Copyright (c) 2023-2024 The ggml authors
  7. *
  8. * Permission is hereby granted, free of charge, to any person obtaining a copy
  9. * of this software and associated documentation files (the "Software"), to deal
  10. * in the Software without restriction, including without limitation the rights
  11. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  12. * copies of the Software, and to permit persons to whom the Software is
  13. * furnished to do so, subject to the following conditions:
  14. *
  15. * The above copyright notice and this permission notice shall be included in all
  16. * copies or substantial portions of the Software.
  17. *
  18. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  19. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  20. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  21. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  22. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  23. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  24. * SOFTWARE.
  25. */
  26. #pragma once
  27. // GGML CPU internal header
  28. #include "ggml.h"
  29. #include "ggml-impl.h"
  30. #include <stdlib.h> // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/
  31. //#include <stddef.h>
  32. #include <stdbool.h>
  33. #include <string.h> // memcpy
  34. #include <math.h> // fabsf
  35. #ifdef __cplusplus
  36. extern "C" {
  37. #endif
  38. struct ggml_compute_params {
  39. // ith = thread index, nth = number of threads
  40. int ith, nth;
  41. // work buffer for all threads
  42. size_t wsize;
  43. void * wdata;
  44. struct ggml_threadpool * threadpool;
  45. };
  46. #if defined(_MSC_VER)
  47. #define m512bh(p) p
  48. #define m512i(p) p
  49. #else
  50. #define m512bh(p) (__m512bh)(p)
  51. #define m512i(p) (__m512i)(p)
  52. #endif
  53. // __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
  54. #if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))
  55. #ifndef __FMA__
  56. #define __FMA__
  57. #endif
  58. #ifndef __F16C__
  59. #define __F16C__
  60. #endif
  61. #endif
  62. // __SSE3__ and __SSSE3__ are not defined in MSVC, but SSE3/SSSE3 are present when AVX/AVX2/AVX512 are available
  63. #if defined(_MSC_VER) && (defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__))
  64. #ifndef __SSE3__
  65. #define __SSE3__
  66. #endif
  67. #ifndef __SSSE3__
  68. #define __SSSE3__
  69. #endif
  70. #endif
  71. #if defined(__ARM_FEATURE_SVE)
  72. #include <arm_sve.h>
  73. #include <sys/prctl.h>
  74. #endif
  75. // 16-bit float
  76. // on Arm, we use __fp16
  77. // on x86, we use uint16_t
  78. #if defined(__ARM_NEON)
  79. // if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
  80. //
  81. // $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
  82. //
  83. #include <arm_neon.h>
  84. #ifdef _MSC_VER
  85. typedef uint16_t ggml_fp16_internal_t;
  86. #define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) }
  87. #else
  88. typedef __fp16 ggml_fp16_internal_t;
  89. #define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) }
  90. #endif // _MSC_VER
  91. #if !defined(__aarch64__)
  92. // 32-bit ARM compatibility
  93. // vaddlvq_s16
  94. // vpaddq_s16
  95. // vpaddq_s32
  96. // vaddvq_s32
  97. // vaddvq_f32
  98. // vmaxvq_f32
  99. // vcvtnq_s32_f32
  100. // vzip1_u8
  101. // vzip2_u8
  102. inline static int32_t vaddlvq_s16(int16x8_t v) {
  103. int32x4_t v0 = vreinterpretq_s32_s64(vpaddlq_s32(vpaddlq_s16(v)));
  104. return vgetq_lane_s32(v0, 0) + vgetq_lane_s32(v0, 2);
  105. }
  106. inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
  107. int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
  108. int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
  109. return vcombine_s16(a0, b0);
  110. }
  111. inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {
  112. int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a));
  113. int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b));
  114. return vcombine_s32(a0, b0);
  115. }
  116. inline static int32_t vaddvq_s32(int32x4_t v) {
  117. return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
  118. }
  119. inline static float vaddvq_f32(float32x4_t v) {
  120. return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
  121. }
  122. inline static float vmaxvq_f32(float32x4_t v) {
  123. return
  124. MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
  125. MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
  126. }
  127. inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
  128. int32x4_t res;
  129. res[0] = roundf(vgetq_lane_f32(v, 0));
  130. res[1] = roundf(vgetq_lane_f32(v, 1));
  131. res[2] = roundf(vgetq_lane_f32(v, 2));
  132. res[3] = roundf(vgetq_lane_f32(v, 3));
  133. return res;
  134. }
  135. inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
  136. uint8x8_t res;
  137. res[0] = a[0]; res[1] = b[0];
  138. res[2] = a[1]; res[3] = b[1];
  139. res[4] = a[2]; res[5] = b[2];
  140. res[6] = a[3]; res[7] = b[3];
  141. return res;
  142. }
  143. inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
  144. uint8x8_t res;
  145. res[0] = a[4]; res[1] = b[4];
  146. res[2] = a[5]; res[3] = b[5];
  147. res[4] = a[6]; res[5] = b[6];
  148. res[6] = a[7]; res[7] = b[7];
  149. return res;
  150. }
  151. // vld1q_s16_x2
  152. // vld1q_u8_x2
  153. // vld1q_u8_x4
  154. // vld1q_s8_x2
  155. // vld1q_s8_x4
  156. // TODO: double-check these work correctly
  157. typedef struct ggml_int16x8x2_t {
  158. int16x8_t val[2];
  159. } ggml_int16x8x2_t;
  160. inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) {
  161. ggml_int16x8x2_t res;
  162. res.val[0] = vld1q_s16(ptr + 0);
  163. res.val[1] = vld1q_s16(ptr + 8);
  164. return res;
  165. }
  166. typedef struct ggml_uint8x16x2_t {
  167. uint8x16_t val[2];
  168. } ggml_uint8x16x2_t;
  169. inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) {
  170. ggml_uint8x16x2_t res;
  171. res.val[0] = vld1q_u8(ptr + 0);
  172. res.val[1] = vld1q_u8(ptr + 16);
  173. return res;
  174. }
  175. typedef struct ggml_uint8x16x4_t {
  176. uint8x16_t val[4];
  177. } ggml_uint8x16x4_t;
  178. inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) {
  179. ggml_uint8x16x4_t res;
  180. res.val[0] = vld1q_u8(ptr + 0);
  181. res.val[1] = vld1q_u8(ptr + 16);
  182. res.val[2] = vld1q_u8(ptr + 32);
  183. res.val[3] = vld1q_u8(ptr + 48);
  184. return res;
  185. }
  186. typedef struct ggml_int8x16x2_t {
  187. int8x16_t val[2];
  188. } ggml_int8x16x2_t;
  189. inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) {
  190. ggml_int8x16x2_t res;
  191. res.val[0] = vld1q_s8(ptr + 0);
  192. res.val[1] = vld1q_s8(ptr + 16);
  193. return res;
  194. }
  195. typedef struct ggml_int8x16x4_t {
  196. int8x16_t val[4];
  197. } ggml_int8x16x4_t;
  198. inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
  199. ggml_int8x16x4_t res;
  200. res.val[0] = vld1q_s8(ptr + 0);
  201. res.val[1] = vld1q_s8(ptr + 16);
  202. res.val[2] = vld1q_s8(ptr + 32);
  203. res.val[3] = vld1q_s8(ptr + 48);
  204. return res;
  205. }
  206. // NOTE: not tested
  207. inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) {
  208. int8x16_t res;
  209. res[ 0] = a[b[ 0]];
  210. res[ 1] = a[b[ 1]];
  211. res[ 2] = a[b[ 2]];
  212. res[ 3] = a[b[ 3]];
  213. res[ 4] = a[b[ 4]];
  214. res[ 5] = a[b[ 5]];
  215. res[ 6] = a[b[ 6]];
  216. res[ 7] = a[b[ 7]];
  217. res[ 8] = a[b[ 8]];
  218. res[ 9] = a[b[ 9]];
  219. res[10] = a[b[10]];
  220. res[11] = a[b[11]];
  221. res[12] = a[b[12]];
  222. res[13] = a[b[13]];
  223. res[14] = a[b[14]];
  224. res[15] = a[b[15]];
  225. return res;
  226. }
  227. // NOTE: not tested
  228. inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) {
  229. uint8x16_t res;
  230. res[ 0] = a[b[ 0]];
  231. res[ 1] = a[b[ 1]];
  232. res[ 2] = a[b[ 2]];
  233. res[ 3] = a[b[ 3]];
  234. res[ 4] = a[b[ 4]];
  235. res[ 5] = a[b[ 5]];
  236. res[ 6] = a[b[ 6]];
  237. res[ 7] = a[b[ 7]];
  238. res[ 8] = a[b[ 8]];
  239. res[ 9] = a[b[ 9]];
  240. res[10] = a[b[10]];
  241. res[11] = a[b[11]];
  242. res[12] = a[b[12]];
  243. res[13] = a[b[13]];
  244. res[14] = a[b[14]];
  245. res[15] = a[b[15]];
  246. return res;
  247. }
  248. #else
  249. #define ggml_int16x8x2_t int16x8x2_t
  250. #define ggml_uint8x16x2_t uint8x16x2_t
  251. #define ggml_uint8x16x4_t uint8x16x4_t
  252. #define ggml_int8x16x2_t int8x16x2_t
  253. #define ggml_int8x16x4_t int8x16x4_t
  254. #define ggml_vld1q_s16_x2 vld1q_s16_x2
  255. #define ggml_vld1q_u8_x2 vld1q_u8_x2
  256. #define ggml_vld1q_u8_x4 vld1q_u8_x4
  257. #define ggml_vld1q_s8_x2 vld1q_s8_x2
  258. #define ggml_vld1q_s8_x4 vld1q_s8_x4
  259. #define ggml_vqtbl1q_s8 vqtbl1q_s8
  260. #define ggml_vqtbl1q_u8 vqtbl1q_u8
  261. #endif // !defined(__aarch64__)
  262. #if !defined(__ARM_FEATURE_DOTPROD)
  263. inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
  264. const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
  265. const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
  266. return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));
  267. }
  268. #else
  269. #define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c)
  270. #endif // !defined(__ARM_FEATURE_DOTPROD)
  271. #endif // defined(__ARM_NEON)
  272. #ifdef __wasm_simd128__
  273. #include <wasm_simd128.h>
  274. #else
  275. #ifdef __POWER9_VECTOR__
  276. #include <altivec.h>
  277. #undef bool
  278. #define bool _Bool
  279. #else
  280. #if defined(_MSC_VER) || defined(__MINGW32__)
  281. #include <intrin.h>
  282. #else
  283. #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__)
  284. #if !defined(__riscv)
  285. #include <immintrin.h>
  286. #endif
  287. #endif
  288. #endif
  289. #endif
  290. #endif
  291. #ifdef __riscv_v_intrinsic
  292. #include <riscv_vector.h>
  293. #endif
  294. #if defined(__loongarch64)
  295. #if defined(__loongarch_asx)
  296. #include <lasxintrin.h>
  297. #endif
  298. #if defined(__loongarch_sx)
  299. #include <lsxintrin.h>
  300. #endif
  301. #endif
  302. #if defined(__loongarch_asx)
  303. typedef union {
  304. int32_t i;
  305. float f;
  306. } ft_union;
  307. /* float type data load instructions */
  308. static __m128 __lsx_vreplfr2vr_s(float val) {
  309. ft_union fi_tmpval = {.f = val};
  310. return (__m128)__lsx_vreplgr2vr_w(fi_tmpval.i);
  311. }
  312. static __m256 __lasx_xvreplfr2vr_s(float val) {
  313. ft_union fi_tmpval = {.f = val};
  314. return (__m256)__lasx_xvreplgr2vr_w(fi_tmpval.i);
  315. }
  316. #endif
  317. // TODO: move to ggml-threading
  318. void ggml_barrier(struct ggml_threadpool * tp);
  319. #ifdef __cplusplus
  320. }
  321. #endif