1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159 |
- /**
- * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
- *
- * MIT License
- *
- * Copyright (c) 2023-2024 The ggml authors
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to deal
- * in the Software without restriction, including without limitation the rights
- * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- * copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
- #include "common.cuh"
- #include <cstdint>
- static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) {
- const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment
- int x32 = x16[2*i32 + 0] << 0;
- x32 |= x16[2*i32 + 1] << 16;
- return x32;
- }
- static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32) {
- return ((const int *) x)[i32]; // assume at least 4 byte alignment
- }
- // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
- // MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
- #define VDR_Q4_0_Q8_1_MMVQ 2
- #define VDR_Q4_0_Q8_1_MMQ 4
- template <int vdr> static __device__ __forceinline__ float vec_dot_q4_0_q8_1_impl(
- const int * v, const int * u, const float & d4, const half2 & ds8) {
- int sumi = 0;
- #pragma unroll
- for (int i = 0; i < vdr; ++i) {
- const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;
- const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
- // SIMD dot product of quantized values
- sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi);
- sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi);
- }
- const float2 ds8f = __half22float2(ds8);
- // second part effectively subtracts 8 from each quant value
- return d4 * (sumi * ds8f.x - (8*vdr/QI4_0) * ds8f.y);
- }
- #define VDR_Q4_1_Q8_1_MMVQ 2
- #define VDR_Q4_1_Q8_1_MMQ 4
- template <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_impl(
- const int * v, const int * u, const half2 & dm4, const half2 & ds8) {
- int sumi = 0;
- #pragma unroll
- for (int i = 0; i < vdr; ++i) {
- const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;
- const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
- // SIMD dot product of quantized values
- sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi);
- sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi);
- }
- #ifdef GGML_CUDA_F16
- const float2 tmp = __half22float2(__hmul2(dm4, ds8));
- const float d4d8 = tmp.x;
- const float m4s8 = tmp.y;
- #else
- const float2 dm4f = __half22float2(dm4);
- const float2 ds8f = __half22float2(ds8);
- const float d4d8 = dm4f.x * ds8f.x;
- const float m4s8 = dm4f.y * ds8f.y;
- #endif // GGML_CUDA_F16
- // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it
- return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1));
- }
- #define VDR_Q5_0_Q8_1_MMVQ 2
- #define VDR_Q5_0_Q8_1_MMQ 4
- template <int vdr> static __device__ __forceinline__ float vec_dot_q5_0_q8_1_impl(
- const int * vl, const int * vh, const int * u, const float & d5, const half2 & ds8) {
- int sumi = 0;
- #pragma unroll
- for (int i = 0; i < vdr; ++i) {
- int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits
- vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4
- vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12
- vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20
- vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28
- sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values
- int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
- vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4
- vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12
- vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20
- vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28
- sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values
- }
- const float2 ds8f = __half22float2(ds8);
- // second part effectively subtracts 16 from each quant value
- return d5 * (sumi * ds8f.x - (16*vdr/QI5_0) * ds8f.y);
- }
- #define VDR_Q5_1_Q8_1_MMVQ 2
- #define VDR_Q5_1_Q8_1_MMQ 4
- template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_impl(
- const int * vl, const int * vh, const int * u, const half2 & dm5, const half2 & ds8) {
- int sumi = 0;
- #pragma unroll
- for (int i = 0; i < vdr; ++i) {
- int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits
- vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4
- vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12
- vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20
- vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28
- sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values
- int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
- vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4
- vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12
- vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20
- vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28
- sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values
- }
- #ifdef GGML_CUDA_F16
- const float2 tmp = __half22float2(__hmul2(dm5, ds8));
- const float d5d8 = tmp.x;
- const float m5s8 = tmp.y;
- #else
- const float2 dm5f = __half22float2(dm5);
- const float2 ds8f = __half22float2(ds8);
- const float d5d8 = dm5f.x * ds8f.x;
- const float m5s8 = dm5f.y * ds8f.y;
- #endif // GGML_CUDA_F16
- // scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it
- return sumi*d5d8 + m5s8 / (QI5_1 / vdr);
- }
- #define VDR_Q8_0_Q8_1_MMVQ 2
- #define VDR_Q8_0_Q8_1_MMQ 8
- template <typename T, int vdr> static __device__ __forceinline__ T vec_dot_q8_0_q8_1_impl(
- const int * v, const int * u, const T & d8_0, const T & d8_1) {
- int sumi = 0;
- #pragma unroll
- for (int i = 0; i < vdr; ++i) {
- // SIMD dot product of quantized values
- sumi = ggml_cuda_dp4a(v[i], u[i], sumi);
- }
- return d8_0*d8_1 * ((T) sumi);
- }
- template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_impl(
- const int * v, const int * u, const half2 & dm8, const half2 & ds8) {
- int sumi = 0;
- #pragma unroll
- for (int i = 0; i < vdr; ++i) {
- // SIMD dot product of quantized values
- sumi = ggml_cuda_dp4a(v[i], u[i], sumi);
- }
- #ifdef GGML_CUDA_F16
- const float2 tmp = __half22float2(__hmul2(dm8, ds8));
- const float d8d8 = tmp.x;
- const float m8s8 = tmp.y;
- #else
- const float2 dm8f = __half22float2(dm8);
- const float2 ds8f = __half22float2(ds8);
- const float d8d8 = dm8f.x * ds8f.x;
- const float m8s8 = dm8f.y * ds8f.y;
- #endif // GGML_CUDA_F16
- // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it
- return sumi*d8d8 + m8s8 / (QI8_1 / vdr);
- }
- template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_16_q8_1_impl(
- const int * v, const int * u, const float * d8_0, const float & d8_1) {
- float sumf = 0.0f;
- #pragma unroll
- for (int i0 = 0; i0 < vdr; i0 += QI8_0/2) {
- int sumi = 0;
- #pragma unroll
- for (int i = i0; i < i0 + QI8_0/2; ++i) {
- // SIMD dot product of quantized values
- sumi = ggml_cuda_dp4a(v[i], u[i], sumi);
- }
- sumf += d8_0[i0/(QI8_0/2)]*sumi;
- }
- return d8_1*sumf;
- }
- #define VDR_Q2_K_Q8_1_MMVQ 1
- #define VDR_Q2_K_Q8_1_MMQ 4
- // contiguous v/x values
- static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
- const int & v, const int * __restrict__ u, const uint8_t * __restrict__ scales,
- const half2 & dm2, const float * __restrict__ d8) {
- float sumf_d = 0.0f;
- float sumf_m = 0.0f;
- #pragma unroll
- for (int i = 0; i < QR2_K; ++i) {
- const int sc = scales[2*i];
- const int vi = (v >> (2*i)) & 0x03030303;
- sumf_d += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product
- // fill int with 4x m
- int m = sc >> 4;
- m |= m << 8;
- m |= m << 16;
- sumf_m += d8[i] * ggml_cuda_dp4a(m, u[i], 0); // multiply constant q2_K part with sum of q8_1 values
- }
- const float2 dm2f = __half22float2(dm2);
- return dm2f.x*sumf_d - dm2f.y*sumf_m;
- }
- // contiguous v/x + u/y values
- template <int ns8>
- static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
- const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8, const half2 * s8) {
- float sumf = 0.0f;
- float sumf_d8 = 0.0f;
- #pragma unroll
- for (int i0 = 0; i0 < QR2_K*VDR_Q2_K_Q8_1_MMQ; i0 += QI8_1) {
- const float2 dm2f0 = __half22float2(dm2[i0/(QI8_1/2) + 0]);
- int sumi_d0 = 0;
- const float2 dm2f1 = __half22float2(dm2[i0/(QI8_1/2) + 1]);
- int sumi_d1 = 0;
- #pragma unroll
- for (int i = i0; i < i0 + QI8_1/2; ++i) {
- sumi_d0 = ggml_cuda_dp4a(v[i], u[i], sumi_d0);
- }
- sumf_d8 += dm2f0.x * sumi_d0;
- #pragma unroll
- for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) {
- sumi_d1 = ggml_cuda_dp4a(v[i], u[i], sumi_d1);
- }
- sumf_d8 += dm2f1.x * sumi_d1;
- if (i0/QI8_1 < ns8) {
- const float2 s8f = __half22float2(s8[i0/QI8_1]);
- sumf -= dm2f0.y*s8f.x;
- sumf -= dm2f1.y*s8f.y;
- } else {
- int sumi_m0 = 0;
- #pragma unroll
- for (int i = i0; i < i0 + QI8_1/2; ++i) {
- sumi_m0 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m0);
- }
- sumf_d8 -= dm2f0.y * sumi_m0;
- int sumi_m1 = 0;
- #pragma unroll
- for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) {
- sumi_m1 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m1);
- }
- sumf_d8 -= dm2f1.y * sumi_m1;
- }
- }
- return sumf + d8*sumf_d8;
- }
- #define VDR_Q3_K_Q8_1_MMVQ 1
- #define VDR_Q3_K_Q8_1_MMQ 2
- // contiguous v/x values
- static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(
- const int & vl, const int & vh, const int * __restrict__ u, const uint8_t * __restrict__ scales,
- const int & scale_offset, const float & d3, const float * __restrict__ d8) {
- float sumf = 0.0f;
- #pragma unroll
- for (int i = 0; i < QR3_K; ++i) {
- const int isc = scale_offset + 2*i;
- const int isc_low = isc % (QK_K/32);
- const int sc_shift_low = 4 * (isc / (QK_K/32));
- const int sc_low = (scales[isc_low] >> sc_shift_low) & 0xF;
- const int isc_high = isc % (QK_K/64);
- const int sc_shift_high = 2 * (isc / (QK_K/64));
- const int sc_high = ((scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4;
- const int sc = (sc_low | sc_high) - 32;
- const int vil = (vl >> (2*i)) & 0x03030303;
- const int vih = ((vh >> i) << 2) & 0x04040404;
- const int vi = __vsubss4(vil, vih);
- sumf += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * sc); // SIMD dot product
- }
- return d3 * sumf;
- }
- // contiguous v/x + u/y values
- static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
- const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales,
- const float & d3, const float & d8) {
- int sumi = 0;
- #pragma unroll
- for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) {
- int sumi_sc = 0;
- #pragma unroll
- for (int i = i0; i < i0 + QI8_1/2; ++i) {
- sumi_sc = ggml_cuda_dp4a(v[i], u[i], sumi_sc); // SIMD dot product
- }
- sumi += sumi_sc * scales[i0 / (QI8_1/2)];
- }
- return d3*d8 * sumi;
- }
- #define VDR_Q4_K_Q8_1_MMVQ 2
- #define VDR_Q4_K_Q8_1_MMQ 8
- // contiguous v/x values
- static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
- const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
- const uint8_t * __restrict__ m, const half2 & dm4, const float * __restrict__ d8) {
- float sumf_d = 0.0f;
- float sumf_m = 0.0f;
- #pragma unroll
- for (int i = 0; i < QR4_K; ++i) {
- const int v0i = (v[0] >> (4*i)) & 0x0F0F0F0F;
- const int v1i = (v[1] >> (4*i)) & 0x0F0F0F0F;
- const int dot1 = ggml_cuda_dp4a(v1i, u[2*i+1], ggml_cuda_dp4a(v0i, u[2*i+0], 0)); // SIMD dot product
- const int dot2 = ggml_cuda_dp4a(0x01010101, u[2*i+1], ggml_cuda_dp4a(0x01010101, u[2*i+0], 0)); // sum of u
- sumf_d += d8[i] * (dot1 * sc[i]);
- sumf_m += d8[i] * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values
- }
- const float2 dm4f = __half22float2(dm4);
- return dm4f.x*sumf_d - dm4f.y*sumf_m;
- }
- // contiguous v/x + u/y values
- static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
- const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
- const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
- float sumf_d = 0.0f;
- float sumf_m = 0.0f;
- #pragma unroll
- for (int i = 0; i < QR4_K*VDR_Q4_K_Q8_1_MMQ/QI8_1; ++i) {
- int sumi_d = 0;
- #pragma unroll
- for (int j = 0; j < QI8_1; ++j) {
- sumi_d = ggml_cuda_dp4a((v[j] >> (4*i)) & 0x0F0F0F0F, u[i*QI8_1 + j], sumi_d); // SIMD dot product
- }
- const float2 ds8f = __half22float2(ds8[i]);
- sumf_d += ds8f.x * (sc[i] * sumi_d);
- sumf_m += ds8f.y * m[i]; // sum of q8_1 block * q4_K min val
- }
- const float2 dm4f = __half22float2(dm4);
- return dm4f.x*sumf_d - dm4f.y*sumf_m;
- }
- #define VDR_Q5_K_Q8_1_MMVQ 2
- #define VDR_Q5_K_Q8_1_MMQ 8
- // contiguous v/x values
- static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq(
- const int * __restrict__ vl, const int * __restrict__ vh, const int * __restrict__ u, const uint8_t * __restrict__ sc,
- const uint8_t * __restrict__ m, const half2 & dm5, const float * __restrict__ d8) {
- float sumf_d = 0.0f;
- float sumf_m = 0.0f;
- #pragma unroll
- for (int i = 0; i < QR5_K; ++i) {
- const int vl0i = (vl[0] >> (4*i)) & 0x0F0F0F0F;
- const int vl1i = (vl[1] >> (4*i)) & 0x0F0F0F0F;
- const int vh0i = ((vh[0] >> i) << 4) & 0x10101010;
- const int vh1i = ((vh[1] >> i) << 4) & 0x10101010;
- const int v0i = vl0i | vh0i;
- const int v1i = vl1i | vh1i;
- const int dot1 = ggml_cuda_dp4a(v0i, u[2*i+0], ggml_cuda_dp4a(v1i, u[2*i+1], 0)); // SIMD dot product
- const int dot2 = ggml_cuda_dp4a(0x01010101, u[2*i+0], ggml_cuda_dp4a(0x01010101, u[2*i+1], 0)); // sum of u
- sumf_d += d8[i] * (dot1 * sc[i]);
- sumf_m += d8[i] * (dot2 * m[i]);
- }
- const float2 dm5f = __half22float2(dm5);
- return dm5f.x*sumf_d - dm5f.y*sumf_m;
- }
- // contiguous v/x + u/y values
- static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
- const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
- const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
- float sumf_d = 0.0f;
- float sumf_m = 0.0f;
- #pragma unroll
- for (int i = 0; i < QR5_K*VDR_Q5_K_Q8_1_MMQ/QI8_1; ++i) {
- int sumi_d = 0;
- #pragma unroll
- for (int j = 0; j < QI8_1; ++j) {
- sumi_d = ggml_cuda_dp4a(v[i*QI8_1 + j], u[i*QI8_1 + j], sumi_d); // SIMD dot product
- }
- const float2 ds8f = __half22float2(ds8[i]);
- sumf_d += ds8f.x * (sc[i] * sumi_d);
- sumf_m += ds8f.y * m[i]; // sum of q8_1 block * q4_K min val
- }
- const float2 dm4f = __half22float2(dm4);
- return dm4f.x*sumf_d - dm4f.y*sumf_m;
- }
- #define VDR_Q6_K_Q8_1_MMVQ 1
- #define VDR_Q6_K_Q8_1_MMQ 8
- // contiguous v/x values
- static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(
- const int & vl, const int & vh, const int * __restrict__ u, const int8_t * __restrict__ scales,
- const float & d, const float * __restrict__ d8) {
- float sumf = 0.0f;
- #pragma unroll
- for (int i = 0; i < QR6_K; ++i) {
- const int sc = scales[4*i];
- const int vil = (vl >> (4*i)) & 0x0F0F0F0F;
- const int vih = ((vh >> (4*i)) << 4) & 0x30303030;
- const int vi = __vsubss4((vil | vih), 0x20202020); // vi = (vil | vih) - 32
- sumf += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * sc); // SIMD dot product
- }
- return d*sumf;
- }
- // contiguous v/x + u/y values
- static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
- const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc,
- const float & d6, const float * __restrict__ d8) {
- float sumf_d = 0.0f;
- const int sc_packed = get_int_b4(sc, 0);
- const int8_t * sc_reg = (const int8_t *) &sc_packed;
- #pragma unroll
- for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) {
- int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale
- #pragma unroll
- for (int i = i0; i < i0 + 2; ++i) {
- sumi_d.x = ggml_cuda_dp4a(v[2*i+0], u[2*i+0], sumi_d.x); // SIMD dot product
- sumi_d.x = ggml_cuda_dp4a(v[2*i+1], u[2*i+1], sumi_d.x); // SIMD dot product
- sumi_d.y = ggml_cuda_dp4a(v[2*i+4], u[2*i+4], sumi_d.y); // SIMD dot product
- sumi_d.y = ggml_cuda_dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product
- }
- sumf_d += d8[i0/4] * (sc_reg[i0/2+0]*sumi_d.x + sc_reg[i0/2+1]*sumi_d.y);
- }
- return d6 * sumf_d;
- }
- static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
- const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
- const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq + kbx;
- int v[VDR_Q4_0_Q8_1_MMVQ];
- int u[2*VDR_Q4_0_Q8_1_MMVQ];
- #pragma unroll
- for (int i = 0; i < VDR_Q4_0_Q8_1_MMVQ; ++i) {
- v[i] = get_int_b2(bq4_0->qs, iqs + i);
- u[2*i+0] = get_int_b4(bq8_1->qs, iqs + i);
- u[2*i+1] = get_int_b4(bq8_1->qs, iqs + i + QI4_0);
- }
- return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMVQ>(v, u, bq4_0->d, bq8_1->ds);
- }
- static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
- const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
- const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq + kbx;
- int v[VDR_Q4_1_Q8_1_MMVQ];
- int u[2*VDR_Q4_1_Q8_1_MMVQ];
- #pragma unroll
- for (int i = 0; i < VDR_Q4_1_Q8_1_MMVQ; ++i) {
- v[i] = get_int_b4(bq4_1->qs, iqs + i);
- u[2*i+0] = get_int_b4(bq8_1->qs, iqs + i);
- u[2*i+1] = get_int_b4(bq8_1->qs, iqs + i + QI4_1);
- }
- return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMVQ>(v, u, bq4_1->dm, bq8_1->ds);
- }
- static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
- const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
- const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq + kbx;
- int vl[VDR_Q5_0_Q8_1_MMVQ];
- int vh[VDR_Q5_0_Q8_1_MMVQ];
- int u[2*VDR_Q5_0_Q8_1_MMVQ];
- #pragma unroll
- for (int i = 0; i < VDR_Q5_0_Q8_1_MMVQ; ++i) {
- vl[i] = get_int_b2(bq5_0->qs, iqs + i);
- vh[i] = get_int_b2(bq5_0->qh, 0) >> (4 * (iqs + i));
- u[2*i+0] = get_int_b4(bq8_1->qs, iqs + i);
- u[2*i+1] = get_int_b4(bq8_1->qs, iqs + i + QI5_0);
- }
- return vec_dot_q5_0_q8_1_impl<VDR_Q5_0_Q8_1_MMVQ>(vl, vh, u, bq5_0->d, bq8_1->ds);
- }
- static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
- const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
- const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq + kbx;
- int vl[VDR_Q5_1_Q8_1_MMVQ];
- int vh[VDR_Q5_1_Q8_1_MMVQ];
- int u[2*VDR_Q5_1_Q8_1_MMVQ];
- #pragma unroll
- for (int i = 0; i < VDR_Q5_1_Q8_1_MMVQ; ++i) {
- vl[i] = get_int_b4(bq5_1->qs, iqs + i);
- vh[i] = get_int_b4(bq5_1->qh, 0) >> (4 * (iqs + i));
- u[2*i+0] = get_int_b4(bq8_1->qs, iqs + i);
- u[2*i+1] = get_int_b4(bq8_1->qs, iqs + i + QI5_1);
- }
- return vec_dot_q5_1_q8_1_impl<VDR_Q5_1_Q8_1_MMVQ>(vl, vh, u, bq5_1->dm, bq8_1->ds);
- }
- static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
- const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
- const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq + kbx;
- int v[VDR_Q8_0_Q8_1_MMVQ];
- int u[VDR_Q8_0_Q8_1_MMVQ];
- #pragma unroll
- for (int i = 0; i < VDR_Q8_0_Q8_1_MMVQ; ++i) {
- v[i] = get_int_b2(bq8_0->qs, iqs + i);
- u[i] = get_int_b4(bq8_1->qs, iqs + i);
- }
- return vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d, __low2half(bq8_1->ds));
- }
- static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
- const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
- const block_q2_K * bq2_K = (const block_q2_K *) vbq + kbx;
- const int bq8_offset = QR2_K * (iqs / QI8_1);
- const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
- const uint8_t * scales = bq2_K->scales + scale_offset;
- const int v = get_int_b4(bq2_K->qs, iqs);
- int u[QR2_K];
- float d8[QR2_K];
- #pragma unroll
- for (int i = 0; i < QR2_K; ++ i) {
- u[i] = get_int_b4(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
- d8[i] = __low2float(bq8_1[bq8_offset + i].ds);
- }
- return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8);
- }
- static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
- const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
- const block_q3_K * bq3_K = (const block_q3_K *) vbq + kbx;
- const int bq8_offset = QR3_K * (iqs / (QI3_K/2));
- const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
- const float d = bq3_K->d;
- const int vl = get_int_b2(bq3_K->qs, iqs);
- // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
- const int vh = ~get_int_b2(bq3_K->hmask, iqs % (QI3_K/2)) >> bq8_offset;
- int u[QR3_K];
- float d8[QR3_K];
- #pragma unroll
- for (int i = 0; i < QR3_K; ++i) {
- u[i] = get_int_b4(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
- d8[i] = __low2float(bq8_1[bq8_offset + i].ds);
- }
- return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8);
- }
- static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
- const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
- const block_q4_K * bq4_K = (const block_q4_K *) vbq + kbx;
- int v[2];
- int u[2*QR4_K];
- float d8[QR4_K];
- // iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6
- const int bq8_offset = QR4_K * ((iqs/2) / (QI8_1/2));
- // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12
- // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44
- // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76
- // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108
- const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));
- v[0] = q4[0];
- v[1] = q4[4];
- const uint16_t * scales = (const uint16_t *)bq4_K->scales;
- uint16_t aux[2];
- const int j = bq8_offset/2;
- if (j < 2) {
- aux[0] = scales[j+0] & 0x3f3f;
- aux[1] = scales[j+2] & 0x3f3f;
- } else {
- aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
- aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
- }
- const uint8_t * sc = (const uint8_t *)aux;
- const uint8_t * m = sc + 2;
- for (int i = 0; i < QR4_K; ++i) {
- const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
- d8[i] = __low2float(bq8i->ds);
- const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
- u[2*i+0] = q8[0];
- u[2*i+1] = q8[4];
- }
- return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8);
- }
- static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
- const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
- const block_q5_K * bq5_K = (const block_q5_K *) vbq + kbx;
- int vl[2];
- int vh[2];
- int u[2*QR5_K];
- float d8[QR5_K];
- const int bq8_offset = QR5_K * ((iqs/2) / (QI8_1/2));
- const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));
- const int * qh = (const int *)(bq5_K->qh + 4 * ((iqs/2)%4));
- vl[0] = ql[0];
- vl[1] = ql[4];
- vh[0] = qh[0] >> bq8_offset;
- vh[1] = qh[4] >> bq8_offset;
- const uint16_t * scales = (const uint16_t *)bq5_K->scales;
- uint16_t aux[2];
- const int j = bq8_offset/2;
- if (j < 2) {
- aux[0] = scales[j+0] & 0x3f3f;
- aux[1] = scales[j+2] & 0x3f3f;
- } else {
- aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
- aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
- }
- const uint8_t * sc = (const uint8_t *)aux;
- const uint8_t * m = sc + 2;
- #pragma unroll
- for (int i = 0; i < QR5_K; ++i) {
- const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
- d8[i] = __low2float(bq8i->ds);
- const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
- u[2*i+0] = q8[0];
- u[2*i+1] = q8[4];
- }
- return vec_dot_q5_K_q8_1_impl_vmmq(vl, vh, u, sc, m, bq5_K->dm, d8);
- }
- static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
- const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
- const block_q6_K * bq6_K = (const block_q6_K *) vbq + kbx;
- const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4);
- const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8);
- const int vh_shift = 2 * ((iqs % (QI6_K/2)) / (QI6_K/4));
- const int vl = get_int_b2(bq6_K->ql, iqs);
- const int vh = get_int_b2(bq6_K->qh, (QI6_K/4) * (iqs / (QI6_K/2)) + iqs % (QI6_K/4)) >> vh_shift;
- const int8_t * scales = bq6_K->scales + scale_offset;
- int u[QR6_K];
- float d8[QR6_K];
- #pragma unroll
- for (int i = 0; i < QR6_K; ++i) {
- u[i] = get_int_b4(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1);
- d8[i] = __low2float(bq8_1[bq8_offset + 2*i].ds);
- }
- return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8);
- }
- #define VDR_IQ2_XXS_Q8_1_MMVQ 2
- #define VDR_IQ2_XXS_Q8_1_MMQ 2
- static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
- const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
- const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq + kbx;
- const int q2 = get_int_b2(bq2->qs, iqs);
- const uint8_t * aux8 = (const uint8_t *) &q2;
- const uint32_t aux32 = get_int_b2(bq2->qs, iqs + 1);
- int sumi = 0;
- #pragma unroll
- for (int k0 = 0; k0 < 8; k0 += 2) {
- const int * grid_pos = (const int *) (iq2xxs_grid + aux8[k0/2]);
- const int signs_packed = ksigns_iq2xs[(aux32 >> (7*k0/2)) & 0x7F];
- const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000);
- const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0);
- const int u0 = get_int_b4(bq8_1[iqs/2].qs, k0 + 0);
- sumi = ggml_cuda_dp4a(grid0, u0, sumi);
- const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
- const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
- const int u1 = get_int_b4(bq8_1[iqs/2].qs, k0 + 1);
- sumi = ggml_cuda_dp4a(grid1, u1, sumi);
- }
- const int ls = aux32 >> 28;
- sumi = (ls*sumi + sumi/2)/4;
- const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds);
- return d * sumi;
- }
- #define VDR_IQ2_XS_Q8_1_MMVQ 2
- #define VDR_IQ2_XS_Q8_1_MMQ 2
- static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
- const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
- const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq + kbx;
- const int2 q2_packed = make_int2(get_int_b2(bq2->qs, iqs + 0), get_int_b2(bq2->qs, iqs + 1));
- const uint16_t * q2 = (const uint16_t *) &q2_packed;
- const int ls0 = bq2->scales[iqs/2] & 0x0F;
- const int ls1 = bq2->scales[iqs/2] >> 4;
- int sumi0 = 0;
- int sumi1 = 0;
- #pragma unroll
- for (int l0 = 0; l0 < 8; l0 += 2) {
- const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l0/2] & 0x000001FF));
- const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l0/2] >> 9));
- const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
- const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
- const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
- const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
- if (l0 < 4) {
- sumi0 = ggml_cuda_dp4a(grid_l, u0, sumi0);
- sumi0 = ggml_cuda_dp4a(grid_h, u1, sumi0);
- } else {
- sumi1 = ggml_cuda_dp4a(grid_l, u0, sumi1);
- sumi1 = ggml_cuda_dp4a(grid_h, u1, sumi1);
- }
- }
- const int sumi = (sumi0*ls0 + sumi1*ls1 + (sumi0 + sumi1)/2)/4;
- const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds);
- return d * sumi;
- }
- #define VDR_IQ2_S_Q8_1_MMVQ 2
- #define VDR_IQ2_S_Q8_1_MMQ 2
- static __device__ __forceinline__ float vec_dot_iq2_s_q8_1(
- const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
- const block_iq2_s * bq2 = (const block_iq2_s *) vbq + kbx;
- const int qs_packed = get_int_b2(bq2->qs, iqs/2);
- const uint8_t * qs = (const uint8_t *) &qs_packed;
- const int qh = bq2->qh[iqs/2];
- const int signs_packed_32 = get_int_b2(bq2->qs, QK_K/32 + iqs/2);
- const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32;
- const int ls0 = bq2->scales[iqs/2] & 0x0F;
- const int ls1 = bq2->scales[iqs/2] >> 4;
- int sumi0 = 0;
- int sumi1 = 0;
- #pragma unroll
- for (int l0 = 0; l0 < 8; l0 += 2) {
- const int * grid_pos = (const int *)(iq2s_grid + (qs[l0/2] | ((qh << (8-l0)) & 0x300)));
- const int signs0 = __vcmpne4(((signs_packed_8[l0/2] & 0x03) << 7) | ((signs_packed_8[l0/2] & 0x0C) << 21), 0x00000000);
- const int signs1 = __vcmpne4(((signs_packed_8[l0/2] & 0x30) << 3) | ((signs_packed_8[l0/2] & 0xC0) << 17), 0x00000000);
- const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
- const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
- const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
- const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
- if (l0 < 4) {
- sumi0 = ggml_cuda_dp4a(grid_l, u0, sumi0);
- sumi0 = ggml_cuda_dp4a(grid_h, u1, sumi0);
- } else {
- sumi1 = ggml_cuda_dp4a(grid_l, u0, sumi1);
- sumi1 = ggml_cuda_dp4a(grid_h, u1, sumi1);
- }
- }
- const int sumi = (sumi0*ls0 + sumi1*ls1 + (sumi0 + sumi1)/2)/4;
- const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds);
- return d * sumi;
- }
- #define VDR_IQ3_XXS_Q8_1_MMVQ 2
- #define VDR_IQ3_XXS_Q8_1_MMQ 2
- static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
- const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
- const block_iq3_xxs * bq3 = (const block_iq3_xxs *) vbq + kbx;
- const int2 q3_packed = make_int2(get_int_b2(bq3->qs, iqs), get_int_b2(bq3->qs, iqs+1));
- const uint8_t * q3 = (const uint8_t *) &q3_packed;
- const uint32_t aux32 = get_int_b2(bq3->qs, QK_K/16 + iqs/2);
- int sumi = 0;
- #pragma unroll
- for (int l0 = 0; l0 < 8; l0 += 2) {
- const int2 grid_pos = make_int2(iq3xxs_grid[q3[l0 + 0]], iq3xxs_grid[q3[l0 + 1]]);
- const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l0/2)) & 0x7F));
- const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
- const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
- const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
- const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
- sumi = ggml_cuda_dp4a(grid_l, u0, sumi);
- sumi = ggml_cuda_dp4a(grid_h, u1, sumi);
- }
- const int ls = aux32 >> 28;
- sumi = (ls*sumi + sumi/2)/2;
- const float d = __half2float(bq3->d) * __low2float(bq8_1[iqs/2].ds);
- return d * sumi;
- }
- #define VDR_IQ3_S_Q8_1_MMVQ 2
- #define VDR_IQ3_S_Q8_1_MMQ 2
- // TODO: don't use lookup table for signs
- static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
- const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
- const block_iq3_s * bq3 = (const block_iq3_s *) vbq + kbx;
- const int2 qs_packed = make_int2(get_int_b2(bq3->qs, iqs + 0), get_int_b2(bq3->qs, iqs + 1));
- const uint8_t * qs = (const uint8_t *) &qs_packed;
- const int qh = bq3->qh[iqs/2];
- const int signs_packed_32 = get_int_b2(bq3->signs, iqs/2);
- const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32;
- int sumi = 0;
- #pragma unroll
- for (int l0 = 0; l0 < 8; l0 += 2) {
- const int2 grid_pos = make_int2(
- iq3s_grid[qs[l0 + 0] | ((qh << (8 - l0)) & 0x100)],
- iq3s_grid[qs[l0 + 1] | ((qh << (7 - l0)) & 0x100)]);
- const int signs0 = __vcmpne4(((signs_packed_8[l0/2] & 0x03) << 7) | ((signs_packed_8[l0/2] & 0x0C) << 21), 0x00000000);
- const int signs1 = __vcmpne4(((signs_packed_8[l0/2] & 0x30) << 3) | ((signs_packed_8[l0/2] & 0xC0) << 17), 0x00000000);
- const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
- const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
- const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
- const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
- sumi = ggml_cuda_dp4a(grid_l, u0, sumi);
- sumi = ggml_cuda_dp4a(grid_h, u1, sumi);
- }
- sumi *= 1 + 2*((bq3->scales[iqs/4] >> ((iqs << 1) & 0x04)) & 0x0F);
- const float d = __half2float(bq3->d) * __low2float(bq8_1[iqs/2].ds);
- return d * sumi;
- }
- #define VDR_IQ1_S_Q8_1_MMVQ 1
- #define VDR_IQ1_S_Q8_1_MMQ 1
- static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
- const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
- const block_iq1_s * bq1 = (const block_iq1_s *) vbq + kbx;
- const int qs_packed = get_int_b2(bq1->qs, iqs);
- const uint8_t * qs = (const uint8_t *) &qs_packed;
- const int qh = bq1->qh[iqs];
- int sumi = 0;
- #pragma unroll
- for (int l0 = 0; l0 < 8; l0 += 2) {
- const int grid = iq1s_grid_gpu[qs[l0/2] | (((qh >> 3*(l0/2)) & 0x07) << 8)];
- const int grid0 = (grid >> 0) & 0x0F0F0F0F;
- const int grid1 = (grid >> 4) & 0x0F0F0F0F;
- const int u0 = get_int_b4(bq8_1[iqs].qs, l0 + 0);
- const int u1 = get_int_b4(bq8_1[iqs].qs, l0 + 1);
- sumi = ggml_cuda_dp4a(grid0, u0, sumi);
- sumi = ggml_cuda_dp4a(grid1, u1, sumi);
- }
- const float d1q = __half2float(bq1->d) * (((qh >> 11) & 0x0E) + 1);
- const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
- const float2 ds = __half22float2(bq8_1[iqs].ds);
- return d1q * (ds.x*sumi + ds.y*delta);
- }
- #define VDR_IQ1_M_Q8_1_MMVQ 1
- #define VDR_IQ1_M_Q8_1_MMQ 1
- static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
- const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
- const block_iq1_m * bq1 = (const block_iq1_m *) vbq + kbx;
- const int qs_packed = get_int_b4(bq1->qs, iqs);
- const uint8_t * qs = (const uint8_t *) &qs_packed;
- int sumi[2] = {0};
- float sumf[2] = {0.0f};
- #pragma unroll
- for (int l0 = 0; l0 < 8; l0 += 2) {
- const int qhl = bq1->qh[2*iqs + l0/4] >> (4 * ((l0/2) % 2));
- const int grid = iq1s_grid_gpu[qs[l0/2] | ((qhl & 0x07) << 8)];
- const int grid0 = (grid >> 0) & 0x0F0F0F0F;
- const int grid1 = (grid >> 4) & 0x0F0F0F0F;
- const int u0 = get_int_b4(bq8_1[iqs].qs, l0 + 0);
- const int u1 = get_int_b4(bq8_1[iqs].qs, l0 + 1);
- sumi[l0/4] = ggml_cuda_dp4a(grid0, u0, sumi[l0/4]);
- sumi[l0/4] = ggml_cuda_dp4a(grid1, u1, sumi[l0/4]);
- const float delta = -1.0f + IQ1M_DELTA - (qhl & 0x08) * (2.0f*IQ1M_DELTA/0x08);
- int sumy = 0;
- sumy = ggml_cuda_dp4a(u0, 0x01010101, sumy);
- sumy = ggml_cuda_dp4a(u1, 0x01010101, sumy);
- sumf[l0/4] += delta*sumy;
- }
- const uint16_t * sc = (const uint16_t *) bq1->scales;
- iq1m_scale_t scale;
- scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00F0) | ((sc[2] >> 4) & 0x0F00) | (sc[3] & 0xF000);
- const float d = __half2float(scale.f16) * __low2float(bq8_1[iqs].ds);
- const int tmp = sc[iqs/2] >> (6*(iqs%2));
- const int sc0 = 2*((tmp >> 0) & 0x07) + 1;
- const int sc1 = 2*((tmp >> 3) & 0x07) + 1;
- return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1);
- }
- static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) {
- const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
- const int8_t * q0_8 = (const int8_t *) &q0_32;
- const char4 val0_8 = make_char4(
- kvalues_iq4nl[q0_8[0]], kvalues_iq4nl[q0_8[1]], kvalues_iq4nl[q0_8[2]], kvalues_iq4nl[q0_8[3]]);
- const int q1_32 = (q4 >> 4) & 0x0F0F0F0F;
- const int8_t * q1_8 = (const int8_t *) &q1_32;
- const char4 val1_8 = make_char4(
- kvalues_iq4nl[q1_8[0]], kvalues_iq4nl[q1_8[1]], kvalues_iq4nl[q1_8[2]], kvalues_iq4nl[q1_8[3]]);
- return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
- }
- #define VDR_IQ4_NL_Q8_1_MMVQ 2
- #define VDR_IQ4_NL_Q8_1_MMQ 4
- static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
- const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
- const block_iq4_nl * bq4 = (const block_iq4_nl *) vbq + kbx;
- const int * q8 = (const int *) bq8_1->qs + iqs;
- int sumi = 0;
- #pragma unroll
- for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
- const int aux_q4 = get_int_b2(bq4->qs, iqs + l);
- const int2 v = get_int_from_table_16(aux_q4);
- sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
- sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
- }
- const float d = __half2float(bq4->d) * __low2float(bq8_1->ds);
- return d * sumi;
- }
- #define VDR_IQ4_XS_Q8_1_MMVQ 4
- #define VDR_IQ4_XS_Q8_1_MMQ 4
- static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
- const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
- const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq + kbx;
- int sumi = 0;
- #pragma unroll
- for (int j = 0; j < 4; ++j) {
- const int aux_q4 = get_int_b4(bq4->qs, iqs + j);
- const int2 v = get_int_from_table_16(aux_q4);
- const int u0 = get_int_b4(bq8_1[iqs/4].qs, j + 0);
- const int u1 = get_int_b4(bq8_1[iqs/4].qs, j + 4);
- sumi = ggml_cuda_dp4a(v.x, u0, sumi);
- sumi = ggml_cuda_dp4a(v.y, u1, sumi);
- }
- const int ls = ((bq4->scales_l[iqs/8] >> (iqs & 0x04)) & 0x0F) | (((bq4->scales_h >> (iqs/2)) & 0x03) << 4);
- sumi *= ls - 32;
- const float d = __half2float(bq4->d) * __low2float(bq8_1[iqs/4].ds);
- return d * sumi;
- }
|