瀏覽代碼

update llama.cpp

Michael Yang 1 年之前
父節點
當前提交
18ffeeec45
共有 16 個文件被更改,包括 2037 次插入917 次删除
  1. 354 108
      llama/ggml-cuda.cu
  2. 1 1
      llama/ggml-cuda.h
  3. 8 1
      llama/ggml-metal.h
  4. 214 82
      llama/ggml-metal.m
  5. 289 210
      llama/ggml-metal.metal
  6. 1 1
      llama/ggml-mpi.c
  7. 1 1
      llama/ggml-mpi.h
  8. 1 1
      llama/ggml-opencl.cpp
  9. 1 1
      llama/ggml-opencl.h
  10. 173 354
      llama/ggml.c
  11. 87 21
      llama/ggml.h
  12. 328 4
      llama/k_quants.c
  13. 1 1
      llama/k_quants.h
  14. 1 1
      llama/llama-util.h
  15. 513 120
      llama/llama.cpp
  16. 64 10
      llama/llama.h

+ 354 - 108
llama/ggml-cuda.cu

@@ -1,5 +1,5 @@
 /**
 /**
- * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc
+ * llama.cpp - git 7c529cede6e84054e77a3eceab31c53de7b2f55b
  *
  *
  * MIT License
  * MIT License
  *
  *
@@ -246,7 +246,7 @@ typedef struct {
 static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding");
 static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding");
 
 
 #define WARP_SIZE 32
 #define WARP_SIZE 32
-#define MATRIX_ROW_PADDING 256 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
+#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
 
 
 #define CUDA_ADD_BLOCK_SIZE 256
 #define CUDA_ADD_BLOCK_SIZE 256
 #define CUDA_MUL_BLOCK_SIZE 256
 #define CUDA_MUL_BLOCK_SIZE 256
@@ -358,12 +358,10 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
     }
     }
 }
 }
 
 
-static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols) {
+static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
     const int row = blockIdx.x*blockDim.y + threadIdx.y;
     const int row = blockIdx.x*blockDim.y + threadIdx.y;
     const int tid = threadIdx.x;
     const int tid = threadIdx.x;
 
 
-    const float eps = 1e-6f;
-
     float tmp = 0.0f; // partial sum for thread in warp
     float tmp = 0.0f; // partial sum for thread in warp
 
 
     for (int col = tid; col < ncols; col += WARP_SIZE) {
     for (int col = tid; col < ncols; col += WARP_SIZE) {
@@ -961,12 +959,18 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
     uint16_t aux[4];
     uint16_t aux[4];
     const uint8_t * sc = (const uint8_t *)aux;
     const uint8_t * sc = (const uint8_t *)aux;
 
 
+#if K_QUANTS_PER_ITERATION == 2
+    uint32_t q32[4];
+    const uint8_t * q4 = (const uint8_t *)q32;
+#else
+    uint16_t q16[4];
+    const uint8_t * q4 = (const uint8_t *)q16;
+#endif
+
     float tmp = 0; // partial sum for thread in warp
     float tmp = 0; // partial sum for thread in warp
 
 
     for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
     for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
 
 
-        const uint8_t * q1 = x[i].qs + q_offset;
-        const uint8_t * q2 = q1 + 64;
         const float   * y1 = yy + i*QK_K + y_offset;
         const float   * y1 = yy + i*QK_K + y_offset;
         const float   * y2 = y1 + 128;
         const float   * y2 = y1 + 128;
 
 
@@ -979,14 +983,41 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
         aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
         aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
         aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
         aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
 
 
+#if K_QUANTS_PER_ITERATION == 2
+        const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset);
+        const uint32_t * q2 = q1 + 16;
+
+        q32[0] = q1[0] & 0x0f0f0f0f;
+        q32[1] = q1[0] & 0xf0f0f0f0;
+        q32[2] = q2[0] & 0x0f0f0f0f;
+        q32[3] = q2[0] & 0xf0f0f0f0;
+
         float4 s = {0.f, 0.f, 0.f, 0.f};
         float4 s = {0.f, 0.f, 0.f, 0.f};
         float smin = 0;
         float smin = 0;
-        for (int l = 0; l < n; ++l) {
-            s.x += y1[l] * (q1[l] & 0xF); s.y += y1[l+32] * (q1[l] >> 4);
-            s.z += y2[l] * (q2[l] & 0xF); s.w += y2[l+32] * (q2[l] >> 4);
+        for (int l = 0; l < 4; ++l) {
+            s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+ 4];
+            s.z += y2[l] * q4[l+8]; s.w += y2[l+32] * q4[l+12];
             smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
             smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
         }
         }
-        tmp += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin;
+        tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
+#else
+        const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset);
+        const uint16_t * q2 = q1 + 32;
+
+        q16[0] = q1[0] & 0x0f0f;
+        q16[1] = q1[0] & 0xf0f0;
+        q16[2] = q2[0] & 0x0f0f;
+        q16[3] = q2[0] & 0xf0f0;
+
+        float4 s = {0.f, 0.f, 0.f, 0.f};
+        float smin = 0;
+        for (int l = 0; l < 2; ++l) {
+            s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2];
+            s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6];
+            smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
+        }
+        tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
+#endif
 
 
     }
     }
 #else
 #else
@@ -1066,10 +1097,12 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx,
     uint16_t aux[4];
     uint16_t aux[4];
     const uint8_t * sc = (const uint8_t *)aux;
     const uint8_t * sc = (const uint8_t *)aux;
 
 
+    uint16_t q16[8];
+    const uint8_t * q4 = (const uint8_t *)q16;
+
     for (int i = ix; i < num_blocks_per_row; i += 2) {
     for (int i = ix; i < num_blocks_per_row; i += 2) {
 
 
         const uint8_t * ql1 = x[i].qs + q_offset;
         const uint8_t * ql1 = x[i].qs + q_offset;
-        const uint8_t * ql2 = ql1 + 64;
         const uint8_t * qh  = x[i].qh + l0;
         const uint8_t * qh  = x[i].qh + l0;
         const float   * y1  = yy + i*QK_K + y_offset;
         const float   * y1  = yy + i*QK_K + y_offset;
         const float   * y2  = y1 + 128;
         const float   * y2  = y1 + 128;
@@ -1085,15 +1118,25 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx,
 
 
         float4 sum = {0.f, 0.f, 0.f, 0.f};
         float4 sum = {0.f, 0.f, 0.f, 0.f};
         float smin = 0;
         float smin = 0;
+        const uint16_t * q1 = (const uint16_t *)ql1;
+        const uint16_t * q2 = q1 + 32;
+        q16[0] = q1[0] & 0x0f0f;
+        q16[1] = q1[8] & 0x0f0f;
+        q16[2] = (q1[0] >> 4) & 0x0f0f;
+        q16[3] = (q1[8] >> 4) & 0x0f0f;
+        q16[4] = q2[0] & 0x0f0f;
+        q16[5] = q2[8] & 0x0f0f;
+        q16[6] = (q2[0] >> 4) & 0x0f0f;
+        q16[7] = (q2[8] >> 4) & 0x0f0f;
         for (int l = 0; l < n; ++l) {
         for (int l = 0; l < n; ++l) {
-            sum.x += y1[l+ 0] * ((ql1[l+ 0] & 0xF) + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
-                   + y1[l+16] * ((ql1[l+16] & 0xF) + (qh[l+16] & (hm1 << 0) ? 16 : 0));
-            sum.y += y1[l+32] * ((ql1[l+ 0] >>  4) + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
-                   + y1[l+48] * ((ql1[l+16] >>  4) + (qh[l+16] & (hm1 << 1) ? 16 : 0));
-            sum.z += y2[l+ 0] * ((ql2[l+ 0] & 0xF) + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
-                   + y2[l+16] * ((ql2[l+16] & 0xF) + (qh[l+16] & (hm2 << 0) ? 16 : 0));
-            sum.w += y2[l+32] * ((ql2[l+ 0] >>  4) + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
-                   + y2[l+48] * ((ql2[l+16] >>  4) + (qh[l+16] & (hm2 << 1) ? 16 : 0));
+            sum.x += y1[l+ 0] * (q4[l +0] + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
+                   + y1[l+16] * (q4[l +2] + (qh[l+16] & (hm1 << 0) ? 16 : 0));
+            sum.y += y1[l+32] * (q4[l +4] + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
+                   + y1[l+48] * (q4[l +6] + (qh[l+16] & (hm1 << 1) ? 16 : 0));
+            sum.z += y2[l+ 0] * (q4[l +8] + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
+                   + y2[l+16] * (q4[l+10] + (qh[l+16] & (hm2 << 0) ? 16 : 0));
+            sum.w += y2[l+32] * (q4[l+12] + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
+                   + y2[l+48] * (q4[l+14] + (qh[l+16] & (hm2 << 1) ? 16 : 0));
             smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
             smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
                   + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
                   + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
         }
         }
@@ -1547,33 +1590,95 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
     const block_q4_K * bq4_K = (const block_q4_K *) vbq;
     const block_q4_K * bq4_K = (const block_q4_K *) vbq;
 
 
-    const int bq8_offset = QR4_K * (iqs / QI8_1);
-
     float sumf_d = 0.0f;
     float sumf_d = 0.0f;
     float sumf_m = 0.0f;
     float sumf_m = 0.0f;
 
 
+#ifndef GGML_QKK_64
+
+    // iqs is in 0...15. bq8_offset = 2 * (iqs/4) -> bq8_offset = 0, 2, 4, 6
+    const int bq8_offset = QR4_K * (iqs / (QI8_1/2));
+
     const float    d = bq4_K->d;
     const float    d = bq4_K->d;
     const float dmin = bq4_K->dmin;
     const float dmin = bq4_K->dmin;
 
 
-    const int v = *((int *) &bq4_K->qs[sizeof(int) * iqs]);
+    // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12
+    // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44
+    // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76
+    // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108
+
+    const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * (iqs%4));
+    const int v1 = q4[0];
+    const int v2 = q4[4];
+
+    const uint16_t * scales = (const uint16_t *)bq4_K->scales;
+    uint16_t aux[2];
+    const int j = bq8_offset/2;
+    if (j < 2) {
+        aux[0] = scales[j+0] & 0x3f3f;
+        aux[1] = scales[j+2] & 0x3f3f;
+    } else {
+        aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
+        aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
+    }
+    const uint8_t * sc = (const uint8_t *)aux;
+    const uint8_t * m  = sc + 2;
 
 
     for (int i = 0; i < QR4_K; ++i) {
     for (int i = 0; i < QR4_K; ++i) {
-        const int isc = bq8_offset + i;
-
-        uint8_t sc, m;
-        get_scale_min_k4(isc, bq4_K->scales, sc, m);
 
 
         const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
         const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
-        const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
         const float d8i = bq8i->d;
         const float d8i = bq8i->d;
+        const int * q8 = (const int *)bq8i->qs + (iqs%4);
+        const int ui1 = q8[0];
+        const int ui2 = q8[4];
 
 
-        const int vi = (v >> (4*i)) & 0x0F0F0F0F;
+        const int vi1 = (v1 >> (4*i)) & 0x0F0F0F0F;
+        const int vi2 = (v2 >> (4*i)) & 0x0F0F0F0F;
 
 
-        sumf_d += d8i * (__dp4a(vi,         ui, 0) * sc); // SIMD dot product
-        sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * m);  // multiply constant part of q4_K with sum of q8_1 values
+        const int dot1 = __dp4a(vi2, ui2, __dp4a(vi1, ui1, 0)); // SIMD dot product
+        const int dot2 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0));
+
+        sumf_d += d8i * (dot1 * sc[i]);
+        sumf_m += d8i * (dot2 * m[i]);  // multiply constant part of q4_K with sum of q8_1 values
     }
     }
 
 
     return d*sumf_d - dmin*sumf_m;
     return d*sumf_d - dmin*sumf_m;
+
+#else
+
+    uint16_t aux16[2];
+    const uint8_t * s = (const uint8_t *)aux16;
+
+    const uint16_t * a = (const uint16_t *)bq4_K->scales;
+    aux16[0] = a[0] & 0x0f0f;
+    aux16[1] = (a[0] >> 4) & 0x0f0f;
+
+    const float dall = bq4_K->d[0];
+    const float dmin = bq4_K->d[1];
+
+    const float d8_1 = bq8_1[0].d;
+    const float d8_2 = bq8_1[1].d;
+
+    const int ui1 = *((const int *)bq8_1[0].qs + iqs);
+    const int ui2 = *((const int *)bq8_1[0].qs + iqs + 4);
+    const int ui3 = *((const int *)bq8_1[1].qs + iqs);
+    const int ui4 = *((const int *)bq8_1[1].qs + iqs + 4);
+
+    const int * q4 = (const int *)bq4_K->qs + iqs;
+    const int v1 = q4[0];
+    const int v2 = q4[4];
+
+    const int dot1 = __dp4a(ui2, v2 & 0x0f0f0f0f, __dp4a(ui1, v1 & 0x0f0f0f0f, 0));
+    const int dot2 = __dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, __dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0));
+    const int dot3 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0));
+    const int dot4 = __dp4a(0x01010101, ui4, __dp4a(0x01010101, ui3, 0));
+
+    sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]);
+    sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]);
+
+    return dall * sumf_d - dmin * sumf_m;
+
+#endif
+
 #else
 #else
     return 0.0f; // only to satisfy the compiler
     return 0.0f; // only to satisfy the compiler
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
@@ -1585,7 +1690,11 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
     const block_q5_K * bq5_K = (const block_q5_K *) vbq;
     const block_q5_K * bq5_K = (const block_q5_K *) vbq;
 
 
-    const int bq8_offset = QR5_K * (iqs / QI8_1);
+#ifndef GGML_QKK_64
+
+    const int bq8_offset = QR5_K * (iqs / (QI8_1/2));
+    const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * (iqs%4));
+    const int * qh = (const int *)(bq5_K->qh + 4 * (iqs%4));
 
 
     float sumf_d = 0.0f;
     float sumf_d = 0.0f;
     float sumf_m = 0.0f;
     float sumf_m = 0.0f;
@@ -1593,31 +1702,87 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
     const float    d = bq5_K->d;
     const float    d = bq5_K->d;
     const float dmin = bq5_K->dmin;
     const float dmin = bq5_K->dmin;
 
 
-    const int vl = *((int *) &bq5_K->qs[sizeof(int) * iqs]);
+    const int vl1 = ql[0];
+    const int vl2 = ql[4];
 
 
-    const int vh = (*((int *) &bq5_K->qh[sizeof(int) * (iqs % (QI5_K/4))])) >> bq8_offset;
+    const int vh1 = qh[0] >> bq8_offset;
+    const int vh2 = qh[4] >> bq8_offset;
 
 
-    for (int i = 0; i < QR5_K; ++i) {
-        const int isc = bq8_offset + i;
+    const uint16_t * scales = (const uint16_t *)bq5_K->scales;
+    uint16_t aux[2];
+    const int j = bq8_offset/2;
+    if (j < 2) {
+        aux[0] = scales[j+0] & 0x3f3f;
+        aux[1] = scales[j+2] & 0x3f3f;
+    } else {
+        aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
+        aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
+    }
+    const uint8_t * sc = (const uint8_t *)aux;
+    const uint8_t * m  = sc + 2;
 
 
-        uint8_t sc, m;
-        get_scale_min_k4(isc, bq5_K->scales, sc, m);
+    for (int i = 0; i < QR5_K; ++i) {
 
 
         const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
         const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
-        const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
         const float d8i = bq8i->d;
         const float d8i = bq8i->d;
+        const int * q8 = (const int *)bq8i->qs + (iqs%4);
+        const int ui1 = q8[0];
+        const int ui2 = q8[4];
 
 
-        const int vil = (vl >> (4*i)) & 0x0F0F0F0F;
+        const int vil1 = (vl1 >> (4*i)) & 0x0F0F0F0F;
+        const int vil2 = (vl2 >> (4*i)) & 0x0F0F0F0F;
 
 
-        const int vih = ((vh >> i) << 4) & 0x10101010;
+        const int vih1 = ((vh1 >> i) << 4) & 0x10101010;
+        const int vih2 = ((vh2 >> i) << 4) & 0x10101010;
 
 
-        const int vi = vil | vih;
+        const int vi1 = vil1 | vih1;
+        const int vi2 = vil2 | vih2;
+
+        const int dot1 = __dp4a(vi2, ui2, __dp4a(vi1, ui1, 0)); // SIMD dot product
+        const int dot2 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0));
+
+        sumf_d += d8i * (dot1 * sc[i]);
+        sumf_m += d8i * (dot2 * m[i]);
 
 
-        sumf_d += d8i * (__dp4a(vi,         ui, 0) * sc); // SIMD dot product
-        sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * m);  // multiply constant part of q5_K with sum of q8_1 values
     }
     }
 
 
     return d*sumf_d - dmin*sumf_m;
     return d*sumf_d - dmin*sumf_m;
+
+#else
+
+    const int8_t * s = bq5_K->scales;
+
+    const float d = bq5_K->d;
+
+    const float d8_1 = bq8_1[0].d;
+    const float d8_2 = bq8_1[1].d;
+
+    const int ui1 = *((const int *)bq8_1[0].qs + iqs);
+    const int ui2 = *((const int *)bq8_1[0].qs + iqs + 4);
+    const int ui3 = *((const int *)bq8_1[1].qs + iqs);
+    const int ui4 = *((const int *)bq8_1[1].qs + iqs + 4);
+
+    const int * ql = (const int *)bq5_K->qs + iqs;
+    const int vl1 = ql[0];
+    const int vl2 = ql[4];
+
+    const int step = 4 * iqs; // 0, 4, 8, 12
+    const int im = step/8; // = 0 for iqs = 0, 1, = 1 for iqs = 2, 3
+    const int in = step%8; // 0, 4, 0, 4
+    const int vh = (*((const int *)(bq5_K->qh + in))) >> im;
+
+    const int v1 = (((vh << 4) & 0x10101010) ^ 0x10101010) | ((vl1 >> 0) & 0x0f0f0f0f);
+    const int v2 = (((vh << 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 0) & 0x0f0f0f0f);
+    const int v3 = (((vh >> 0) & 0x10101010) ^ 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f);
+    const int v4 = (((vh >> 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 4) & 0x0f0f0f0f);
+
+    const float sumf_d = d8_1 * (__dp4a(ui1, v1, 0) * s[0] + __dp4a(ui2, v2, 0) * s[1])
+                       + d8_2 * (__dp4a(ui3, v3, 0) * s[2] + __dp4a(ui4, v4, 0) * s[3]);
+
+    return d * sumf_d;
+
+#endif
+
 #else
 #else
     return 0.0f; // only to satisfy the compiler
     return 0.0f; // only to satisfy the compiler
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
@@ -1771,11 +1936,15 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
     }
     }
 }
 }
 
 
-static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
+static __global__ void mul_mat_p021_f16_f32(
+    const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y) {
+
     const half * x = (const half *) vx;
     const half * x = (const half *) vx;
 
 
     const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
     const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
     const int channel = blockDim.z*blockIdx.z + threadIdx.z;
     const int channel = blockDim.z*blockIdx.z + threadIdx.z;
+    const int channel_x = channel / (nchannels_y / nchannels_x);
 
 
     const int nrows_y = ncols_x;
     const int nrows_y = ncols_x;
     const int nrows_dst = nrows_x;
     const int nrows_dst = nrows_x;
@@ -1791,7 +1960,7 @@ static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const
         }
         }
 
 
         // x is transposed and permuted
         // x is transposed and permuted
-        const int ix = row_x*nchannels_x*ncols_x + channel*ncols_x + col_x;
+        const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x;
         const float xi = __half2float(x[ix]);
         const float xi = __half2float(x[ix]);
 
 
         const int row_y = col_x;
         const int row_y = col_x;
@@ -1819,12 +1988,13 @@ static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const
 
 
 static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
 static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
     const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
     const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
-    const int row_stride_x, const int channel_stride_x) {
+    const int row_stride_x, const int channel_stride_x, const int channel_x_divisor) {
 
 
     const half * x = (const half *) vx;
     const half * x = (const half *) vx;
 
 
     const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
     const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
     const int channel = blockDim.z*blockIdx.z + threadIdx.z;
     const int channel = blockDim.z*blockIdx.z + threadIdx.z;
+    const int channel_x = channel / channel_x_divisor;
 
 
     const int nrows_y = ncols_x;
     const int nrows_y = ncols_x;
     const int nrows_dst = nrows_x;
     const int nrows_dst = nrows_x;
@@ -1841,7 +2011,7 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
             break;
             break;
         }
         }
 
 
-        const int ix = channel*channel_stride_x + row_x*row_stride_x + col_x;
+        const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
         const float xi = __half2float(x[ix]);
         const float xi = __half2float(x[ix]);
 
 
         const int row_y = col_x;
         const int row_y = col_x;
@@ -2053,10 +2223,10 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i
     norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
     norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
 }
 }
 
 
-static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
     GGML_ASSERT(ncols % WARP_SIZE == 0);
     GGML_ASSERT(ncols % WARP_SIZE == 0);
     const dim3 block_dims(WARP_SIZE, 1, 1);
     const dim3 block_dims(WARP_SIZE, 1, 1);
-    rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
+    rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
 }
 }
 
 
 static void quantize_row_q8_1_cuda(const float * x, void * vy, const int ndata, const int k, cudaStream_t stream) {
 static void quantize_row_q8_1_cuda(const float * x, void * vy, const int ndata, const int k, cudaStream_t stream) {
@@ -2285,7 +2455,10 @@ static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float *
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK_K, QI4_K, block_q4_K, vec_dot_q4_K_q8_1>
+    // Note: we use QI4_K/2 instead of QI4_K to make the dot product template require 4 groups of quants to be processed per
+    //       kernel call instead of 2. This results in a better perfmance because the cost of computing the k-quant scales
+    //       is better amortized.
+    mul_mat_vec_q<QK_K, QI4_K/2, block_q4_K, vec_dot_q4_K_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 }
 }
 
 
@@ -2294,7 +2467,10 @@ static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float *
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK_K, QI5_K, block_q5_K, vec_dot_q5_K_q8_1>
+    // Note: we use QI5_K/2 instead of QI5_K to make the dot product template require 4 groups of quants to be processed per
+    //       kernel call instead of 2. This results in a better perfmance because the cost of computing the k-quant scales
+    //       is better amortized.
+    mul_mat_vec_q<QK_K, QI5_K/2, block_q5_K, vec_dot_q5_K_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 }
 }
 
 
@@ -2350,20 +2526,23 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
     }
     }
 }
 }
 
 
-static void ggml_mul_mat_p021_f16_f32_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, cudaStream_t stream) {
-    const dim3 block_nums(1, nrows_x, nchannels_x);
+static void ggml_mul_mat_p021_f16_f32_cuda(
+    const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
+    const int nchannels_x, const int nchannels_y, cudaStream_t stream) {
+
+    const dim3 block_nums(1, nrows_x, nchannels_y);
     const dim3 block_dims(WARP_SIZE, 1, 1);
     const dim3 block_dims(WARP_SIZE, 1, 1);
-    mul_mat_p021_f16_f32<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, nchannels_x);
+    mul_mat_p021_f16_f32<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, nchannels_x, nchannels_y);
 }
 }
 
 
 static void ggml_mul_mat_vec_nc_f16_f32_cuda(
 static void ggml_mul_mat_vec_nc_f16_f32_cuda(
     const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x,
     const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x,
-    const int nchannels_x, const int channel_stride_x, cudaStream_t stream) {
+    const int nchannels_x, const int nchannels_y, const int channel_stride_x, cudaStream_t stream) {
 
 
-    const dim3 block_nums(1, nrows_x, nchannels_x);
+    const dim3 block_nums(1, nrows_x, nchannels_y);
     const dim3 block_dims(WARP_SIZE, 1, 1);
     const dim3 block_dims(WARP_SIZE, 1, 1);
     mul_mat_vec_nc_f16_f32<<<block_nums, block_dims, 0, stream>>>
     mul_mat_vec_nc_f16_f32<<<block_nums, block_dims, 0, stream>>>
-        (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x);
+        (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x, nchannels_y/nchannels_x);
 }
 }
 
 
 static void ggml_cpy_f32_f32_cuda(
 static void ggml_cpy_f32_f32_cuda(
@@ -2449,20 +2628,53 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
     scoped_spin_lock lock(g_cuda_pool_lock);
     scoped_spin_lock lock(g_cuda_pool_lock);
     int id;
     int id;
     CUDA_CHECK(cudaGetDevice(&id));
     CUDA_CHECK(cudaGetDevice(&id));
-
+#ifdef DEBUG_CUDA_MALLOC
+    int nnz = 0;
+    size_t max_size = 0, tot_size = 0;
+#endif
+    size_t best_diff = 1ull << 36;
+    int ibest = -1;
     for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
     for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
         cuda_buffer& b = g_cuda_buffer_pool[id][i];
         cuda_buffer& b = g_cuda_buffer_pool[id][i];
-        if (b.size >= size && b.ptr != nullptr) {
-            void * ptr = b.ptr;
-            *actual_size = b.size;
-            b.ptr = nullptr;
-            b.size = 0;
-            return ptr;
+        if (b.ptr != nullptr) {
+#ifdef DEBUG_CUDA_MALLOC
+            ++nnz;
+            tot_size += b.size;
+            if (b.size > max_size) max_size = b.size;
+#endif
+            if (b.size >= size) {
+                size_t diff = b.size - size;
+                if (diff < best_diff) {
+                    best_diff = diff;
+                    ibest = i;
+                    if (!best_diff) {
+                        void * ptr = b.ptr;
+                        *actual_size = b.size;
+                        b.ptr = nullptr;
+                        b.size = 0;
+                        return ptr;
+                    }
+                }
+            }
         }
         }
     }
     }
+    if (ibest >= 0) {
+        cuda_buffer& b = g_cuda_buffer_pool[id][ibest];
+        void * ptr = b.ptr;
+        *actual_size = b.size;
+        b.ptr = nullptr;
+        b.size = 0;
+        return ptr;
+    }
+#ifdef DEBUG_CUDA_MALLOC
+    fprintf(stderr, "%s: %d buffers, max_size = %u MB, tot_size = %u MB, requested %u MB\n", __func__, nnz,
+            (uint32_t)(max_size/1024/1024), (uint32_t)(tot_size/1024/1024), (uint32_t)(size/1024/1024));
+#endif
     void * ptr;
     void * ptr;
-    CUDA_CHECK(cudaMalloc((void **) &ptr, size));
-    *actual_size = size;
+    size_t look_ahead_size = (size_t) (1.05 * size);
+    look_ahead_size = 256 * ((look_ahead_size + 255)/256);
+    CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size));
+    *actual_size = look_ahead_size;
     return ptr;
     return ptr;
 }
 }
 
 
@@ -2490,7 +2702,9 @@ static size_t g_scratch_offset = 0;
 
 
 static int g_device_count = -1;
 static int g_device_count = -1;
 static int g_main_device = 0;
 static int g_main_device = 0;
+#ifndef GGML_CUDA_FORCE_DMMV
 static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES];
 static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES];
+#endif
 static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
 static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
 
 
 static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
 static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
@@ -2513,7 +2727,9 @@ void ggml_init_cublas() {
             g_tensor_split[id] = total_vram;
             g_tensor_split[id] = total_vram;
             total_vram += prop.totalGlobalMem;
             total_vram += prop.totalGlobalMem;
 
 
+#ifndef GGML_CUDA_FORCE_DMMV
             g_compute_capabilities[id] = 100*prop.major + 10*prop.minor;
             g_compute_capabilities[id] = 100*prop.major + 10*prop.minor;
+#endif
         }
         }
         for (int id = 0; id < g_device_count; ++id) {
         for (int id = 0; id < g_device_count; ++id) {
             g_tensor_split[id] /= total_vram;
             g_tensor_split[id] /= total_vram;
@@ -2538,6 +2754,9 @@ void ggml_init_cublas() {
 }
 }
 
 
 void ggml_cuda_set_tensor_split(const float * tensor_split) {
 void ggml_cuda_set_tensor_split(const float * tensor_split) {
+    if (tensor_split == nullptr) {
+        return;
+    }
     bool all_zero = true;
     bool all_zero = true;
     for (int i = 0; i < g_device_count; ++i) {
     for (int i = 0; i < g_device_count; ++i) {
         if (tensor_split[i] != 0.0f) {
         if (tensor_split[i] != 0.0f) {
@@ -2678,6 +2897,7 @@ inline void ggml_cuda_op_mul(
     (void) dst;
     (void) dst;
     (void) src0_ddq_i;
     (void) src0_ddq_i;
     (void) i02;
     (void) i02;
+    (void) i1;
 }
 }
 
 
 inline void ggml_cuda_op_gelu(
 inline void ggml_cuda_op_gelu(
@@ -2757,8 +2977,11 @@ inline void ggml_cuda_op_rms_norm(
     const int64_t ne00 = src0->ne[0];
     const int64_t ne00 = src0->ne[0];
     const int64_t i01_diff = i01_high - i01_low;
     const int64_t i01_diff = i01_high - i01_low;
 
 
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
+
     // compute
     // compute
-    rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
+    rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, eps, cudaStream_main);
 
 
     (void) src1;
     (void) src1;
     (void) dst;
     (void) dst;
@@ -2805,8 +3028,8 @@ inline void ggml_cuda_op_mul_mat_vec(
 #endif
 #endif
 
 
     if (use_mul_mat_vec_q) {
     if (use_mul_mat_vec_q) {
-        int64_t padded_row_size = ne00 + MATRIX_ROW_PADDING - 1;
-        padded_row_size -= padded_row_size % MATRIX_ROW_PADDING;
+        const int64_t padded_row_size = ne00 % MATRIX_ROW_PADDING == 0 ?
+            ne00 : ne00 - ne00 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING;
         size_t as;
         size_t as;
         void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*sizeof(block_q8_1)/QK8_1, &as);
         void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*sizeof(block_q8_1)/QK8_1, &as);
         quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne00, padded_row_size, cudaStream_main);
         quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne00, padded_row_size, cudaStream_main);
@@ -2973,13 +3196,18 @@ inline void ggml_cuda_op_rope(
     const int64_t ne00 = src0->ne[0];
     const int64_t ne00 = src0->ne[0];
     const int64_t i01_diff = i01_high - i01_low;
     const int64_t i01_diff = i01_high - i01_low;
 
 
-    const int n_past = ((int32_t *) src1->data)[0];
-    const int n_dims = ((int32_t *) src1->data)[1];
-    const int mode   = ((int32_t *) src1->data)[2];
-    const int n_ctx  = ((int32_t *) src1->data)[3];
+    const int n_past = ((int32_t *) dst->op_params)[0];
+    const int n_dims = ((int32_t *) dst->op_params)[1];
+    const int mode   = ((int32_t *) dst->op_params)[2];
+    const int n_ctx  = ((int32_t *) dst->op_params)[3];
+    // RoPE alteration for extended context
 
 
-    const float theta_scale = powf(10000.0, -2.0f/n_dims);
-    const float p = ((mode & 1) == 0 ? n_past + i02 : i02);
+    float freq_base, freq_scale;
+    memcpy(&freq_base,  (int32_t *) dst->op_params + 4, sizeof(float));
+    memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
+
+    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+    const float p = (((mode & 1) == 0 ? n_past + i02 : i02)) * freq_scale;
 
 
     bool is_glm = mode & 4;
     bool is_glm = mode & 4;
 
 
@@ -2992,6 +3220,7 @@ inline void ggml_cuda_op_rope(
         rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);
         rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);
     }
     }
 
 
+    (void) src1;
     (void) dst;
     (void) dst;
     (void) src0_ddq_i;
     (void) src0_ddq_i;
     (void) src1_ddf_i;
     (void) src1_ddf_i;
@@ -3010,11 +3239,12 @@ inline void ggml_cuda_op_diag_mask_inf(
     const int64_t ne01 = src0->ne[1];
     const int64_t ne01 = src0->ne[1];
     const int64_t i01_diff = i01_high - i01_low;
     const int64_t i01_diff = i01_high - i01_low;
 
 
-    const int n_past = ((int32_t *) src1->data)[0];
+    const int n_past = ((int32_t *) dst->op_params)[0];
 
 
     // compute
     // compute
     diag_mask_inf_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_past, cudaStream_main);
     diag_mask_inf_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_past, cudaStream_main);
 
 
+    (void) src1;
     (void) dst;
     (void) dst;
     (void) src0_ddq_i;
     (void) src0_ddq_i;
     (void) src1_ddf_i;
     (void) src1_ddf_i;
@@ -3082,6 +3312,9 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
     const int64_t ne11 = use_src1 ? src1->ne[1] : 1;
     const int64_t ne11 = use_src1 ? src1->ne[1] : 1;
     const int64_t ne12 = use_src1 ? src1->ne[2] : 1;
     const int64_t ne12 = use_src1 ? src1->ne[2] : 1;
     const int64_t ne13 = use_src1 ? src1->ne[3] : 1;
     const int64_t ne13 = use_src1 ? src1->ne[3] : 1;
+    const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1;
+
+    GGML_ASSERT(ne03 == ne13);
 
 
     const int64_t ne0 = dst->ne[0];
     const int64_t ne0 = dst->ne[0];
     const int64_t ne1 = dst->ne[1];
     const int64_t ne1 = dst->ne[1];
@@ -3093,12 +3326,19 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
     GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT);
     GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT);
 
 
     // strides for iteration over dims 3 and 2
     // strides for iteration over dims 3 and 2
-    const int64_t num_iters = flatten_rows ? 1 : ne02 * ne03;
-    const int64_t stride_mod = flatten_rows ? ne02 * ne03 : 1;
+    const int64_t num_iters_0 = ne02 >= ne12 ? ne02*ne03 : ne12*ne13;
+    const int64_t num_iters = flatten_rows ? 1 : num_iters_0;
+    const int64_t stride_mod = flatten_rows ? num_iters_0 : 1;
     const int64_t src0_stride = ne00 * ne01 * stride_mod;
     const int64_t src0_stride = ne00 * ne01 * stride_mod;
     const int64_t src1_stride = ne10 * ne11 * stride_mod;
     const int64_t src1_stride = ne10 * ne11 * stride_mod;
     const int64_t dst_stride = ne0 * ne1 * stride_mod;
     const int64_t dst_stride = ne0 * ne1 * stride_mod;
 
 
+    const int64_t rows_per_iter = flatten_rows ? nrows0 : ne01;
+    const int64_t i03_max = flatten_rows ? 1 : ne03;
+    const int64_t i02_max = flatten_rows ? 1 : (ne02 >= ne12 ? ne02 : ne12);
+    const int64_t i02_divisor = ne02 >= ne12 ? 1 : ne12 / ne02;
+    GGML_ASSERT(!(flatten_rows && ne02 < ne12));
+
     const size_t src0_ts = ggml_type_size(src0->type);
     const size_t src0_ts = ggml_type_size(src0->type);
     const size_t src0_bs = ggml_blck_size(src0->type);
     const size_t src0_bs = ggml_blck_size(src0->type);
 
 
@@ -3115,6 +3355,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
         dst->op == GGML_OP_SCALE || dst->op == GGML_OP_DIAG_MASK_INF || dst->op == GGML_OP_ROPE);
         dst->op == GGML_OP_SCALE || dst->op == GGML_OP_DIAG_MASK_INF || dst->op == GGML_OP_ROPE);
 
 
     const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
     const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
+    GGML_ASSERT(!(split && ne02 < ne12));
 
 
     const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
     const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
 
 
@@ -3151,7 +3392,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
             row_high = id == g_device_count - 1 ? nrows0 : nrows0*g_tensor_split[id + 1];
             row_high = id == g_device_count - 1 ? nrows0 : nrows0*g_tensor_split[id + 1];
         } else {
         } else {
             row_low = 0;
             row_low = 0;
-            row_high = nrows0;
+            row_high = nrows0*i02_divisor;
         }
         }
         if (row_low == row_high) {
         if (row_low == row_high) {
             continue;
             continue;
@@ -3199,16 +3440,12 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
             dst_ddf[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_asf[id]);
             dst_ddf[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_asf[id]);
         }
         }
 
 
-        const int64_t i03_max = flatten_rows ? 1 : ne03;
-        const int64_t i02_max = flatten_rows ? 1 : ne02;
-        const int64_t rows_per_iter = flatten_rows ? nrows0 : ne01;
-
         for (int64_t i03 = 0; i03 < i03_max; i03++) {
         for (int64_t i03 = 0; i03 < i03_max; i03++) {
             const int64_t i13 = i03 % ne13;
             const int64_t i13 = i03 % ne13;
             for (int64_t i02 = 0; i02 < i02_max; i02++) {
             for (int64_t i02 = 0; i02 < i02_max; i02++) {
                 const int64_t i12 = i02 % ne12;
                 const int64_t i12 = i02 % ne12;
 
 
-                const int64_t i0 = i03*ne02 + i02;
+                const int64_t i0 = i03*i02_max + i02;
 
 
                 // i0 values that contain the lower/upper rows for a split tensor when using multiple GPUs
                 // i0 values that contain the lower/upper rows for a split tensor when using multiple GPUs
                 const int64_t i0_offset_low = row_low/rows_per_iter;
                 const int64_t i0_offset_low = row_low/rows_per_iter;
@@ -3242,10 +3479,10 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
                 const int64_t i11 = i13*ne12 + i12;
                 const int64_t i11 = i13*ne12 + i12;
 
 
                 // for split tensors the data begins at i0 == i0_offset_low
                 // for split tensors the data begins at i0 == i0_offset_low
-                char  * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs;
-                float * src0_ddf_i = src0_ddf[id] + (i0 - i0_offset_low)*src0_stride;
+                char  * src0_ddq_i = src0_ddq[id] + (i0/i02_divisor - i0_offset_low)*src0_stride*src0_ts/src0_bs;
+                float * src0_ddf_i = src0_ddf[id] + (i0/i02_divisor - i0_offset_low)*src0_stride;
                 float * src1_ddf_i = src1_ddf[id] + i11*src1_stride;
                 float * src1_ddf_i = src1_ddf[id] + i11*src1_stride;
-                float * dst_ddf_i  =  dst_ddf[id] + (i0 - i0_offset_low)*dst_stride;
+                float * dst_ddf_i  =  dst_ddf[id] + (i0             - i0_offset_low)*dst_stride;
 
 
                 // for split tensors the data pointer needs to be rounded down
                 // for split tensors the data pointer needs to be rounded down
                 // to the bin edge for i03, i02 bins beyond the first
                 // to the bin edge for i03, i02 bins beyond the first
@@ -3284,11 +3521,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
                     }
                     }
                 }
                 }
 
 
-                if (!src0_on_device || !src0_is_contiguous) {
+                if ((!src0_on_device || !src0_is_contiguous) && i02 % i02_divisor == 0) {
                     if (src0_is_f32) {
                     if (src0_is_f32) {
-                        CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
+                        CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf_i, src0, i03, i02/i02_divisor, i01_low, i01_high, cudaStream_main));
                     } else {
                     } else {
-                        CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddq_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
+                        CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddq_i, src0, i03, i02/i02_divisor, i01_low, i01_high, cudaStream_main));
                     }
                     }
                 }
                 }
 
 
@@ -3442,6 +3679,8 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
     const int64_t ne01 = src0->ne[1];
     const int64_t ne01 = src0->ne[1];
     const int64_t ne02 = src0->ne[2];
     const int64_t ne02 = src0->ne[2];
 
 
+    const int64_t ne12 = src1->ne[2];
+
     CUDA_CHECK(cudaSetDevice(g_main_device));
     CUDA_CHECK(cudaSetDevice(g_main_device));
     cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
     cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
 
 
@@ -3454,7 +3693,7 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
     struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
     struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
     float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
     float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
 
 
-    ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, cudaStream_main);
+    ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, cudaStream_main);
 }
 }
 
 
 void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
 void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
@@ -3468,6 +3707,8 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
     const int64_t ne01 = src0->ne[1];
     const int64_t ne01 = src0->ne[1];
     const int64_t ne02 = src0->ne[2];
     const int64_t ne02 = src0->ne[2];
 
 
+    const int64_t ne12 = src1->ne[2];
+
     const int64_t nb01 = src0->nb[1];
     const int64_t nb01 = src0->nb[1];
     const int64_t nb02 = src0->nb[2];
     const int64_t nb02 = src0->nb[2];
 
 
@@ -3486,7 +3727,7 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
     const int row_stride_x = nb01 / sizeof(half);
     const int row_stride_x = nb01 / sizeof(half);
     const int channel_stride_x = nb02 / sizeof(half);
     const int channel_stride_x = nb02 / sizeof(half);
 
 
-    ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, channel_stride_x, cudaStream_main);
+    ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, cudaStream_main);
 }
 }
 
 
 void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
 void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -3627,7 +3868,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
         size_t size = ggml_nbytes_split(tensor, nrows_split);
         size_t size = ggml_nbytes_split(tensor, nrows_split);
         const size_t original_size = size;
         const size_t original_size = size;
 
 
-        // pad last row to a multiple of 256 elements to avoid out-of-bounds memory accesses
+        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
         if (ne0 % MATRIX_ROW_PADDING != 0) {
         if (ne0 % MATRIX_ROW_PADDING != 0) {
             size += (MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING)
             size += (MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING)
                 * ggml_type_size(tensor->type)/ggml_blck_size(tensor->type);
                 * ggml_type_size(tensor->type)/ggml_blck_size(tensor->type);
@@ -3643,7 +3884,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
         }
         }
 
 
 
 
-        CUDA_CHECK(cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice));
+        CUDA_CHECK(cudaMemcpy(buf, buf_host, original_size, cudaMemcpyHostToDevice));
 
 
         extra->data_device[id] = buf;
         extra->data_device[id] = buf;
 
 
@@ -3723,7 +3964,7 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
         char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
         char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
         size_t offset = 0;
         size_t offset = 0;
         if (tensor->op == GGML_OP_VIEW) {
         if (tensor->op == GGML_OP_VIEW) {
-            memcpy(&offset, tensor->src[2]->data, sizeof(size_t));
+            memcpy(&offset, tensor->op_params, sizeof(size_t));
         }
         }
         extra = ggml_cuda_alloc_temp_tensor_extra();
         extra = ggml_cuda_alloc_temp_tensor_extra();
         extra->data_device[g_main_device] = src0_ddc + offset;
         extra->data_device[g_main_device] = src0_ddc + offset;
@@ -3825,18 +4066,23 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
             }
             }
             func = ggml_cuda_mul;
             func = ggml_cuda_mul;
             break;
             break;
-        case GGML_OP_GELU:
-            if (!any_on_device) {
-                return false;
-            }
-            func = ggml_cuda_gelu;
-            break;
-        case GGML_OP_SILU:
-            if (!any_on_device) {
-                return false;
-            }
-            func = ggml_cuda_silu;
-            break;
+        case GGML_OP_UNARY:
+            switch (ggml_get_unary_op(tensor)) {
+                case GGML_UNARY_OP_GELU:
+                    if (!any_on_device) {
+                        return false;
+                    }
+                    func = ggml_cuda_gelu;
+                    break;
+                case GGML_UNARY_OP_SILU:
+                    if (!any_on_device) {
+                        return false;
+                    }
+                    func = ggml_cuda_silu;
+                    break;
+                default:
+                    return false;
+            } break;
         case GGML_OP_NORM:
         case GGML_OP_NORM:
             if (!any_on_device) {
             if (!any_on_device) {
                 return false;
                 return false;

+ 1 - 1
llama/ggml-cuda.h

@@ -1,5 +1,5 @@
 /**
 /**
- * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc
+ * llama.cpp - git 7c529cede6e84054e77a3eceab31c53de7b2f55b
  *
  *
  * MIT License
  * MIT License
  *
  *

+ 8 - 1
llama/ggml-metal.h

@@ -1,7 +1,7 @@
 //go:build darwin
 //go:build darwin
 
 
 /**
 /**
- * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc
+ * llama.cpp - git 7c529cede6e84054e77a3eceab31c53de7b2f55b
  *
  *
  * MIT License
  * MIT License
  *
  *
@@ -89,6 +89,13 @@ void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor *
 // get data from the device into host memory
 // get data from the device into host memory
 void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
 void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
 
 
+// try to find operations that can be run concurrently in the graph
+// you should run it again if the topology of your graph changes
+void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
+
+// if the graph has been optimized for concurrently dispatch
+bool ggml_metal_if_optimized(struct ggml_metal_context * ctx);
+
 // same as ggml_graph_compute but uses Metal
 // same as ggml_graph_compute but uses Metal
 // creates gf->n_threads command buffers in parallel
 // creates gf->n_threads command buffers in parallel
 void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
 void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);

+ 214 - 82
llama/ggml-metal.m

@@ -1,7 +1,7 @@
 //go:build darwin
 //go:build darwin
 
 
 /**
 /**
- * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc
+ * llama.cpp - git 7c529cede6e84054e77a3eceab31c53de7b2f55b
  *
  *
  * MIT License
  * MIT License
  *
  *
@@ -64,12 +64,16 @@ struct ggml_metal_context {
     int n_buffers;
     int n_buffers;
     struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
     struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
 
 
+    int concur_list[GGML_MAX_NODES];
+    int concur_list_len;
+
     // custom kernels
     // custom kernels
 #define GGML_METAL_DECL_KERNEL(name) \
 #define GGML_METAL_DECL_KERNEL(name) \
     id<MTLFunction>             function_##name; \
     id<MTLFunction>             function_##name; \
     id<MTLComputePipelineState> pipeline_##name
     id<MTLComputePipelineState> pipeline_##name
 
 
     GGML_METAL_DECL_KERNEL(add);
     GGML_METAL_DECL_KERNEL(add);
+    GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
     GGML_METAL_DECL_KERNEL(mul);
     GGML_METAL_DECL_KERNEL(mul);
     GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
     GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
     GGML_METAL_DECL_KERNEL(scale);
     GGML_METAL_DECL_KERNEL(scale);
@@ -125,6 +129,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
     ctx->device = MTLCreateSystemDefaultDevice();
     ctx->device = MTLCreateSystemDefaultDevice();
     ctx->queue  = [ctx->device newCommandQueue];
     ctx->queue  = [ctx->device newCommandQueue];
     ctx->n_buffers = 0;
     ctx->n_buffers = 0;
+    ctx->concur_list_len = 0;
 
 
     // determine if we can use MPS
     // determine if we can use MPS
     if (MPSSupportsMTLDevice(ctx->device)) {
     if (MPSSupportsMTLDevice(ctx->device)) {
@@ -185,6 +190,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name);
         fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name);
 
 
         GGML_METAL_ADD_KERNEL(add);
         GGML_METAL_ADD_KERNEL(add);
+        GGML_METAL_ADD_KERNEL(add_row);
         GGML_METAL_ADD_KERNEL(mul);
         GGML_METAL_ADD_KERNEL(mul);
         GGML_METAL_ADD_KERNEL(mul_row);
         GGML_METAL_ADD_KERNEL(mul_row);
         GGML_METAL_ADD_KERNEL(scale);
         GGML_METAL_ADD_KERNEL(scale);
@@ -243,6 +249,13 @@ void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
     ctx->n_cb = n_cb;
     ctx->n_cb = n_cb;
 }
 }
 
 
+bool ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
+    if (ctx->concur_list_len) {
+        return true;
+    }
+    return false;
+}
+
 // finds the Metal buffer that contains the tensor data on the GPU device
 // finds the Metal buffer that contains the tensor data on the GPU device
 // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
 // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
 // Metal buffer based on the host memory pointer
 // Metal buffer based on the host memory pointer
@@ -381,11 +394,98 @@ void ggml_metal_get_tensor(
     memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t));
     memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t));
 }
 }
 
 
+void ggml_metal_graph_find_concurrency(
+        struct ggml_metal_context * ctx,
+        struct ggml_cgraph * gf) {
+    int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
+    int nodes_unused[GGML_MAX_NODES];
+
+    for (int i = 0; i < GGML_MAX_NODES; i++) {ctx->concur_list[i] = 0;}
+    for (int i = 0; i < gf->n_nodes; i++) {nodes_unused[i] = 1;}
+    ctx->concur_list_len = 0;
+
+    int n_left = gf->n_nodes;
+    int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
+    int level_pos = 0;  // at ctx->concur_list, the last layer (level) ends at level_pos
+
+    while (n_left > 0) {
+        // number of nodes at a layer (that can be issued concurrently)
+        int concurrency = 0;
+        for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
+            if (nodes_unused[i]) {
+                // if the requirements for gf->nodes[i] are satisfied
+                int exe_flag=1;
+                // scan all srcs
+                for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) {
+                    struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
+                    if (src_cur) {
+                        // if is leaf nodes it's satisfied.
+                        if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {continue;}
+
+                        // otherwise this src should be the output from previous nodes.
+                        int is_found = 0;
+                        // scan 2*search_depth back because we inserted barrier.
+                        for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
+                            if (gf->nodes[ctx->concur_list[j]] == src_cur) {is_found = 1; break;}
+                        }
+                        if (is_found == 0) {exe_flag = 0; break;}
+                    }
+                }
+                if (exe_flag) {
+                    // check if nodes[i]'s data will be overwritten by a node before nodes[i].
+                    // if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
+                    int64_t data_start = (int64_t) gf->nodes[i]->data;
+                    int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
+                    for (int j = n_start; j < i; j++) {
+                        if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \
+                                            && gf->nodes[j]->op != GGML_OP_VIEW \
+                                            && gf->nodes[j]->op != GGML_OP_TRANSPOSE \
+                                            && gf->nodes[j]->op != GGML_OP_PERMUTE) {
+                            if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
+                                ((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
+                                continue;
+                            } else {
+                                exe_flag = 0;
+                            }
+                        }
+                    }
+                }
+                if (exe_flag) {
+                    ctx->concur_list[level_pos + concurrency] = i;
+                    nodes_unused[i] = 0;
+                    concurrency++;
+                    ctx->concur_list_len++;
+                }
+            }
+        }
+        n_left -= concurrency;
+        // adding a barrier different layer
+        ctx->concur_list[level_pos + concurrency] = -1;
+        ctx->concur_list_len++;
+        // jump all sorted nodes at nodes_bak
+        while (!nodes_unused[n_start]) {n_start++;}
+        level_pos += concurrency + 1;
+    }
+
+    if (ctx->concur_list_len > GGML_MAX_NODES) {
+        fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
+    }
+}
+
 void ggml_metal_graph_compute(
 void ggml_metal_graph_compute(
         struct ggml_metal_context * ctx,
         struct ggml_metal_context * ctx,
                struct ggml_cgraph * gf) {
                struct ggml_cgraph * gf) {
     metal_printf("%s: evaluating graph\n", __func__);
     metal_printf("%s: evaluating graph\n", __func__);
 
 
+    // if there is ctx->concur_list, dispatch concurrently
+    // else fallback to serial dispatch
+    MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
+
+    const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_NODES;
+
+    const int n_nodes  = has_concur ? ctx->concur_list_len      : gf->n_nodes;
+    edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
+
     // create multiple command buffers and enqueue them
     // create multiple command buffers and enqueue them
     // then, we encode the graph into the command buffers in parallel
     // then, we encode the graph into the command buffers in parallel
 
 
@@ -404,7 +504,7 @@ void ggml_metal_graph_compute(
     dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
     dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
 
 
     for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
     for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
-        const int n_nodes_per_cb = (gf->n_nodes + n_cb - 1) / n_cb;
+        const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
 
 
         dispatch_async(queue, ^{
         dispatch_async(queue, ^{
             size_t offs_src0 = 0;
             size_t offs_src0 = 0;
@@ -415,10 +515,21 @@ void ggml_metal_graph_compute(
 
 
             id<MTLComputeCommandEncoder> encoder = nil;
             id<MTLComputeCommandEncoder> encoder = nil;
 
 
-            const int node_start =                                      (cb_idx + 0) * n_nodes_per_cb;
-            const int node_end   = (cb_idx == n_cb - 1) ? gf->n_nodes : (cb_idx + 1) * n_nodes_per_cb;
+            const int node_start =                                  (cb_idx + 0) * n_nodes_per_cb;
+            const int node_end   = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
+
+            for (int ind = node_start; ind < node_end; ++ind) {
+                const int i = has_concur ? ctx->concur_list[ind] : ind;
+
+                if (i == -1) {
+                    if (encoder == nil) {
+                        encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
+                        continue;
+                    }
+                    [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
+                    continue;
+                }
 
 
-            for (int i = node_start; i < node_end; ++i) {
                 metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
                 metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
 
 
                 struct ggml_tensor * src0 = gf->nodes[i]->src[0];
                 struct ggml_tensor * src0 = gf->nodes[i]->src[0];
@@ -489,13 +600,19 @@ void ggml_metal_graph_compute(
                     case GGML_OP_ADD:
                     case GGML_OP_ADD:
                         {
                         {
                             if (encoder == nil) {
                             if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
+                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                             }
                             }
 
 
-                            [encoder setComputePipelineState:ctx->pipeline_add];
+                            if (ggml_nelements(src1) == ne10) {
+                                // src1 is a row
+                                [encoder setComputePipelineState:ctx->pipeline_add_row];
+                            } else {
+                                [encoder setComputePipelineState:ctx->pipeline_add];
+                            }
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
                             [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
 
 
                             const int64_t n = ggml_nelements(dst);
                             const int64_t n = ggml_nelements(dst);
 
 
@@ -504,7 +621,7 @@ void ggml_metal_graph_compute(
                     case GGML_OP_MUL:
                     case GGML_OP_MUL:
                         {
                         {
                             if (encoder == nil) {
                             if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
+                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                             }
                             }
 
 
                             if (ggml_nelements(src1) == ne10) {
                             if (ggml_nelements(src1) == ne10) {
@@ -525,7 +642,7 @@ void ggml_metal_graph_compute(
                     case GGML_OP_SCALE:
                     case GGML_OP_SCALE:
                         {
                         {
                             if (encoder == nil) {
                             if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
+                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                             }
                             }
 
 
                             const float scale = *(const float *) src1->data;
                             const float scale = *(const float *) src1->data;
@@ -539,52 +656,60 @@ void ggml_metal_graph_compute(
 
 
                             [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                             [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                         } break;
                         } break;
-                    case GGML_OP_SILU:
-                        {
-                            if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
-                            }
-
-                            [encoder setComputePipelineState:ctx->pipeline_silu];
-                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-
-                            const int64_t n = ggml_nelements(dst);
-
-                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                        } break;
-                    case GGML_OP_RELU:
-                        {
-                            if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
-                            }
-
-                            [encoder setComputePipelineState:ctx->pipeline_relu];
-                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-
-                            const int64_t n = ggml_nelements(dst);
-
-                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                    case GGML_OP_UNARY:
+                        switch (ggml_get_unary_op(gf->nodes[i])) {
+                            case GGML_UNARY_OP_SILU:
+                                {
+                                    if (encoder == nil) {
+                                        encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
+                                    }
+
+                                    [encoder setComputePipelineState:ctx->pipeline_silu];
+                                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                                    const int64_t n = ggml_nelements(dst);
+
+                                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                                } break;
+                            case GGML_UNARY_OP_RELU:
+                                {
+                                    if (encoder == nil) {
+                                        encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
+                                    }
+
+                                    [encoder setComputePipelineState:ctx->pipeline_relu];
+                                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                                    const int64_t n = ggml_nelements(dst);
+
+                                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                                } break;
+                            case GGML_UNARY_OP_GELU:
+                                {
+                                    if (encoder == nil) {
+                                        encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
+                                    }
+
+                                    [encoder setComputePipelineState:ctx->pipeline_gelu];
+                                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                                    const int64_t n = ggml_nelements(dst);
+
+                                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                                } break;
+                            default:
+                                {
+                                    fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
+                                    GGML_ASSERT(false);
+                                }
                         } break;
                         } break;
-                    case GGML_OP_GELU:
-                    {
-                            if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
-                            }
-
-                            [encoder setComputePipelineState:ctx->pipeline_gelu];
-                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-
-                            const int64_t n = ggml_nelements(dst);
-
-                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                    } break;
                     case GGML_OP_SOFT_MAX:
                     case GGML_OP_SOFT_MAX:
                         {
                         {
                             if (encoder == nil) {
                             if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
+                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                             }
                             }
 
 
                             const int nth = 32;
                             const int nth = 32;
@@ -602,10 +727,10 @@ void ggml_metal_graph_compute(
                     case GGML_OP_DIAG_MASK_INF:
                     case GGML_OP_DIAG_MASK_INF:
                         {
                         {
                             if (encoder == nil) {
                             if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
+                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                             }
                             }
 
 
-                            const int n_past = ((int32_t *)(src1->data))[0];
+                            const int n_past = ((int32_t *)(dst->op_params))[0];
 
 
                             [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
                             [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -665,7 +790,7 @@ void ggml_metal_graph_compute(
                                 }
                                 }
                             } else {
                             } else {
                                 if (encoder == nil) {
                                 if (encoder == nil) {
-                                    encoder = [command_buffer computeCommandEncoder];
+                                    encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                                 }
                                 }
 
 
                                 int nth0 = 32;
                                 int nth0 = 32;
@@ -704,8 +829,8 @@ void ggml_metal_graph_compute(
                                             GGML_ASSERT(ne02 == 1);
                                             GGML_ASSERT(ne02 == 1);
                                             GGML_ASSERT(ne12 == 1);
                                             GGML_ASSERT(ne12 == 1);
 
 
-                                            nth0 = 4;
-                                            nth1 = 16;
+                                            nth0 = 2;
+                                            nth1 = 32;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
                                         } break;
                                         } break;
                                     case GGML_TYPE_Q3_K:
                                     case GGML_TYPE_Q3_K:
@@ -713,8 +838,8 @@ void ggml_metal_graph_compute(
                                             GGML_ASSERT(ne02 == 1);
                                             GGML_ASSERT(ne02 == 1);
                                             GGML_ASSERT(ne12 == 1);
                                             GGML_ASSERT(ne12 == 1);
 
 
-                                            nth0 = 4;
-                                            nth1 = 16;
+                                            nth0 = 2;
+                                            nth1 = 32;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
                                         } break;
                                         } break;
                                     case GGML_TYPE_Q4_K:
                                     case GGML_TYPE_Q4_K:
@@ -768,19 +893,21 @@ void ggml_metal_graph_compute(
                                 [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:14];
                                 [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:14];
 
 
                                 if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
                                 if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
-                                    src0t == GGML_TYPE_Q4_K) {
+                                    src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
                                 }
+                                else if (src0t == GGML_TYPE_Q3_K) {
+#ifdef GGML_QKK_64
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+#else
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+#endif
+                                }
                                 else if (src0t == GGML_TYPE_Q5_K) {
                                 else if (src0t == GGML_TYPE_Q5_K) {
                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
                                 }
                                 else if (src0t == GGML_TYPE_Q6_K) {
                                 else if (src0t == GGML_TYPE_Q6_K) {
                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                                }
-                                else if (src0t == GGML_TYPE_Q2_K ||
-                                         src0t == GGML_TYPE_Q3_K) {
-                                    [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
-                                    [encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 } else {
                                 } else {
                                     [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
                                     [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
                                     [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                     [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -790,7 +917,7 @@ void ggml_metal_graph_compute(
                     case GGML_OP_GET_ROWS:
                     case GGML_OP_GET_ROWS:
                         {
                         {
                             if (encoder == nil) {
                             if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
+                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                             }
                             }
 
 
                             switch (src0->type) {
                             switch (src0->type) {
@@ -819,10 +946,11 @@ void ggml_metal_graph_compute(
                     case GGML_OP_RMS_NORM:
                     case GGML_OP_RMS_NORM:
                         {
                         {
                             if (encoder == nil) {
                             if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
+                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                             }
                             }
 
 
-                            const float eps = 1e-6f;
+                            float eps;
+                            memcpy(&eps, dst->op_params, sizeof(float));
 
 
                             const int nth = 512;
                             const int nth = 512;
 
 
@@ -841,7 +969,7 @@ void ggml_metal_graph_compute(
                     case GGML_OP_NORM:
                     case GGML_OP_NORM:
                         {
                         {
                             if (encoder == nil) {
                             if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
+                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                             }
                             }
 
 
                             const float eps = 1e-5f;
                             const float eps = 1e-5f;
@@ -863,14 +991,15 @@ void ggml_metal_graph_compute(
                     case GGML_OP_ALIBI:
                     case GGML_OP_ALIBI:
                         {
                         {
                             if (encoder == nil) {
                             if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
+                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                             }
                             }
 
 
                             GGML_ASSERT((src0t == GGML_TYPE_F32));
                             GGML_ASSERT((src0t == GGML_TYPE_F32));
 
 
-                            const int   n_past   = ((int32_t *) src1->data)[0]; UNUSED(n_past);
-                            const int   n_head   = ((int32_t *) src1->data)[1];
-                            const float max_bias = ((float *)   src1->data)[2];
+                            const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
+                            const int n_head = ((int32_t *) dst->op_params)[1];
+                            float max_bias;
+                            memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
 
 
                             if (__builtin_popcount(n_head) != 1) {
                             if (__builtin_popcount(n_head) != 1) {
                                 GGML_ASSERT(false && "only power-of-two n_head implemented");
                                 GGML_ASSERT(false && "only power-of-two n_head implemented");
@@ -905,18 +1034,17 @@ void ggml_metal_graph_compute(
                     case GGML_OP_ROPE:
                     case GGML_OP_ROPE:
                         {
                         {
                             if (encoder == nil) {
                             if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
+                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                             }
                             }
 
 
-                            const int n_dims = ((int32_t *) src1->data)[1];
-                            const int mode   = ((int32_t *) src1->data)[2];
-
-                            const int n_past = ((int32_t *)(src1->data))[0];
+                            const int n_past = ((int32_t *) dst->op_params)[0];
+                            const int n_dims = ((int32_t *) dst->op_params)[1];
+                            const int mode   = ((int32_t *) dst->op_params)[2];
 
 
                             float freq_base;
                             float freq_base;
                             float freq_scale;
                             float freq_scale;
-                            memcpy(&freq_base,  (int32_t *) src1->data + 4, sizeof(float));
-                            memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
+                            memcpy(&freq_base,  (int32_t *) dst->op_params + 4, sizeof(float));
+                            memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
 
 
                             [encoder setComputePipelineState:ctx->pipeline_rope];
                             [encoder setComputePipelineState:ctx->pipeline_rope];
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -945,10 +1073,12 @@ void ggml_metal_graph_compute(
 
 
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                         } break;
                         } break;
+                    case GGML_OP_DUP:
                     case GGML_OP_CPY:
                     case GGML_OP_CPY:
+                    case GGML_OP_CONT:
                         {
                         {
                             if (encoder == nil) {
                             if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
+                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                             }
                             }
 
 
                             const int nth = 32;
                             const int nth = 32;
@@ -995,8 +1125,10 @@ void ggml_metal_graph_compute(
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
                         } break;
                     default:
                     default:
-                        fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
-                        GGML_ASSERT(false);
+                        {
+                            fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
+                            GGML_ASSERT(false);
+                        }
                 }
                 }
             }
             }
 
 

+ 289 - 210
llama/ggml-metal.metal

@@ -1,7 +1,7 @@
 //go:build darwin
 //go:build darwin
 
 
 /**
 /**
- * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc
+ * llama.cpp - git 7c529cede6e84054e77a3eceab31c53de7b2f55b
  *
  *
  * MIT License
  * MIT License
  *
  *
@@ -95,6 +95,17 @@ kernel void kernel_add(
     dst[tpig] = src0[tpig] + src1[tpig];
     dst[tpig] = src0[tpig] + src1[tpig];
 }
 }
 
 
+// assumption: src1 is a row
+// broadcast src1 into src0
+kernel void kernel_add_row(
+        device const float * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] + src1[tpig % ne00];
+}
+
 kernel void kernel_mul(
 kernel void kernel_mul(
         device const float * src0,
         device const float * src0,
         device const float * src1,
         device const float * src1,
@@ -379,7 +390,7 @@ kernel void kernel_rms_norm(
 
 
     threadgroup_barrier(mem_flags::mem_threadgroup);
     threadgroup_barrier(mem_flags::mem_threadgroup);
     // broadcast, simd group number is ntg / 32
     // broadcast, simd group number is ntg / 32
-    for (int i = ntg / 32 / 2; i > 0; i /= 2) {
+    for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
        if (tpitg < i) {
        if (tpitg < i) {
            sum[tpitg] += sum[tpitg + i];
            sum[tpitg] += sum[tpitg + i];
        }
        }
@@ -404,87 +415,90 @@ kernel void kernel_rms_norm(
     }
     }
 }
 }
 
 
-// function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i])
-float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl) {
+// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q4 quants begin (0 or QK4_0/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
     float d = qb_curr->d;
     float d = qb_curr->d;
-    float4 acc = 0.f;
-    device uint16_t * qs = ((device uint16_t *)qb_curr + 1);
-    for (int i = 0; i < 16; i+=2) {
-        acc[0] += yl[i]      * (qs[i / 2] & 0x000F);
-        acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
-        acc[2] += yl[i +  1] * (qs[i / 2] & 0x0F00);
-        acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
-    }
-    return d * (sumy * -8.f + acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f);
+    float2 acc = 0.f;
+    device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
+    for (int i = 0; i < 8; i+=2) {
+        acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
+                + yl[i + 1] * (qs[i / 2] & 0x0F00);
+        acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
+                + yl[i + 9] * (qs[i / 2] & 0xF000);
+    }
+    return d * (sumy * -8.f + acc[0] + acc[1]);
 }
 }
 
 
-// function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i])
-float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl) {
+// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q4 quants begin (0 or QK4_0/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
     float d = qb_curr->d;
     float d = qb_curr->d;
     float m = qb_curr->m;
     float m = qb_curr->m;
-    float4 acc = 0.f;
-    device uint16_t * qs = ((device uint16_t *)qb_curr + 2);
-    for (int i = 0; i < 16; i+=2) {
-        acc[0] += yl[i]      * (qs[i / 2] & 0x000F);
-        acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
-        acc[2] += yl[i +  1] * (qs[i / 2] & 0x0F00);
-        acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
-    }
-    return d * (acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f) + sumy * m;
+    device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
+    float2 acc = 0.f;
+    for (int i = 0; i < 8; i+=2) {
+        acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
+                + yl[i + 1] * (qs[i / 2] & 0x0F00);
+        acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
+                + yl[i + 9] * (qs[i / 2] & 0xF000);
+    }
+    return d * (acc[0] + acc[1]) + sumy * m;
 }
 }
 
 
 // putting them in the kernel cause a significant performance penalty
 // putting them in the kernel cause a significant performance penalty
 #define N_DST 4 // each SIMD group works on 4 rows
 #define N_DST 4 // each SIMD group works on 4 rows
 #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
 #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
 #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
 #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
-template<typename block_q_type>
+//Note: This is a template, but strictly speaking it only applies to
+//      quantizations where the block size is 32. It also does not
+//      giard against the number of rows not being divisible by
+//      N_DST, so this is another explicit assumption of the implementation.
+template<typename block_q_type, int nr, int nsg, int nw>
 void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
 void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
                     int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
                     int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
                     uint2 tgpig, uint tiisg, uint sgitg) {
                     uint2 tgpig, uint tiisg, uint sgitg) {
     const int nb = ne00/QK4_0;
     const int nb = ne00/QK4_0;
     const int r0 = tgpig.x;
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int r1 = tgpig.y;
-    device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
+    const int first_row = (r0 * nsg + sgitg) * nr;
+    device const block_q_type * x = (device const block_q_type *) src0 + first_row * nb;
     device const float      * y = (device const float      *) src1 + r1*ne10;
     device const float      * y = (device const float      *) src1 + r1*ne10;
-    float4 y_curr[8];       // src1 vector cache
-    float sumf[N_DST]={0.f}, all_sum;
-    thread float * yl=(thread float *)y_curr;
+    float yl[16];       // src1 vector cache
+    float sumf[nr]={0.f};
 
 
-    // each thread in a SIMD group deals with 1 block.
-    for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
-        float sumy = 0;
-        for (int i = 0; i < QK4_0 / 4; i++) {
-            y_curr[i] = *((device float4  *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i);
-            sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
-        }
+    const int ix = tiisg/2;
+    const int il = 8*(tiisg%2);
 
 
-        for (int row = 0; row < N_DST; row++) {
-            sumf[row] += block_q_n_dot_y(x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl);
-        }
-    }
+    device const float * yb = y + ix * QK4_0 + il;
 
 
-    // from now loads two rows every time and 16 blocks per row
-    int ir = tiisg / (N_SIMDWIDTH / 2);
-    int ib = tiisg % (N_SIMDWIDTH / 2);
-    for (int ind = 0; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1)/(N_SIMDWIDTH / 2); ind++) {
-        int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2); //where the left blocks start
+    // each thread in a SIMD group deals with half a block.
+    for (int ib = ix; ib < nb; ib += nw/2) {
         float sumy = 0;
         float sumy = 0;
-        for (int i = 0; i < QK4_0 / 4; i++) {
-            y_curr[i] = *((device float4 *)(y + (nb_start + ib) * QK4_0) + i);
-            sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
+        for (int i = 0; i < 8; i += 2) {
+            sumy += yb[i] + yb[i+1];
+            yl[i+0] = yb[i+ 0];
+            yl[i+1] = yb[i+ 1]/256.f;
+            sumy += yb[i+16] + yb[i+17];
+            yl[i+8] = yb[i+16]/16.f;
+            yl[i+9] = yb[i+17]/4096.f;
         }
         }
 
 
-        for (int row = 0; row < N_DST; row+=2) {
-            if (nb_start + ib < nb) {
-                sumf[row + ir] += block_q_n_dot_y(x + (nb_start + ib + (row + ir) * nb), sumy, yl);
-            }
+        for (int row = 0; row < nr; row++) {
+            sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il);
         }
         }
+
+        yb += QK4_0 * 16;
     }
     }
 
 
-    for (int row = 0; row < N_DST; ++row) {
-        all_sum = simd_sum(sumf[row]);
-        if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
-            dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
+    for (int row = 0; row < nr; ++row) {
+        const float tot = simd_sum(sumf[row]);
+        if (tiisg == 0 && first_row + row < ne01) {
+            dst[r1*ne0 + first_row + row] = tot;
         }
         }
     }
     }
 }
 }
@@ -500,7 +514,7 @@ kernel void kernel_mul_mat_q4_0_f32(
         uint2 tgpig[[threadgroup_position_in_grid]],
         uint2 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32<block_q4_0>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
+    mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
 }
 }
 
 
 kernel void kernel_mul_mat_q4_1_f32(
 kernel void kernel_mul_mat_q4_1_f32(
@@ -514,7 +528,7 @@ kernel void kernel_mul_mat_q4_1_f32(
         uint2 tgpig[[threadgroup_position_in_grid]],
         uint2 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
-     mul_vec_q_n_f32<block_q4_1>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
+     mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
 }
 }
 
 
 kernel void kernel_mul_mat_f16_f32(
 kernel void kernel_mul_mat_f16_f32(
@@ -1237,111 +1251,137 @@ kernel void kernel_mul_mat_q2_K_f32(
         constant   int64_t & ne00,
         constant   int64_t & ne00,
         constant   int64_t & ne10,
         constant   int64_t & ne10,
         constant   int64_t & ne0,
         constant   int64_t & ne0,
-        threadgroup float  * sum [[threadgroup(0)]],
+        constant   int64_t & ne01[[buffer(4)]],
         uint2 tgpig[[threadgroup_position_in_grid]],
         uint2 tgpig[[threadgroup_position_in_grid]],
-        uint2 tpitg[[thread_position_in_threadgroup]],
-        uint2  tptg[[threads_per_threadgroup]]) {
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
 
     const int nb = ne00/QK_K;
     const int nb = ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
 
 
-    const int64_t r0 = tgpig.x;
-    const int64_t r1 = tgpig.y;
-
-    device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb;
-    device const float     * yy = (device const float      *) src1 + r1*ne10;
-
-    const int nth = tptg.x*tptg.y;
-    const int ith = tptg.y*tpitg.x + tpitg.y;
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int ib_row = first_row * nb;
+    device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row;
+    device const float      * y = (device const float      *) src1 + r1*ne10;
+    float yl[32];
+    float sumf[N_DST]={0.f}, all_sum;
 
 
-    float sumf = 0;
+    const int step = sizeof(block_q2_K) * nb;
 
 
 #if QK_K == 256
 #if QK_K == 256
-    const int tid = tpitg.y;    // 0...16
-    const int il  = tid/4;      // 0...3
-    const int ir  = tid%4;      // 0...3
-    const int ip  = il/2;       // 0 or 1
-    const int shift1 = 4*(il%2);// 0 or 4
-    const int shift2 = shift1+2;// 2 or 6
-    const int n   = 8;
-    const int is  = 4*il + (n*ir)/16;
+    const int ix = tiisg/8;  // 0...3
+    const int it = tiisg%8;  // 0...7
+    const int im = it/4;     // 0 or 1
+    const int ir = it%4;     // 0...3
+    const int is = (8*ir)/16;// 0 or 1
 
 
-    const int y_offset = 64*il + n*ir;
-    const int q_offset = 32*ip + n*ir;
+    device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir;
 
 
-    for (int i = tpitg.x; i < nb; i += tptg.x) {
+    for (int ib = ix; ib < nb; ib += 4) {
 
 
-        device const uint8_t * q = x[i].qs + q_offset;
-        device const uint8_t * scales = x[i].scales + is;
+        float4 sumy = {0.f, 0.f, 0.f, 0.f};
+        for (int i = 0; i < 8; ++i) {
+            yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
+            yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
+            yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
+            yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
+        }
 
 
-        uint8_t d1 = scales[0] & 0xF;
-        uint8_t d2 = scales[2] & 0xF;
-        uint8_t m1 = scales[0] >>  4;
-        uint8_t m2 = scales[2] >>  4;
+        device const uint8_t  * sc = (device const uint8_t  *)x[ib].scales + 8*im + is;
+        device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
+        device const half     * dh = &x[ib].d;
 
 
-        device const float   * y = yy + i*QK_K + y_offset;
+        for (int row = 0; row < N_DST; row++) {
 
 
-        float2 s = {0.f, 0.f};
-        float smin = 0;
-        for (int l = 0; l < n; ++l) {
-            s[0] += y[l+ 0] * ((q[l] >> shift1) & 3);
-            s[1] += y[l+32] * ((q[l] >> shift2) & 3);
-            smin += y[l+ 0] * m1 + y[l+32] * m2;
+            float4 acc1 = {0.f, 0.f, 0.f, 0.f};
+            float4 acc2 = {0.f, 0.f, 0.f, 0.f};
+            for (int i = 0; i < 8; i += 2) {
+                acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
+                acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
+                acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
+                acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
+                acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
+                acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
+                acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
+                acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
+            }
+            float dall = dh[0];
+            float dmin = dh[1] * 1.f/16.f;
+            sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
+                                 (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f +
+                                 (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f +
+                                 (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
+                         dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
+
+            qs += step/2;
+            sc += step;
+            dh += step/2;
         }
         }
 
 
-        const float dall = (float)x[i].d;
-        const float dmin = (float)x[i].dmin;
-
-        sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
-
+        y4 += 4 * QK_K;
     }
     }
 #else
 #else
-    const int il = 4 * tpitg.x;
+    const int ix = tiisg/2;  // 0...15
+    const int it = tiisg%2;  // 0...1
 
 
-    uint32_t aux[2];
-    thread const uint8_t * d = (thread const uint8_t *)aux;
-    thread const uint8_t * m = (thread const uint8_t *)aux + 4;
+    device const float * y4 = y + ix * QK_K + 8 * it;
 
 
-    for (int i = tpitg.y; i < nb; i += tptg.y) {
+    for (int ib = ix; ib < nb; ib += 16) {
 
 
-        device const uint8_t * q = x[i].qs + il;
-        device const float   * y = yy + i*QK_K + il;
+        float4 sumy = {0.f, 0.f, 0.f, 0.f};
+        for (int i = 0; i < 8; ++i) {
+            yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
+            yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8];
+            yl[i+16] = y4[i+32]; sumy[2] += yl[i+16];
+            yl[i+24] = y4[i+48]; sumy[3] += yl[i+24];
+        }
 
 
-        const float dall = (float)x[i].d;
-        const float dmin = (float)x[i].dmin;
+        device const uint8_t  * sc = (device const uint8_t  *)x[ib].scales;
+        device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
+        device const half     * dh = &x[ib].d;
 
 
-        device const uint32_t * a = (device const uint32_t *)x[i].scales;
-        aux[0] = a[0] & 0x0f0f0f0f;
-        aux[1] = (a[0] >> 4) & 0x0f0f0f0f;
+        for (int row = 0; row < N_DST; row++) {
 
 
-        for (int l = 0; l < 4; ++l) {
-            sumf += y[l+ 0] * (dall * d[0] * ((q[l] >> 0) & 3) - dmin * m[0])
-                  + y[l+16] * (dall * d[1] * ((q[l] >> 2) & 3) - dmin * m[1])
-                  + y[l+32] * (dall * d[2] * ((q[l] >> 4) & 3) - dmin * m[2])
-                  + y[l+48] * (dall * d[3] * ((q[l] >> 6) & 3) - dmin * m[3]);
+            float4 acc1 = {0.f, 0.f, 0.f, 0.f};
+            float4 acc2 = {0.f, 0.f, 0.f, 0.f};
+            for (int i = 0; i < 8; i += 2) {
+                acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
+                acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
+                acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
+                acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
+                acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
+                acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
+                acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
+                acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
+            }
+
+            float dall = dh[0];
+            float dmin = dh[1];
+            sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
+                                 (acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f +
+                                 (acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f +
+                                 (acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) -
+                         dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4));
+
+            qs += step/2;
+            sc += step;
+            dh += step/2;
         }
         }
+
+        y4 += 16 * QK_K;
     }
     }
 #endif
 #endif
 
 
-    sum[ith] = sumf;
-
-    //
-    // Accumulate the sum from all threads in the threadgroup
-    //
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith%4 == 0) {
-        for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith%16 == 0) {
-        for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith == 0) {
-        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
-        dst[r1*ne0 + r0] = sum[0];
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst[r1*ne0 + first_row + row] = all_sum;
+        }
     }
     }
 }
 }
 
 
+#if QK_K == 256
 kernel void kernel_mul_mat_q3_K_f32(
 kernel void kernel_mul_mat_q3_K_f32(
         device const  void * src0,
         device const  void * src0,
         device const float * src1,
         device const float * src1,
@@ -1350,40 +1390,41 @@ kernel void kernel_mul_mat_q3_K_f32(
         constant   int64_t & ne10,
         constant   int64_t & ne10,
         constant   int64_t & ne0,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   int64_t & ne1,
-        threadgroup float  * sum [[threadgroup(0)]],
         uint2 tgpig[[threadgroup_position_in_grid]],
         uint2 tgpig[[threadgroup_position_in_grid]],
-        uint2 tpitg[[thread_position_in_threadgroup]],
-        uint2  tptg[[threads_per_threadgroup]]) {
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
 
     const int nb = ne00/QK_K;
     const int nb = ne00/QK_K;
 
 
     const int64_t r0 = tgpig.x;
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
     const int64_t r1 = tgpig.y;
 
 
-    device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb;
-    device const float     * yy = (device const float      *) src1 + r1*ne10;
-
-    const int nth = tptg.x*tptg.y;
-    const int ith = tptg.y*tpitg.x + tpitg.y;
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
 
 
-#if QK_K == 256
+    device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb;
+    device const float     * yy = (device const float      *) src1 + r1*ne10;
 
 
-    const uint8_t m3 = 3;
-    const int8_t  m4 = 4;
+    float yl[16];
 
 
     const uint16_t kmask1 = 0x0303;
     const uint16_t kmask1 = 0x0303;
     const uint16_t kmask2 = 0x0f0f;
     const uint16_t kmask2 = 0x0f0f;
 
 
-    const int tid = tpitg.y;        // expecting 16
+    const int tid = tiisg/2;
+    const int ix  = tiisg%2;
     const int ip  = tid/8;          // 0 or 1
     const int ip  = tid/8;          // 0 or 1
     const int il  = tid/2 - 4*ip;   // 0...3
     const int il  = tid/2 - 4*ip;   // 0...3
     const int ir  = tid%2;
     const int ir  = tid%2;
     const int n   = 8;
     const int n   = 8;
     const int l0  = n*ir;
     const int l0  = n*ir;
 
 
-    const uint8_t m = 1 << (4*ip + il);
+    const uint16_t m1 = 1 << (4*ip + il);
+    const uint16_t m2 = m1 << 8;
 
 
     const int shift = 2*il;
     const int shift = 2*il;
+    const uint16_t qm1 = 0x0003 << shift;
+    const uint16_t qm2 = 0x0300 << shift;
+    const int32_t v1 = 4 << shift;
+    const int32_t v2 = 1024 << shift;
 
 
     const uint16_t s_shift1 = 4*ip;
     const uint16_t s_shift1 = 4*ip;
     const uint16_t s_shift2 = s_shift1 + 2*(il/2);
     const uint16_t s_shift2 = s_shift1 + 2*(il/2);
@@ -1392,93 +1433,132 @@ kernel void kernel_mul_mat_q3_K_f32(
     const int q_offset = 32*ip + l0;
     const int q_offset = 32*ip + l0;
     const int y_offset = 128*ip + 32*il + l0;
     const int y_offset = 128*ip + 32*il + l0;
 
 
-    //float sumf = 0;
-    float sumf1 = 0, sumf2 = 0;
-    for (int i = tpitg.x; i < nb; i += tptg.x) {
+    const int step = sizeof(block_q3_K) * nb / 2;
 
 
-        const float d_all = (float)(x[i].d);
-
-        device const uint8_t * q = x[i].qs + q_offset;
-        device const uint8_t * h = x[i].hmask + l0;
-        device const float   * y = yy + i * QK_K + y_offset;
+    device const float * y1 = yy + ix*QK_K + y_offset;
 
 
-        device const uint16_t * a = (device const uint16_t *)x[i].scales;
-        const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
+    float sumf1[2] = {0.f}, sumf2[2] = {0.f};
+    for (int i = ix; i < nb; i += 2) {
 
 
-        float s = 0;
-        for (int l = 0; l < n; ++l) {
-            s += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & m3) - ((h[l+ 0] & m) ? 0 : m4));
+        for (int l = 0; l < 8; ++l) {
+            yl[l+0] = y1[l+ 0];
+            yl[l+8] = y1[l+16];
         }
         }
-        float d = d_all * s;
-        sumf1 += d * scales[0];
-        sumf2 += d;
-        //sumf += d_all * s * (scales[0] - 32);
 
 
-        s = 0;
-        for (int l = 0; l < n; ++l) {
-            s += y[l+16] * ((int8_t)((q[l+16] >> shift) & m3) - ((h[l+16] & m) ? 0 : m4));
+        device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
+        device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
+        device const uint16_t * a = (device const uint16_t *)(x[i].scales);
+        device const half * dh = &x[i].d;
+
+        for (int row = 0; row < 2; ++row) {
+
+            const float d_all = (float)dh[0];
+            const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
+
+            float s1 = 0, s2 = 0;
+            for (int l = 0; l < n; l += 2) {
+                const uint16_t qs = q[l/2];
+                s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
+                s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
+            }
+            float d = d_all * (s1 + 1.f/256.f * s2);
+            sumf1[row] += d * scales[0];
+            sumf2[row] += d;
+
+            s1 = s2 = 0;
+            for (int l = 0; l < n; l += 2) {
+                const uint16_t qs = q[l/2+8];
+                s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
+                s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
+            }
+            d = d_all * (s1 + 1.f/256.f * s2);
+            sumf1[row] += d * scales[1];
+            sumf2[row] += d;
+
+            q  += step;
+            h  += step;
+            a  += step;
+            dh += step;
+
         }
         }
-        d = d_all * s;
-        sumf1 += d * scales[1];
-        sumf2 += d;
-        //sumf += d_all * s * (scales[1] - 32);
+
+        y1 += 2 * QK_K;
 
 
     }
     }
 
 
-    //sum[ith] = sumf;
-    sum[ith] = sumf1 - 32.f*sumf2;
+    for (int row = 0; row < 2; ++row) {
+        const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
+        const float tot = simd_sum(sumf);
+        if (tiisg == 0) {
+            dst[r1*ne0 + first_row + row] = tot;
+        }
+    }
+}
 #else
 #else
-    const int il = 4 * tpitg.x;  // 0, 4, 8, 12
+kernel void kernel_mul_mat_q3_K_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne10,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        uint2 tgpig[[threadgroup_position_in_grid]],
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    const int nb = ne00/QK_K;
+
+    const int64_t r0 = tgpig.x;
+    const int64_t r1 = tgpig.y;
+
+    const int row = 2 * r0 + sgitg;
+
+    device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb;
+    device const float     * yy = (device const float      *) src1 + r1*ne10;
+    const int ix = tiisg/4;
+    const int il = 4 * (tiisg%4);// 0, 4, 8, 12
     const int im = il/8;         // 0, 0, 1, 1
     const int im = il/8;         // 0, 0, 1, 1
     const int in = il%8;         // 0, 4, 0, 4
     const int in = il%8;         // 0, 4, 0, 4
 
 
-    float sumf = 0;
+    float2 sum = {0.f, 0.f};
 
 
-    for (int i = tpitg.y; i < nb; i += tptg.y) {
+    for (int i = ix; i < nb; i += 8) {
 
 
         const float d_all = (float)(x[i].d);
         const float d_all = (float)(x[i].d);
 
 
-        device const uint8_t * q = x[i].qs + il;
-        device const uint8_t * h = x[i].hmask + in;
-        device const float   * y = yy + i * QK_K + il;
-
-        const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
-        const float d2 = d_all * ((x[i].scales[0] >>  4) - 8);
-        const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
-        const float d4 = d_all * ((x[i].scales[1] >>  4) - 8);
-
-        for (int l = 0; l < 4; ++l) {
-            const uint8_t hm = h[l] >> im;
-            sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4))
-                  + y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4))
-                  + y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4))
-                  + y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4));
+        device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
+        device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
+        device const uint16_t * s = (device const uint16_t *)(x[i].scales);
+        device const float    * y = yy + i * QK_K + il;
+
+        const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8);
+        const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f;
+        const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f;
+        const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
+
+        for (int l = 0; l < 4; l += 2) {
+            const uint16_t hm = h[l/2] >> im;
+            sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 :  4))
+                    + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
+                    + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
+                    + y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256));
+            sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024))
+                    + y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096))
+                    + y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384))
+                    + y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536));
         }
         }
 
 
     }
     }
+    const float sumf = sum[0] + sum[1] * 1.f/256.f;
 
 
-    sum[ith] = sumf;
-
-#endif
-
-    //
-    // Accumulate the sum from all threads in the threadgroup
-    //
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith%4 == 0) {
-        for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith%16 == 0) {
-        for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith == 0) {
-        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
-        dst[r1*ne0 + r0] = sum[0];
+    const float tot = simd_sum(sumf);
+    if (tiisg == 0) {
+        dst[r1*ne0 + row] = tot;
     }
     }
 
 
 }
 }
+#endif
 
 
 #if QK_K == 256
 #if QK_K == 256
 kernel void kernel_mul_mat_q4_K_f32(
 kernel void kernel_mul_mat_q4_K_f32(
@@ -1776,7 +1856,6 @@ kernel void kernel_mul_mat_q5_K_f32(
 
 
     for (int i = ix; i < nb; i += 8) {
     for (int i = ix; i < nb; i += 8) {
 
 
-        float4 sumy = {0.f, 0.f, 0.f, 0.f};
         for (int l = 0; l < 4; ++l) {
         for (int l = 0; l < 4; ++l) {
             yl[l+0] = y[l+ 0];
             yl[l+0] = y[l+ 0];
             yl[l+4] = y[l+16];
             yl[l+4] = y[l+16];

+ 1 - 1
llama/ggml-mpi.c

@@ -1,7 +1,7 @@
 //go:build mpi
 //go:build mpi
 
 
 /**
 /**
- * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc
+ * llama.cpp - git 7c529cede6e84054e77a3eceab31c53de7b2f55b
  *
  *
  * MIT License
  * MIT License
  *
  *

+ 1 - 1
llama/ggml-mpi.h

@@ -1,7 +1,7 @@
 //go:build mpi
 //go:build mpi
 
 
 /**
 /**
- * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc
+ * llama.cpp - git 7c529cede6e84054e77a3eceab31c53de7b2f55b
  *
  *
  * MIT License
  * MIT License
  *
  *

+ 1 - 1
llama/ggml-opencl.cpp

@@ -1,7 +1,7 @@
 //go:build opencl
 //go:build opencl
 
 
 /**
 /**
- * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc
+ * llama.cpp - git 7c529cede6e84054e77a3eceab31c53de7b2f55b
  *
  *
  * MIT License
  * MIT License
  *
  *

+ 1 - 1
llama/ggml-opencl.h

@@ -1,7 +1,7 @@
 //go:build opencl
 //go:build opencl
 
 
 /**
 /**
- * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc
+ * llama.cpp - git 7c529cede6e84054e77a3eceab31c53de7b2f55b
  *
  *
  * MIT License
  * MIT License
  *
  *

文件差異過大導致無法顯示
+ 173 - 354
llama/ggml.c


+ 87 - 21
llama/ggml.h

@@ -1,5 +1,5 @@
 /**
 /**
- * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc
+ * llama.cpp - git 7c529cede6e84054e77a3eceab31c53de7b2f55b
  *
  *
  * MIT License
  * MIT License
  *
  *
@@ -225,6 +225,7 @@
 #define GGML_MAX_CONTEXTS      64
 #define GGML_MAX_CONTEXTS      64
 #define GGML_MAX_SRC           6
 #define GGML_MAX_SRC           6
 #define GGML_MAX_NAME          48
 #define GGML_MAX_NAME          48
+#define GGML_MAX_OP_PARAMS     32
 #define GGML_DEFAULT_N_THREADS 4
 #define GGML_DEFAULT_N_THREADS 4
 
 
 
 
@@ -233,6 +234,7 @@
 
 
 #define GGML_UNUSED(x) (void)(x)
 #define GGML_UNUSED(x) (void)(x)
 
 
+#define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1))
 
 
 #define GGML_ASSERT(x) \
 #define GGML_ASSERT(x) \
     do { \
     do { \
@@ -355,16 +357,6 @@ extern "C" {
         GGML_OP_ARGMAX,
         GGML_OP_ARGMAX,
         GGML_OP_REPEAT,
         GGML_OP_REPEAT,
         GGML_OP_REPEAT_BACK,
         GGML_OP_REPEAT_BACK,
-        GGML_OP_ABS,
-        GGML_OP_SGN,
-        GGML_OP_NEG,
-        GGML_OP_STEP,
-        GGML_OP_TANH,
-        GGML_OP_ELU,
-        GGML_OP_RELU,
-        GGML_OP_GELU,
-        GGML_OP_GELU_QUICK,
-        GGML_OP_SILU,
         GGML_OP_SILU_BACK,
         GGML_OP_SILU_BACK,
         GGML_OP_NORM, // normalize
         GGML_OP_NORM, // normalize
         GGML_OP_RMS_NORM,
         GGML_OP_RMS_NORM,
@@ -403,6 +395,8 @@ extern "C" {
         GGML_OP_WIN_PART,
         GGML_OP_WIN_PART,
         GGML_OP_WIN_UNPART,
         GGML_OP_WIN_UNPART,
 
 
+        GGML_OP_UNARY,
+
         GGML_OP_MAP_UNARY,
         GGML_OP_MAP_UNARY,
         GGML_OP_MAP_BINARY,
         GGML_OP_MAP_BINARY,
 
 
@@ -416,6 +410,24 @@ extern "C" {
         GGML_OP_COUNT,
         GGML_OP_COUNT,
     };
     };
 
 
+    enum ggml_unary_op {
+        GGML_UNARY_OP_ABS,
+        GGML_UNARY_OP_SGN,
+        GGML_UNARY_OP_NEG,
+        GGML_UNARY_OP_STEP,
+        GGML_UNARY_OP_TANH,
+        GGML_UNARY_OP_ELU,
+        GGML_UNARY_OP_RELU,
+        GGML_UNARY_OP_GELU,
+        GGML_UNARY_OP_GELU_QUICK,
+        GGML_UNARY_OP_SILU,
+    };
+
+    enum ggml_object_type {
+        GGML_OBJECT_TENSOR,
+        GGML_OBJECT_GRAPH,
+        GGML_OBJECT_WORK_BUFFER
+    };
 
 
     // ggml object
     // ggml object
     struct ggml_object {
     struct ggml_object {
@@ -424,7 +436,9 @@ extern "C" {
 
 
         struct ggml_object * next;
         struct ggml_object * next;
 
 
-        char padding[8];
+        enum ggml_object_type type;
+
+        char padding[4];
     };
     };
 
 
     static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
     static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
@@ -444,6 +458,9 @@ extern "C" {
         // compute data
         // compute data
         enum ggml_op op;
         enum ggml_op op;
 
 
+        // op params - allocated as int32_t for alignment
+        int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
+
         bool is_param;
         bool is_param;
 
 
         struct ggml_tensor * grad;
         struct ggml_tensor * grad;
@@ -460,7 +477,7 @@ extern "C" {
 
 
         void * extra; // extra things e.g. for ggml-cuda.cu
         void * extra; // extra things e.g. for ggml-cuda.cu
 
 
-        char padding[8];
+        char padding[4];
     };
     };
 
 
     static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
     static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
@@ -481,6 +498,11 @@ extern "C" {
         void * abort_callback_data;
         void * abort_callback_data;
     };
     };
 
 
+    // next prime after GGML_MAX_NODES
+    // #define GGML_GRAPH_HASHTABLE_SIZE 4099
+    // next prime after GGML_MAX_NODES * 2 (nodes + leafs)
+    #define GGML_GRAPH_HASHTABLE_SIZE 8273
+
     // computation graph
     // computation graph
     struct ggml_cgraph {
     struct ggml_cgraph {
         int n_nodes;
         int n_nodes;
@@ -490,12 +512,16 @@ extern "C" {
         struct ggml_tensor * grads[GGML_MAX_NODES];
         struct ggml_tensor * grads[GGML_MAX_NODES];
         struct ggml_tensor * leafs[GGML_MAX_NODES];
         struct ggml_tensor * leafs[GGML_MAX_NODES];
 
 
+        void * visited_hash_table[GGML_GRAPH_HASHTABLE_SIZE];
+
         // performance
         // performance
         int     perf_runs;
         int     perf_runs;
         int64_t perf_cycles;
         int64_t perf_cycles;
         int64_t perf_time_us;
         int64_t perf_time_us;
     };
     };
 
 
+    static const size_t GGML_GRAPH_SIZE = sizeof(struct ggml_cgraph);
+
     // scratch buffer
     // scratch buffer
     struct ggml_scratch {
     struct ggml_scratch {
         size_t offs;
         size_t offs;
@@ -557,6 +583,7 @@ extern "C" {
 
 
     GGML_API const char * ggml_type_name(enum ggml_type type);
     GGML_API const char * ggml_type_name(enum ggml_type type);
     GGML_API const char * ggml_op_name  (enum ggml_op   op);
     GGML_API const char * ggml_op_name  (enum ggml_op   op);
+    GGML_API const char * ggml_op_symbol(enum ggml_op   op);
 
 
     GGML_API size_t  ggml_element_size(const struct ggml_tensor * tensor);
     GGML_API size_t  ggml_element_size(const struct ggml_tensor * tensor);
 
 
@@ -580,6 +607,7 @@ extern "C" {
     GGML_API size_t  ggml_used_mem(const struct ggml_context * ctx);
     GGML_API size_t  ggml_used_mem(const struct ggml_context * ctx);
 
 
     GGML_API size_t  ggml_set_scratch (struct ggml_context * ctx, struct ggml_scratch scratch);
     GGML_API size_t  ggml_set_scratch (struct ggml_context * ctx, struct ggml_scratch scratch);
+    GGML_API bool    ggml_get_no_alloc(struct ggml_context * ctx);
     GGML_API void    ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc);
     GGML_API void    ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc);
 
 
     GGML_API void *  ggml_get_mem_buffer     (const struct ggml_context * ctx);
     GGML_API void *  ggml_get_mem_buffer     (const struct ggml_context * ctx);
@@ -639,9 +667,11 @@ extern "C" {
     GGML_API void *  ggml_get_data    (const struct ggml_tensor * tensor);
     GGML_API void *  ggml_get_data    (const struct ggml_tensor * tensor);
     GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
     GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
 
 
-    GGML_API const char *         ggml_get_name(const struct ggml_tensor * tensor);
-    GGML_API struct ggml_tensor * ggml_set_name(struct ggml_tensor * tensor, const char * name);
-    GGML_API struct ggml_tensor * ggml_format_name(struct ggml_tensor * tensor, const char * fmt, ...);
+    GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor);
+
+    GGML_API const char *         ggml_get_name   (const struct ggml_tensor * tensor);
+    GGML_API struct ggml_tensor * ggml_set_name   (      struct ggml_tensor * tensor, const char * name);
+    GGML_API struct ggml_tensor * ggml_format_name(      struct ggml_tensor * tensor, const char * fmt, ...);
 
 
     //
     //
     // operations on tensors with backpropagation
     // operations on tensors with backpropagation
@@ -651,6 +681,11 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
             struct ggml_tensor  * a);
 
 
+    // in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_dup_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
     GGML_API struct ggml_tensor * ggml_add(
     GGML_API struct ggml_tensor * ggml_add(
             struct ggml_context * ctx,
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * a,
@@ -875,14 +910,17 @@ extern "C" {
 
 
     GGML_API struct ggml_tensor * ggml_rms_norm(
     GGML_API struct ggml_tensor * ggml_rms_norm(
             struct ggml_context * ctx,
             struct ggml_context * ctx,
-            struct ggml_tensor  * a);
+            struct ggml_tensor  * a,
+            float                 eps);
 
 
     GGML_API struct ggml_tensor * ggml_rms_norm_inplace(
     GGML_API struct ggml_tensor * ggml_rms_norm_inplace(
             struct ggml_context * ctx,
             struct ggml_context * ctx,
-            struct ggml_tensor  * a);
+            struct ggml_tensor  * a,
+            float                 eps);
 
 
     // a - x
     // a - x
     // b - dy
     // b - dy
+    // TODO: update with configurable eps
     GGML_API struct ggml_tensor * ggml_rms_norm_back(
     GGML_API struct ggml_tensor * ggml_rms_norm_back(
             struct ggml_context * ctx,
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * a,
@@ -974,11 +1012,22 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
             struct ggml_tensor  * b);
 
 
+    // a -> b, in-place, return view(b)
+    GGML_API struct ggml_tensor * ggml_cpy_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
     // make contiguous
     // make contiguous
     GGML_API struct ggml_tensor * ggml_cont(
     GGML_API struct ggml_tensor * ggml_cont(
             struct ggml_context * ctx,
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
             struct ggml_tensor  * a);
 
 
+    // make contiguous, in-place
+    GGML_API struct ggml_tensor * ggml_cont_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
     // return view(a), b specifies the new shape
     // return view(a), b specifies the new shape
     // TODO: when we start computing gradient, make a copy instead of view
     // TODO: when we start computing gradient, make a copy instead of view
     GGML_API struct ggml_tensor * ggml_reshape(
     GGML_API struct ggml_tensor * ggml_reshape(
@@ -1154,9 +1203,9 @@ extern "C" {
             int                   n_past,
             int                   n_past,
             int                   n_dims,
             int                   n_dims,
             int                   mode,
             int                   mode,
+            int                   n_ctx,
             float                 freq_base,
             float                 freq_base,
-            float                 freq_scale,
-            int                   n_ctx);
+            float                 freq_scale);
 
 
     // rotary position embedding backward, i.e compute dx from dy
     // rotary position embedding backward, i.e compute dx from dy
     // a - dy
     // a - dy
@@ -1165,7 +1214,8 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * a,
             int                   n_past,
             int                   n_past,
             int                   n_dims,
             int                   n_dims,
-            int                   mode);
+            int                   mode,
+            int                   n_ctx);
 
 
     // alibi position embedding
     // alibi position embedding
     // in-place, returns view(a)
     // in-place, returns view(a)
@@ -1289,6 +1339,16 @@ extern "C" {
     typedef void (*ggml_custom2_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
     typedef void (*ggml_custom2_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
     typedef void (*ggml_custom3_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
     typedef void (*ggml_custom3_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
 
 
+    GGML_API struct ggml_tensor * ggml_unary(
+            struct ggml_context * ctx,
+             struct ggml_tensor * a,
+             enum ggml_unary_op op);
+
+    GGML_API struct ggml_tensor * ggml_unary_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        enum ggml_unary_op op);
+
     GGML_API struct ggml_tensor * ggml_map_unary_f32(
     GGML_API struct ggml_tensor * ggml_map_unary_f32(
             struct ggml_context        * ctx,
             struct ggml_context        * ctx,
             struct ggml_tensor         * a,
             struct ggml_tensor         * a,
@@ -1368,11 +1428,17 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_context * ctx,
             struct ggml_tensor  * tensor);
             struct ggml_tensor  * tensor);
 
 
+
     GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
     GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
 
 
     GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
     GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
     GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
     GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
 
 
+    // graph allocation in a context
+    GGML_API struct ggml_cgraph * ggml_new_graph        (struct ggml_context * ctx);
+    GGML_API struct ggml_cgraph * ggml_build_forward_ctx(struct ggml_context * ctx, struct ggml_tensor * tensor);
+    GGML_API size_t ggml_graph_overhead(void);
+
     // ggml_graph_plan() has to be called before ggml_graph_compute()
     // ggml_graph_plan() has to be called before ggml_graph_compute()
     // when plan.work_size > 0, caller must allocate memory for plan.work_data
     // when plan.work_size > 0, caller must allocate memory for plan.work_data
     GGML_API struct ggml_cplan ggml_graph_plan   (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
     GGML_API struct ggml_cplan ggml_graph_plan   (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);

+ 328 - 4
llama/k_quants.c

@@ -1,5 +1,5 @@
 /**
 /**
- * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc
+ * llama.cpp - git 7c529cede6e84054e77a3eceab31c53de7b2f55b
  *
  *
  * MIT License
  * MIT License
  *
  *
@@ -1692,6 +1692,62 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
 
 
     *s = hsum_float_8(acc) + summs;
     *s = hsum_float_8(acc) + summs;
 
 
+#elif defined __AVX__
+
+    const __m128i m3 = _mm_set1_epi8(3);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    uint32_t ud, um;
+    const uint8_t * restrict db = (const uint8_t *)&ud;
+    const uint8_t * restrict mb = (const uint8_t *)&um;
+
+    float summs = 0;
+
+    // TODO: optimize this
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
+        const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
+
+        const uint8_t * restrict q2 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
+        ud = (sc[0] >> 0) & 0x0f0f0f0f;
+        um = (sc[0] >> 4) & 0x0f0f0f0f;
+
+        int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3];
+        summs += dmin * smin;
+
+        const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2);
+        const __m128i q2_0 = _mm_and_si128(q2bits, m3);
+        const __m128i q2_1 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
+        const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
+        const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
+
+        const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
+        const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
+
+        const __m128i p0 = _mm_maddubs_epi16(q2_0, _mm256_extractf128_si256(q8_0, 0));
+        const __m128i p1 = _mm_maddubs_epi16(q2_1, _mm256_extractf128_si256(q8_0, 1));
+        const __m128i p2 = _mm_maddubs_epi16(q2_2, _mm256_extractf128_si256(q8_1, 0));
+        const __m128i p3 = _mm_maddubs_epi16(q2_3, _mm256_extractf128_si256(q8_1, 1));
+
+        const __m256i p_0 = _mm256_set_m128i(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p0, p0)), _mm_cvtepi16_epi32(p0));
+        const __m256i p_1 = _mm256_set_m128i(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p1, p1)), _mm_cvtepi16_epi32(p1));
+        const __m256i p_2 = _mm256_set_m128i(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p2, p2)), _mm_cvtepi16_epi32(p2));
+        const __m256i p_3 = _mm256_set_m128i(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p3, p3)), _mm_cvtepi16_epi32(p3));
+
+        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0)), acc);
+        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1)), acc);
+        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2)), acc);
+        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3)), acc);
+    }
+
+    *s = hsum_float_8(acc) + summs;
+
 #else
 #else
 
 
     float sumf = 0;
     float sumf = 0;
@@ -2321,6 +2377,93 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
 
 
     *s = hsum_float_8(acc);
     *s = hsum_float_8(acc);
 
 
+#elif defined __AVX__
+
+    const __m128i m3 = _mm_set1_epi8(3);
+    const __m128i m1 = _mm_set1_epi8(1);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    uint64_t aux64;
+
+    uint16_t aux16[2];
+    const int8_t * aux8 = (const int8_t *)aux16;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
+
+        const uint8_t * restrict q3 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const uint16_t a = *(const uint16_t *)x[i].scales;
+        aux16[0] = a & 0x0f0f;
+        aux16[1] = (a >> 4) & 0x0f0f;
+
+        const __m128i scale_0 = _mm_set1_epi16(aux8[0] - 8);
+        const __m128i scale_1 = _mm_set1_epi16(aux8[2] - 8);
+        const __m128i scale_2 = _mm_set1_epi16(aux8[1] - 8);
+        const __m128i scale_3 = _mm_set1_epi16(aux8[3] - 8);
+
+        memcpy(&aux64, x[i].hmask, 8);
+
+        __m128i q3h_0 = _mm_set_epi64x(aux64 >> 1, aux64 >> 0);
+        __m128i q3h_1 = _mm_srli_epi16(q3h_0, 2);
+        __m128i q3h_2 = _mm_srli_epi16(q3h_0, 4);
+        __m128i q3h_3 = _mm_srli_epi16(q3h_0, 6);
+        q3h_0 = _mm_slli_epi16(_mm_andnot_si128(q3h_0, m1), 2);
+        q3h_1 = _mm_slli_epi16(_mm_andnot_si128(q3h_1, m1), 2);
+        q3h_2 = _mm_slli_epi16(_mm_andnot_si128(q3h_2, m1), 2);
+        q3h_3 = _mm_slli_epi16(_mm_andnot_si128(q3h_3, m1), 2);
+
+        // load low 2 bits
+        const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3);
+
+        // prepare low and high bits
+        const __m128i q3l_0 = _mm_and_si128(q3bits, m3);
+        const __m128i q3l_1 = _mm_and_si128(_mm_srli_epi16(q3bits, 2), m3);
+        const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits, 4), m3);
+        const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits, 6), m3);
+
+        // load Q8 quants
+        const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
+        const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
+
+        // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm_maddubs_epi16,
+        // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
+        // and 2 if the high bit was set)
+        const __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, _mm256_extractf128_si256(q8_0, 0));
+        const __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, _mm256_extractf128_si256(q8_0, 1));
+        const __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, _mm256_extractf128_si256(q8_1, 0));
+        const __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, _mm256_extractf128_si256(q8_1, 1));
+
+        __m128i p16_0 = _mm_maddubs_epi16(q3l_0, _mm256_extractf128_si256(q8_0, 0));
+        __m128i p16_1 = _mm_maddubs_epi16(q3l_1, _mm256_extractf128_si256(q8_0, 1));
+        __m128i p16_2 = _mm_maddubs_epi16(q3l_2, _mm256_extractf128_si256(q8_1, 0));
+        __m128i p16_3 = _mm_maddubs_epi16(q3l_3, _mm256_extractf128_si256(q8_1, 1));
+
+        p16_0 = _mm_sub_epi16(p16_0, q8s_0);
+        p16_1 = _mm_sub_epi16(p16_1, q8s_1);
+        p16_2 = _mm_sub_epi16(p16_2, q8s_2);
+        p16_3 = _mm_sub_epi16(p16_3, q8s_3);
+
+        // multiply with scales
+        p16_0 = _mm_madd_epi16(scale_0, p16_0);
+        p16_1 = _mm_madd_epi16(scale_1, p16_1);
+        p16_2 = _mm_madd_epi16(scale_2, p16_2);
+        p16_3 = _mm_madd_epi16(scale_3, p16_3);
+
+        p16_0 = _mm_add_epi32(p16_0, p16_2);
+        p16_1 = _mm_add_epi32(p16_1, p16_3);
+        __m256i p16 = _mm256_set_m128i(p16_1, p16_0);
+
+        // multiply with block scale and accumulate
+        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16)), acc);
+
+    }
+
+    *s = hsum_float_8(acc);
+
 #else
 #else
 
 
     int8_t  aux8[QK_K];
     int8_t  aux8[QK_K];
@@ -2807,6 +2950,60 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
 
 
     *s = hsum_float_8(acc) - summs;
     *s = hsum_float_8(acc) - summs;
 
 
+#elif defined __AVX__
+
+    const __m128i m4 = _mm_set1_epi8(0xF);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    float summs = 0;
+
+    uint16_t aux16[2];
+    const uint8_t * scales = (const uint8_t *)aux16;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = ggml_fp16_to_fp32(x[i].d[0]) * y[i].d;
+        const float m = ggml_fp16_to_fp32(x[i].d[1]) * y[i].d;
+        const __m256 vd = _mm256_set1_ps(d);
+
+        const uint16_t * a = (const uint16_t *)x[i].scales;
+        aux16[0] = a[0] & 0x0f0f;
+        aux16[1] = (a[0] >> 4) & 0x0f0f;
+
+        summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
+
+        const uint8_t * restrict q4 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4);
+        const __m128i q4bits_0 = _mm256_extractf128_si256(q4bits, 0);
+        const __m128i q4bits_1 = _mm256_extractf128_si256(q4bits, 1);
+        const __m128i q4_0 = _mm_and_si128(q4bits_0, m4);
+        const __m128i q4_1 = _mm_and_si128(q4bits_1, m4);
+        const __m128i q4_2 = _mm_and_si128(_mm_srli_epi16(q4bits_0, 4), m4);
+        const __m128i q4_3 = _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4);
+
+        const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
+        const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
+
+        const __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0));
+        const __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1));
+        const __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0));
+        const __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1));
+
+        const __m128i p32_0 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_0);
+        const __m128i p32_1 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_1);
+        acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_set_m128i(p32_1, p32_0))), acc);
+
+        const __m128i p32_2 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_2);
+        const __m128i p32_3 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_3);
+        acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_set_m128i(p32_3, p32_2))), acc);
+
+    }
+
+    *s = hsum_float_8(acc) - summs;
+
 #else
 #else
 
 
     uint8_t aux8[QK_K];
     uint8_t aux8[QK_K];
@@ -3321,10 +3518,66 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
 
 
     *s = hsum_float_8(acc);
     *s = hsum_float_8(acc);
 
 
-#else
+#elif defined __AVX__
 
 
+    const __m128i m4 = _mm_set1_epi8(0xF);
+    const __m128i mone  = _mm_set1_epi8(1);
 
 
-    uint8_t aux8[QK_K];
+    __m256 acc = _mm256_setzero_ps();
+
+    for (int i = 0; i < nb; ++i) {
+
+        const uint8_t * restrict q5 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
+
+        const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5);
+
+        const __m128i scale_0 = _mm_set1_epi16(x[i].scales[0]);
+        const __m128i scale_1 = _mm_set1_epi16(x[i].scales[1]);
+        const __m128i scale_2 = _mm_set1_epi16(x[i].scales[2]);
+        const __m128i scale_3 = _mm_set1_epi16(x[i].scales[3]);
+
+        int64_t aux64;
+        memcpy(&aux64, x[i].qh, 8);
+        const __m128i haux128_0 = _mm_set_epi64x(aux64 >> 1, aux64);
+        const __m128i haux128_1 = _mm_srli_epi16(haux128_0, 2);
+
+        const __m128i q5h_0 = _mm_slli_epi16(_mm_andnot_si128(haux128_0, mone), 4);
+        const __m128i q5h_1 = _mm_slli_epi16(_mm_andnot_si128(haux128_1, mone), 4);
+        const __m128i q5h_2 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_0, 4), mone), 4);
+        const __m128i q5h_3 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_1, 4), mone), 4);
+
+        const __m128i q5l_0 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 0), m4);
+        const __m128i q5l_1 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 1), m4);
+        const __m128i q5l_2 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 0), 4), m4);
+        const __m128i q5l_3 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 1), 4), m4);
+
+        const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
+        const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
+
+        const __m128i p16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5l_0, _mm256_extractf128_si256(q8_0, 0)));
+        const __m128i p16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5l_1, _mm256_extractf128_si256(q8_0, 1)));
+        const __m128i p16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5l_2, _mm256_extractf128_si256(q8_1, 0)));
+        const __m128i p16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5l_3, _mm256_extractf128_si256(q8_1, 1)));
+        const __m128i s16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5h_0, _mm256_extractf128_si256(q8_0, 0)));
+        const __m128i s16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5h_1, _mm256_extractf128_si256(q8_0, 1)));
+        const __m128i s16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5h_2, _mm256_extractf128_si256(q8_1, 0)));
+        const __m128i s16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5h_3, _mm256_extractf128_si256(q8_1, 1)));
+
+        const __m128i dot_0 = _mm_sub_epi32(_mm_add_epi32(p16_0, p16_2), _mm_add_epi32(s16_0, s16_2));
+        const __m128i dot_1 = _mm_sub_epi32(_mm_add_epi32(p16_1, p16_3), _mm_add_epi32(s16_1, s16_3));
+
+        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_set_m128i(dot_1, dot_0))), acc);
+
+    }
+
+    *s = hsum_float_8(acc);
+
+#else
+
+    int8_t aux8[QK_K];
     int16_t aux16[16];
     int16_t aux16[16];
     float   sums [8];
     float   sums [8];
     memset(sums, 0, 8*sizeof(float));
     memset(sums, 0, 8*sizeof(float));
@@ -3334,7 +3587,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
         const uint8_t * restrict q4 = x[i].qs;
         const uint8_t * restrict q4 = x[i].qs;
         const uint8_t * restrict hm = x[i].qh;
         const uint8_t * restrict hm = x[i].qh;
         const  int8_t * restrict q8 = y[i].qs;
         const  int8_t * restrict q8 = y[i].qs;
-        uint8_t * restrict a = aux8;
+        int8_t * restrict a = aux8;
         for (int l = 0; l < 32; ++l) {
         for (int l = 0; l < 32; ++l) {
             a[l+ 0] = q4[l] & 0xF;
             a[l+ 0] = q4[l] & 0xF;
             a[l+32] = q4[l]  >> 4;
             a[l+32] = q4[l]  >> 4;
@@ -3884,6 +4137,77 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
 
 
     *s = hsum_float_8(acc);
     *s = hsum_float_8(acc);
 
 
+#elif defined __AVX__
+
+    const __m128i m4 = _mm_set1_epi8(0xF);
+    const __m128i m2 = _mm_set1_epi8(3);
+    const __m128i m32s = _mm_set1_epi8(32);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
+
+        const uint8_t * restrict q4 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]);
+        const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]);
+        const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]);
+        const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]);
+
+        __m128i sumi_0 = _mm_setzero_si128();
+        __m128i sumi_1 = _mm_setzero_si128();
+
+        const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1);
+        const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3);
+
+        const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4);
+        const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh);
+
+        const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH, m2), 4);
+        const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 2), m2), 4);
+        const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 4), m2), 4);
+        const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 6), m2), 4);
+
+        const __m128i q4_0 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 0), m4), q4h_0);
+        const __m128i q4_1 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 1), m4), q4h_1);
+        const __m128i q4_2 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 0), 4), m4), q4h_2);
+        const __m128i q4_3 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 1), 4), m4), q4h_3);
+
+        const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
+        const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
+
+        __m128i q8s_0 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 0));
+        __m128i q8s_1 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 1));
+        __m128i q8s_2 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 0));
+        __m128i q8s_3 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 1));
+
+        __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0));
+        __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1));
+        __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0));
+        __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1));
+
+        p16_0 = _mm_sub_epi16(p16_0, q8s_0);
+        p16_1 = _mm_sub_epi16(p16_1, q8s_1);
+        p16_2 = _mm_sub_epi16(p16_2, q8s_2);
+        p16_3 = _mm_sub_epi16(p16_3, q8s_3);
+
+        p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
+        p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1);
+        p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
+        p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3);
+
+        sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
+        sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
+
+        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(_mm256_set_m128i(sumi_1, sumi_0))), acc);
+    }
+
+    *s = hsum_float_8(acc);
+
 #else
 #else
 
 
     int8_t  aux8[QK_K];
     int8_t  aux8[QK_K];

+ 1 - 1
llama/k_quants.h

@@ -1,5 +1,5 @@
 /**
 /**
- * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc
+ * llama.cpp - git 7c529cede6e84054e77a3eceab31c53de7b2f55b
  *
  *
  * MIT License
  * MIT License
  *
  *

+ 1 - 1
llama/llama-util.h

@@ -1,5 +1,5 @@
 /**
 /**
- * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc
+ * llama.cpp - git 7c529cede6e84054e77a3eceab31c53de7b2f55b
  *
  *
  * MIT License
  * MIT License
  *
  *

文件差異過大導致無法顯示
+ 513 - 120
llama/llama.cpp


+ 64 - 10
llama/llama.h

@@ -1,5 +1,5 @@
 /**
 /**
- * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc
+ * llama.cpp - git 7c529cede6e84054e77a3eceab31c53de7b2f55b
  *
  *
  * MIT License
  * MIT License
  *
  *
@@ -79,6 +79,10 @@
 #define LLAMA_SUPPORTS_GPU_OFFLOAD
 #define LLAMA_SUPPORTS_GPU_OFFLOAD
 #endif
 #endif
 
 
+#ifndef LLAMA_DEFAULT_RMS_EPS
+#define LLAMA_DEFAULT_RMS_EPS 5e-6f
+#endif
+
 #ifdef __cplusplus
 #ifdef __cplusplus
 extern "C" {
 extern "C" {
 #endif
 #endif
@@ -109,12 +113,15 @@ extern "C" {
     typedef void (*llama_progress_callback)(float progress, void *ctx);
     typedef void (*llama_progress_callback)(float progress, void *ctx);
 
 
    struct llama_context_params {
    struct llama_context_params {
-        uint32_t seed;                         // RNG seed, -1 for random
-        int32_t  n_ctx;                        // text context
-        int32_t  n_batch;                      // prompt processing batch size
-        int32_t  n_gpu_layers;                 // number of layers to store in VRAM
-        int32_t  main_gpu;                     // the GPU that is used for scratch and small tensors
-        float tensor_split[LLAMA_MAX_DEVICES]; // how to split layers across multiple GPUs
+        uint32_t seed;         // RNG seed, -1 for random
+        int32_t  n_ctx;        // text context
+        int32_t  n_batch;      // prompt processing batch size
+        int32_t  n_gqa;        // grouped-query attention (TEMP - will be moved to model hparams)
+        float    rms_norm_eps; // rms norm epsilon (TEMP - will be moved to model hparams)
+        int32_t  n_gpu_layers; // number of layers to store in VRAM
+        int32_t  main_gpu;     // the GPU that is used for scratch and small tensors
+
+        const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
 
 
         // ref: https://github.com/ggerganov/llama.cpp/pull/2054
         // ref: https://github.com/ggerganov/llama.cpp/pull/2054
         float    rope_freq_base;  // RoPE base frequency
         float    rope_freq_base;  // RoPE base frequency
@@ -165,6 +172,40 @@ extern "C" {
         bool quantize_output_tensor; // quantize output.weight
         bool quantize_output_tensor; // quantize output.weight
     } llama_model_quantize_params;
     } llama_model_quantize_params;
 
 
+    // grammar types
+    struct llama_grammar;
+
+    // grammar element type
+    enum llama_gretype {
+        // end of rule definition
+        LLAMA_GRETYPE_END            = 0,
+
+        // start of alternate definition for rule
+        LLAMA_GRETYPE_ALT            = 1,
+
+        // non-terminal element: reference to rule
+        LLAMA_GRETYPE_RULE_REF       = 2,
+
+        // terminal element: character (code point)
+        LLAMA_GRETYPE_CHAR           = 3,
+
+        // inverse char(s) ([^a], [^a-b] [^abc])
+        LLAMA_GRETYPE_CHAR_NOT       = 4,
+
+        // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
+        // be an inclusive range ([a-z])
+        LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
+
+        // modifies a preceding LLAMA_GRETYPE_CHAR or
+        // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
+        LLAMA_GRETYPE_CHAR_ALT       = 6,
+    };
+
+    typedef struct llama_grammar_element {
+        enum llama_gretype type;
+        uint32_t           value; // Unicode code point or rule ID
+    } llama_grammar_element;
+
     // performance timing information
     // performance timing information
     struct llama_timings {
     struct llama_timings {
         double t_start_ms;
         double t_start_ms;
@@ -357,6 +398,15 @@ extern "C" {
     LLAMA_API llama_token llama_token_eos();  // end-of-sentence
     LLAMA_API llama_token llama_token_eos();  // end-of-sentence
     LLAMA_API llama_token llama_token_nl();   // next-line
     LLAMA_API llama_token llama_token_nl();   // next-line
 
 
+    // Grammar
+    //
+    LLAMA_API struct llama_grammar * llama_grammar_init(
+            const llama_grammar_element ** rules,
+                                 size_t    n_rules,
+                                 size_t    start_rule_index);
+
+    LLAMA_API void llama_grammar_free(struct llama_grammar * grammar);
+
     // Sampling functions
     // Sampling functions
 
 
     /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
     /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
@@ -369,13 +419,11 @@ extern "C" {
     /// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.
     /// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.
     /// @params guidance_ctx A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
     /// @params guidance_ctx A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
     /// @params scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
     /// @params scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
-    /// @params smooth_factor Smooth factor between guidance logits and original logits. 1.0f means only use guidance logits. 0.0f means only original logits.
     LLAMA_API void llama_sample_classifier_free_guidance(
     LLAMA_API void llama_sample_classifier_free_guidance(
               struct llama_context * ctx,
               struct llama_context * ctx,
             llama_token_data_array * candidates,
             llama_token_data_array * candidates,
               struct llama_context * guidance_ctx,
               struct llama_context * guidance_ctx,
-                             float   scale,
-                             float   smooth_factor);
+                             float   scale);
 
 
     /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
     /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
     LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);
     LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);
@@ -393,6 +441,9 @@ extern "C" {
     LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep);
     LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep);
     LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
     LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
 
 
+    /// @details Apply constraints from grammar
+    LLAMA_API void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar);
+
     /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
     /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
     /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
     /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
     /// @param tau  The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
     /// @param tau  The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
@@ -414,6 +465,9 @@ extern "C" {
     /// @details Randomly selects a token from the candidates based on their probabilities.
     /// @details Randomly selects a token from the candidates based on their probabilities.
     LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
     LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
 
 
+    /// @details Accepts the sampled token into the grammar
+    LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token);
+
     // Performance information
     // Performance information
     LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
     LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
     LLAMA_API void llama_print_timings(struct llama_context * ctx);
     LLAMA_API void llama_print_timings(struct llama_context * ctx);

部分文件因文件數量過多而無法顯示