瀏覽代碼

Merge pull request #131 from jmorganca/update-llama-cpp

update llama.cpp to e782c9e735f93ab4767ffc37462c523b73a17ddc
Michael Yang 1 年之前
父節點
當前提交
dde880290c
共有 13 個文件被更改,包括 1690 次插入629 次删除
  1. 537 62
      llama/ggml-cuda.cu
  2. 1 1
      llama/ggml-cuda.h
  3. 1 1
      llama/ggml-metal.h
  4. 45 35
      llama/ggml-metal.m
  5. 390 329
      llama/ggml-metal.metal
  6. 440 138
      llama/ggml.c
  7. 49 2
      llama/ggml.h
  8. 1 1
      llama/k_quants.c
  9. 9 1
      llama/k_quants.h
  10. 4 4
      llama/llama-util.h
  11. 120 53
      llama/llama.cpp
  12. 32 2
      llama/llama.h
  13. 61 0
      llama/update-llama-cpp.sh

文件差異過大導致無法顯示
+ 537 - 62
llama/ggml-cuda.cu


+ 1 - 1
llama/ggml-cuda.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 5bf2a2771886ee86137e01dbc7492f78fb392066
+ * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-metal.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 5bf2a2771886ee86137e01dbc7492f78fb392066
+ * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc
  *
  * MIT License
  *

+ 45 - 35
llama/ggml-metal.m

@@ -1,7 +1,7 @@
 // +build darwin
 
 /**
- * llama.cpp - git 5bf2a2771886ee86137e01dbc7492f78fb392066
+ * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc
  *
  * MIT License
  *
@@ -722,8 +722,8 @@ void ggml_metal_graph_compute(
                                             GGML_ASSERT(ne02 == 1);
                                             GGML_ASSERT(ne12 == 1);
 
-                                            nth0 = 4;
-                                            nth1 = 16;
+                                            nth0 = 2;
+                                            nth1 = 32;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
                                         } break;
                                     case GGML_TYPE_Q5_K:
@@ -731,8 +731,8 @@ void ggml_metal_graph_compute(
                                             GGML_ASSERT(ne02 == 1);
                                             GGML_ASSERT(ne12 == 1);
 
-                                            nth0 = 4;
-                                            nth1 = 16;
+                                            nth0 = 2;
+                                            nth1 = 32;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
                                         } break;
                                     case GGML_TYPE_Q6_K:
@@ -740,8 +740,8 @@ void ggml_metal_graph_compute(
                                             GGML_ASSERT(ne02 == 1);
                                             GGML_ASSERT(ne12 == 1);
 
-                                            nth0 = 4;
-                                            nth1 = 16;
+                                            nth0 = 2;
+                                            nth1 = 32;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
                                         } break;
                                     default:
@@ -767,15 +767,18 @@ void ggml_metal_graph_compute(
                                 [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:13];
                                 [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:14];
 
-                                if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) {
-                                    [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
-                                    [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
+                                    src0t == GGML_TYPE_Q4_K) {
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                }
+                                else if (src0t == GGML_TYPE_Q5_K) {
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                }
+                                else if (src0t == GGML_TYPE_Q6_K) {
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
                                 else if (src0t == GGML_TYPE_Q2_K ||
-                                         src0t == GGML_TYPE_Q3_K ||
-                                         src0t == GGML_TYPE_Q4_K ||
-                                         src0t == GGML_TYPE_Q5_K ||
-                                         src0t == GGML_TYPE_Q6_K) {
+                                         src0t == GGML_TYPE_Q3_K) {
                                     [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
                                     [encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 } else {
@@ -821,7 +824,7 @@ void ggml_metal_graph_compute(
 
                             const float eps = 1e-6f;
 
-                            const int nth = 256;
+                            const int nth = 512;
 
                             [encoder setComputePipelineState:ctx->pipeline_rms_norm];
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -829,7 +832,7 @@ void ggml_metal_graph_compute(
                             [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
                             [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
                             [encoder setBytes:&eps  length:sizeof(   float) atIndex:4];
-                            [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
+                            [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
 
                             const int64_t nrows = ggml_nrows(src0);
 
@@ -910,28 +913,35 @@ void ggml_metal_graph_compute(
 
                             const int n_past = ((int32_t *)(src1->data))[0];
 
+                            float freq_base;
+                            float freq_scale;
+                            memcpy(&freq_base,  (int32_t *) src1->data + 4, sizeof(float));
+                            memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
+
                             [encoder setComputePipelineState:ctx->pipeline_rope];
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                            [encoder setBytes:&ne00   length:sizeof( int64_t) atIndex:2];
-                            [encoder setBytes:&ne01   length:sizeof( int64_t) atIndex:3];
-                            [encoder setBytes:&ne02   length:sizeof( int64_t) atIndex:4];
-                            [encoder setBytes:&ne03   length:sizeof( int64_t) atIndex:5];
-                            [encoder setBytes:&nb00   length:sizeof(uint64_t) atIndex:6];
-                            [encoder setBytes:&nb01   length:sizeof(uint64_t) atIndex:7];
-                            [encoder setBytes:&nb02   length:sizeof(uint64_t) atIndex:8];
-                            [encoder setBytes:&nb03   length:sizeof(uint64_t) atIndex:9];
-                            [encoder setBytes:&ne0    length:sizeof( int64_t) atIndex:10];
-                            [encoder setBytes:&ne1    length:sizeof( int64_t) atIndex:11];
-                            [encoder setBytes:&ne2    length:sizeof( int64_t) atIndex:12];
-                            [encoder setBytes:&ne3    length:sizeof( int64_t) atIndex:13];
-                            [encoder setBytes:&nb0    length:sizeof(uint64_t) atIndex:14];
-                            [encoder setBytes:&nb1    length:sizeof(uint64_t) atIndex:15];
-                            [encoder setBytes:&nb2    length:sizeof(uint64_t) atIndex:16];
-                            [encoder setBytes:&nb3    length:sizeof(uint64_t) atIndex:17];
-                            [encoder setBytes:&n_past length:sizeof(     int) atIndex:18];
-                            [encoder setBytes:&n_dims length:sizeof(     int) atIndex:19];
-                            [encoder setBytes:&mode   length:sizeof(     int) atIndex:20];
+                            [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
+                            [encoder setBytes:&ne01    length:sizeof( int64_t) atIndex:3];
+                            [encoder setBytes:&ne02    length:sizeof( int64_t) atIndex:4];
+                            [encoder setBytes:&ne03    length:sizeof( int64_t) atIndex:5];
+                            [encoder setBytes:&nb00    length:sizeof(uint64_t) atIndex:6];
+                            [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:7];
+                            [encoder setBytes:&nb02    length:sizeof(uint64_t) atIndex:8];
+                            [encoder setBytes:&nb03    length:sizeof(uint64_t) atIndex:9];
+                            [encoder setBytes:&ne0     length:sizeof( int64_t) atIndex:10];
+                            [encoder setBytes:&ne1     length:sizeof( int64_t) atIndex:11];
+                            [encoder setBytes:&ne2     length:sizeof( int64_t) atIndex:12];
+                            [encoder setBytes:&ne3     length:sizeof( int64_t) atIndex:13];
+                            [encoder setBytes:&nb0     length:sizeof(uint64_t) atIndex:14];
+                            [encoder setBytes:&nb1     length:sizeof(uint64_t) atIndex:15];
+                            [encoder setBytes:&nb2     length:sizeof(uint64_t) atIndex:16];
+                            [encoder setBytes:&nb3     length:sizeof(uint64_t) atIndex:17];
+                            [encoder setBytes:&n_past  length:sizeof(     int) atIndex:18];
+                            [encoder setBytes:&n_dims  length:sizeof(     int) atIndex:19];
+                            [encoder setBytes:&mode    length:sizeof(     int) atIndex:20];
+                            [encoder setBytes:&freq_base  length:sizeof(float) atIndex:21];
+                            [encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
 
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                         } break;

+ 390 - 329
llama/ggml-metal.metal

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 5bf2a2771886ee86137e01dbc7492f78fb392066
+ * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc
  *
  * MIT License
  *
@@ -357,26 +357,33 @@ kernel void kernel_rms_norm(
         threadgroup float  * sum [[threadgroup(0)]],
         uint tgpig[[threadgroup_position_in_grid]],
         uint tpitg[[thread_position_in_threadgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]],
+        uint tiisg[[thread_index_in_simdgroup]],
         uint   ntg[[threads_per_threadgroup]]) {
-    device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
+    device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
+    device const float * x_scalar = (device const float *) x;
+    float4 sumf=0;
+    float all_sum=0;
 
     // parallel sum
-    sum[tpitg] = 0.0f;
-    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
-        sum[tpitg] += x[i00] * x[i00];
+    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+        sumf += x[i00] * x[i00];
+    }
+    all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
+    all_sum = simd_sum(all_sum);
+    if (tiisg == 0) {
+        sum[sgitg] = all_sum;
     }
 
-    // reduce
     threadgroup_barrier(mem_flags::mem_threadgroup);
-    for (uint i = ntg/2; i > 0; i /= 2) {
-        if (tpitg < i) {
-            sum[tpitg] += sum[tpitg + i];
-        }
-        threadgroup_barrier(mem_flags::mem_threadgroup);
+    // broadcast, simd group number is ntg / 32
+    for (int i = ntg / 32 / 2; i > 0; i /= 2) {
+       if (tpitg < i) {
+           sum[tpitg] += sum[tpitg + i];
+       }
     }
-
-    // broadcast
     if (tpitg == 0) {
+        for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
         sum[0] /= ne00;
     }
 
@@ -385,147 +392,127 @@ kernel void kernel_rms_norm(
     const float mean  = sum[0];
     const float scale = 1.0f/sqrt(mean + eps);
 
-    device float * y = dst + tgpig*ne00;
-    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+    device float4 * y = (device float4 *) (dst + tgpig*ne00);
+    device float * y_scalar = (device float *) y;
+    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
         y[i00] = x[i00] * scale;
     }
+    if (tpitg == 0) {
+        for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
+    }
 }
 
-kernel void kernel_mul_mat_q4_0_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne10,
-        constant   int64_t & ne0,
-        threadgroup float  * sum [[threadgroup(0)]],
-        uint2 tgpig[[threadgroup_position_in_grid]],
-        uint2 tpitg[[thread_position_in_threadgroup]],
-        uint2  tptg[[threads_per_threadgroup]]) {
-    const int nb = ne00/QK4_0;
+// function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i])
+float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl) {
+    float d = qb_curr->d;
+    float4 acc = 0.f;
+    device uint16_t * qs = ((device uint16_t *)qb_curr + 1);
+    for (int i = 0; i < 16; i+=2) {
+        acc[0] += yl[i]      * (qs[i / 2] & 0x000F);
+        acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
+        acc[2] += yl[i +  1] * (qs[i / 2] & 0x0F00);
+        acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
+    }
+    return d * (sumy * -8.f + acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f);
+}
 
-    const int64_t r0 = tgpig.x;
-    const int64_t r1 = tgpig.y;
+// function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i])
+float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl) {
+    float d = qb_curr->d;
+    float m = qb_curr->m;
+    float4 acc = 0.f;
+    device uint16_t * qs = ((device uint16_t *)qb_curr + 2);
+    for (int i = 0; i < 16; i+=2) {
+        acc[0] += yl[i]      * (qs[i / 2] & 0x000F);
+        acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
+        acc[2] += yl[i +  1] * (qs[i / 2] & 0x0F00);
+        acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
+    }
+    return d * (acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f) + sumy * m;
+}
 
-    device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb;
+// putting them in the kernel cause a significant performance penalty
+#define N_DST 4 // each SIMD group works on 4 rows
+#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
+#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
+template<typename block_q_type>
+void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
+                    int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
+                    uint2 tgpig, uint tiisg, uint sgitg) {
+    const int nb = ne00/QK4_0;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
     device const float      * y = (device const float      *) src1 + r1*ne10;
-
-    const int nth = tptg.x*tptg.y;
-    const int ith = tptg.y*tpitg.x + tpitg.y;
-
-    const int ix = tpitg.y/4;           // 0 or 1
-    const int iy = tpitg.y - 4*ix;      // 0...3
-
-    const int first = 4 * iy;
-
-    float sumf = 0;
-
-    for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
-
-        const float d = (float)x[i].d;
-
-        device const uint8_t * xl = x[i].qs + first;
-        device const float   * yl = y + i * QK4_0 + first;
-
-        float2 acc = {0.0f, 0.0f};
-
-        for (int j = 0; j < 4; ++j) {
-
-            acc[0] += yl[j] * (xl[j] & 0xF) + yl[j+16] * (xl[j] >> 4);
-            acc[1] += yl[j] + yl[j+16];
-
+    float4 y_curr[8];       // src1 vector cache
+    float sumf[N_DST]={0.f}, all_sum;
+    thread float * yl=(thread float *)y_curr;
+
+    // each thread in a SIMD group deals with 1 block.
+    for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
+        float sumy = 0;
+        for (int i = 0; i < QK4_0 / 4; i++) {
+            y_curr[i] = *((device float4  *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i);
+            sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
         }
 
-        sumf += d * (acc[0] - 8.f*acc[1]);
+        for (int row = 0; row < N_DST; row++) {
+            sumf[row] += block_q_n_dot_y(x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl);
+        }
     }
 
-    sum[ith] = sumf;
+    // from now loads two rows every time and 16 blocks per row
+    int ir = tiisg / (N_SIMDWIDTH / 2);
+    int ib = tiisg % (N_SIMDWIDTH / 2);
+    for (int ind = 0; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1)/(N_SIMDWIDTH / 2); ind++) {
+        int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2); //where the left blocks start
+        float sumy = 0;
+        for (int i = 0; i < QK4_0 / 4; i++) {
+            y_curr[i] = *((device float4 *)(y + (nb_start + ib) * QK4_0) + i);
+            sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
+        }
 
-    //
-    // Accumulate the sum from all threads in the threadgroup
-    //
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith%4 == 0) {
-        sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith%16 == 0) {
-        sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
+        for (int row = 0; row < N_DST; row+=2) {
+            if (nb_start + ib < nb) {
+                sumf[row + ir] += block_q_n_dot_y(x + (nb_start + ib + (row + ir) * nb), sumy, yl);
+            }
+        }
     }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith == 0) {
-        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
-        dst[r1*ne0 + r0] = sum[0];
+
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
+            dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
+        }
     }
 }
 
-kernel void kernel_mul_mat_q4_1_f32(
+kernel void kernel_mul_mat_q4_0_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
         constant   int64_t & ne10,
         constant   int64_t & ne0,
-        threadgroup float  * sum [[threadgroup(0)]],
+        constant   int64_t & ne01[[buffer(4)]],
         uint2 tgpig[[threadgroup_position_in_grid]],
-        uint2 tpitg[[thread_position_in_threadgroup]],
-        uint2  tptg[[threads_per_threadgroup]]) {
-    const int nb = ne00/QK4_1;
-
-    const int64_t r0 = tgpig.x;
-    const int64_t r1 = tgpig.y;
-
-    device const block_q4_1 * x = (device const block_q4_1 *) src0 + r0*nb;
-    device const float      * y = (device const float      *) src1 + r1*ne10;
-
-    const uint nth = tptg.x*tptg.y;
-    const uint ith = tptg.y*tpitg.x + tpitg.y;
-
-    const int ix = tpitg.y/4;           // 0 or 1
-    const int iy = tpitg.y - 4*ix;      // 0...3
-
-    const int first = 4 * iy;
-
-    float sumf = 0;
-
-    for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
-
-        const float d = (float)x[i].d;
-        const float m = (float)x[i].m;
-
-        device const uint8_t * xl = x[i].qs + first;
-        device const float   * yl = y + i * QK4_1 + first;
-
-        float2 acc = {0.0f, 0.0f};
-
-        for (int j = 0; j < 4; ++j) {
-
-            acc[0] += yl[j+ 0] * (d * (xl[j] & 0xF) + m);
-            acc[1] += yl[j+16] * (d * (xl[j] >>  4) + m);
-
-        }
-
-        sumf += acc[0] + acc[1];
-    }
-
-    sum[ith] = sumf;
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]]) {
+    mul_vec_q_n_f32<block_q4_0>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
+}
 
-    //
-    // Accumulate the sum from all threads in the threadgroup
-    //
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith%4 == 0) {
-        sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith%16 == 0) {
-        sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith == 0) {
-        for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
-        dst[r1*ne0 + r0] = sum[0];
-    }
+kernel void kernel_mul_mat_q4_1_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne10,
+        constant   int64_t & ne0,
+        constant   int64_t & ne01[[buffer(4)]],
+        uint2 tgpig[[threadgroup_position_in_grid]],
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]]) {
+     mul_vec_q_n_f32<block_q4_1>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
 }
 
 kernel void kernel_mul_mat_f16_f32(
@@ -641,17 +628,19 @@ kernel void kernel_rope(
         constant       int & n_past,
         constant       int & n_dims,
         constant       int & mode,
+        constant     float & freq_base,
+        constant     float & freq_scale,
         uint3 tpig[[thread_position_in_grid]]) {
     const int64_t i3 = tpig[2];
     const int64_t i2 = tpig[1];
     const int64_t i1 = tpig[0];
 
     const bool is_neox = mode & 2;
-    const float theta_scale = pow(10000.0, -2.0f/n_dims);
+    const float theta_scale = pow(freq_base, -2.0f/n_dims);
 
     const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
 
-    float theta = (float)p;
+    float theta = freq_scale * (float)p;
 
     if (!is_neox) {
         for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
@@ -1489,6 +1478,7 @@ kernel void kernel_mul_mat_q3_K_f32(
 
 }
 
+#if QK_K == 256
 kernel void kernel_mul_mat_q4_K_f32(
         device const  void * src0,
         device const float * src1,
@@ -1496,131 +1486,180 @@ kernel void kernel_mul_mat_q4_K_f32(
         constant   int64_t & ne00,
         constant   int64_t & ne10,
         constant   int64_t & ne0,
-        threadgroup float  * sum [[threadgroup(0)]],
+        constant   int64_t & ne01[[buffer(4)]],
         uint2 tgpig[[threadgroup_position_in_grid]],
-        uint2 tpitg[[thread_position_in_threadgroup]],
-        uint2  tptg[[threads_per_threadgroup]]) {
-
-    const int nb = ne00/QK_K;
-
-    const int64_t r0 = tgpig.x;
-    const int64_t r1 = tgpig.y;
-
-    const int nth = tptg.x*tptg.y;
-    const int ith = tptg.y*tpitg.x + tpitg.y;
-
-    device const block_q4_K * x = (device const block_q4_K *) src0 + r0*nb;
-    device const float     * yy = (device const float      *) src1 + r1*ne10;
-
-    float sumf = 0;
-
-#if QK_K == 256
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
     const uint16_t kmask1 = 0x3f3f;
     const uint16_t kmask2 = 0x0f0f;
     const uint16_t kmask3 = 0xc0c0;
 
-    const int tid = tpitg.y;   // 0...16
-    const int il  = tid/4;     // 0...3
-    const int ir  = tid - 4*il;// 0...3
-    const int n   = 4;
+    const int ix = tiisg/8;  // 0...3
+    const int it = tiisg%8;  // 0...7
+    const int im = it/4;     // 0 or 1
+    const int ir = it%4;     // 0...3
 
-    const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
-    const int in = il%2;
+    const int nb = ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int ib_row = first_row * nb;
+    device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
+    device const float      * y = (device const float      *) src1 + r1*ne10;
+    float yl[16];
+    float yh[16];
+    float sumf[N_DST]={0.f}, all_sum;
 
-    const int l0 = n*(2*ir + in);
-    const int q_offset = 32*im + l0;
-    const int y_offset = 64*im + l0;
+    const int step = sizeof(block_q4_K) * nb / 2;
 
-    uchar2 sc1, sc2, sc3, sc4;
+    device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir;
 
-    for (int i = tpitg.x; i < nb; i += tptg.x) {
+    uint16_t sc16[4];
+    thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
 
-        device const uint8_t * q1 = (x + i)->qs + q_offset;
-        device const uint8_t * q2 = q1 + 64;
-        device const float   * y1 = yy + i*QK_K + y_offset;
-        device const float   * y2 = y1 + 128;
+    for (int ib = ix; ib < nb; ib += 4) {
 
-        const float dall = (float)((x + i)->d);
-        const float dmin = (float)((x + i)->dmin);
+        float4 sumy = {0.f, 0.f, 0.f, 0.f};
+        for (int i = 0; i < 8; ++i) {
+            yl[i+0] = y4[i+  0]; sumy[0] += yl[i+0];
+            yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
+            yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
+            yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
+        }
 
-        device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
-        sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
-        sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
-        sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
-        sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
+        device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im;
+        device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
+        device const half     * dh = &x[ib].d;
+
+        for (int row = 0; row < N_DST; row++) {
+
+            sc16[0] = sc[0] & kmask1;
+            sc16[1] = sc[2] & kmask1;
+            sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
+            sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
+
+            device const uint16_t * q2 = q1 + 32;
+
+            float4 acc1 = {0.f, 0.f, 0.f, 0.f};
+            float4 acc2 = {0.f, 0.f, 0.f, 0.f};
+            for (int i = 0; i < 8; i += 2) {
+                acc1[0] += yl[i+0] * (q1[i/2] & 0x000F);
+                acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00);
+                acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0);
+                acc1[3] += yl[i+9] * (q1[i/2] & 0xF000);
+                acc2[0] += yh[i+0] * (q2[i/2] & 0x000F);
+                acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00);
+                acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0);
+                acc2[3] += yh[i+9] * (q2[i/2] & 0xF000);
+            }
 
-        float4 s = {0.f, 0.f, 0.f, 0.f};
-        float smin = 0;
-        for (int l = 0; l < n; ++l) {
+            float dall = dh[0];
+            float dmin = dh[1];
+            sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
+                                 (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
+                                 (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
+                                 (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
+                         dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
+
+            q1 += step;
+            sc += step;
+            dh += step;
+        }
 
-            s[0] += y1[l] * (q1[l] & 0xF); s[1] += y1[l+32] * (q1[l] >> 4);
-            s[2] += y2[l] * (q2[l] & 0xF); s[3] += y2[l+32] * (q2[l] >> 4);
-            smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
+        y4 += 4 * QK_K;
+    }
 
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst[r1*ne0 + first_row + row] = all_sum;
         }
-        sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
-
     }
+}
 #else
-    uint16_t aux16[2];
-    thread const uint8_t * scales = (thread const uint8_t *)aux16;
+kernel void kernel_mul_mat_q4_K_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne10,
+        constant   int64_t & ne0,
+        constant   int64_t & ne01[[buffer(4)]],
+        uint2 tgpig[[threadgroup_position_in_grid]],
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    const int il  = 4*tpitg.x;
+    const int ix = tiisg/4;  // 0...7
+    const int it = tiisg%4;  // 0...3
 
-    for (int i = tpitg.y; i < nb; i += tptg.y) {
+    const int nb = ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int ib_row = first_row * nb;
+    device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
+    device const float      * y = (device const float      *) src1 + r1*ne10;
+    float yl[8];
+    float yh[8];
+    float sumf[N_DST]={0.f}, all_sum;
 
-        device const uint8_t * q = x[i].qs + il;
-        device const float   * y = yy + i * QK_K + il;
+    const int step = sizeof(block_q4_K) * nb / 2;
 
-        const float d = (float)x[i].d[0];
-        const float m = (float)x[i].d[1];
+    device const float * y4 = y + ix * QK_K + 8 * it;
 
-        device const uint16_t * a = (device const uint16_t *)x[i].scales;
-        aux16[0] = a[0] & 0x0f0f;
-        aux16[1] = (a[0] >> 4) & 0x0f0f;
+    uint16_t sc16[4];
 
-        for (int l = 0; l < 4; ++l) {
-            sumf += d * scales[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - m * scales[2] * (y[l+ 0] + y[l+16])
-                  + d * scales[1] * (y[l+32] * (q[l] >>  4) + y[l+48] * (q[l+16] >>  4)) - m * scales[3] * (y[l+32] + y[l+48]);
+    for (int ib = ix; ib < nb; ib += 8) {
+
+        float2 sumy = {0.f, 0.f};
+        for (int i = 0; i < 8; ++i) {
+            yl[i] = y4[i+ 0]; sumy[0] += yl[i];
+            yh[i] = y4[i+32]; sumy[1] += yh[i];
         }
-    }
-#endif
 
-    sum[ith] = sumf;
+        device const uint16_t * sc = (device const uint16_t *)x[ib].scales;
+        device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
+        device const half     * dh = x[ib].d;
 
-    //
-    // Accumulate the sum from all threads in the threadgroup
-    // This version is slightly faster than the commented out one below,
-    // which I copy-pasted from ggerganov's q4_0 dot product for metal.
-    //
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith%4 == 0) {
-        for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith%16 == 0) {
-        for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith == 0) {
-        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
-        dst[r1*ne0 + r0] = sum[0];
-    }
+        for (int row = 0; row < N_DST; row++) {
+
+            sc16[0] = sc[0] & 0x000f;
+            sc16[1] = sc[0] & 0x0f00;
+            sc16[2] = sc[0] & 0x00f0;
+            sc16[3] = sc[0] & 0xf000;
+
+            float2 acc1 = {0.f, 0.f};
+            float2 acc2 = {0.f, 0.f};
+            for (int i = 0; i < 8; i += 2) {
+                acc1[0] += yl[i+0] * (qs[i/2] & 0x000F);
+                acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00);
+                acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0);
+                acc2[1] += yh[i+1] * (qs[i/2] & 0xF000);
+            }
+
+            float dall = dh[0];
+            float dmin = dh[1];
+            sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] +
+                                 (acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) -
+                         dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f);
+
+            qs += step;
+            sc += step;
+            dh += step;
+        }
 
-    //// accumulate the sum from all threads in the threadgroup
-    //threadgroup_barrier(mem_flags::mem_threadgroup);
-    //for (uint i = nth/2; i > 0; i /= 2) {
-    //    if (ith < i) {
-    //        sum[ith] += sum[ith + i];
-    //    }
-    //    threadgroup_barrier(mem_flags::mem_threadgroup);
-    //}
+        y4 += 8 * QK_K;
+    }
 
-    //if (ith == 0) {
-    //    dst[r1*ne0 + r0] = sum[0];
-    //}
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst[r1*ne0 + first_row + row] = all_sum;
+        }
+    }
 }
+#endif
 
 kernel void kernel_mul_mat_q5_K_f32(
         device const  void * src0,
@@ -1629,39 +1668,39 @@ kernel void kernel_mul_mat_q5_K_f32(
         constant   int64_t & ne00,
         constant   int64_t & ne10,
         constant   int64_t & ne0,
-        threadgroup float  * sum [[threadgroup(0)]],
         uint2 tgpig[[threadgroup_position_in_grid]],
-        uint2 tpitg[[thread_position_in_threadgroup]],
-        uint2  tptg[[threads_per_threadgroup]]) {
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
     const int nb = ne00/QK_K;
 
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
 
-    device const block_q5_K * x = (device const block_q5_K *) src0 + r0*nb;
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
+
+    device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb;
     device const float     * yy = (device const float      *) src1 + r1*ne10;
 
-    const int nth = tptg.x*tptg.y;
-    const int ith = tptg.y*tpitg.x + tpitg.y;
+    float sumf[2]={0.f};
 
-    float sumf = 0;
+    const int step = sizeof(block_q5_K) * nb;
 
 #if QK_K == 256
+#
+    float yl[16], yh[16];
 
     const uint16_t kmask1 = 0x3f3f;
     const uint16_t kmask2 = 0x0f0f;
     const uint16_t kmask3 = 0xc0c0;
 
-    const int tid = tpitg.y;   // 0...16
-    const int il  = tid/4;     // 0...3
-    const int ir  = tid - 4*il;// 0...3
-    const int n   = 4;
-
-    const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
-    const int in = il%2;
+    const int tid = tiisg/4;
+    const int ix  = tiisg%4;
+    const int im  = tid/4;
+    const int ir  = tid%4;
+    const int n   = 8;
 
-    const int l0 = n*(2*ir + in);
+    const int l0 = n*ir;
     const int q_offset = 32*im + l0;
     const int y_offset = 64*im + l0;
 
@@ -1670,78 +1709,114 @@ kernel void kernel_mul_mat_q5_K_f32(
     const uint8_t hm3 = hm1 << 4;
     const uint8_t hm4 = hm2 << 4;
 
-    uchar2 sc1, sc2, sc3, sc4;
+    uint16_t sc16[4];
+    thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
 
-    for (int i = tpitg.x; i < nb; i += tptg.x) {
+    device const float * y1 = yy + ix*QK_K + y_offset;
+
+    for (int i = ix; i < nb; i += 4) {
+
+        device const uint8_t * q1 = x[i].qs + q_offset;
+        device const uint8_t * qh = x[i].qh + l0;
+        device const half * dh = &x[i].d;
+        device const uint16_t * a = (device const uint16_t *)x[i].scales + im;
 
-        device const uint8_t * q1 = (x + i)->qs + q_offset;
-        device const uint8_t * q2 = q1 + 64;
-        device const uint8_t * qh = (x + i)->qh + l0;
-        device const float   * y1 = yy + i*QK_K + y_offset;
-        device const float   * y2 = y1 + 128;
+        device const float * y2 = y1 + 128;
+        float4 sumy = {0.f, 0.f, 0.f, 0.f};
+        for (int l = 0; l < 8; ++l) {
+            yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
+            yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
+            yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
+            yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
+        }
 
-        const float dall = (float)((x + i)->d);
-        const float dmin = (float)((x + i)->dmin);
+        for (int row = 0; row < 2; ++row) {
 
-        device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
-        sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
-        sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
-        sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
-        sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
+            device const uint8_t * q2 = q1 + 64;
 
-        float4 s = {0.f, 0.f, 0.f, 0.f};
-        float smin = 0;
-        for (int l = 0; l < n; ++l) {
+            sc16[0] = a[0] & kmask1;
+            sc16[1] = a[2] & kmask1;
+            sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
+            sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
 
-            s[0] += y1[l+ 0] * ((q1[l] & 0xF) + (qh[l] & hm1 ? 16 : 0));
-            s[1] += y1[l+32] * ((q1[l] >>  4) + (qh[l] & hm2 ? 16 : 0));
-            s[2] += y2[l+ 0] * ((q2[l] & 0xF) + (qh[l] & hm3 ? 16 : 0));
-            s[3] += y2[l+32] * ((q2[l] >>  4) + (qh[l] & hm4 ? 16 : 0));
-            smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
+            float4 acc = {0.f, 0.f, 0.f, 0.f};
+            for (int l = 0; l < n; ++l) {
+                uint8_t h = qh[l];
+                acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0));
+                acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0));
+                acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0));
+                acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0));
+            }
+            const float dall = dh[0];
+            const float dmin = dh[1];
+            sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) -
+                         dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
+
+            q1 += step;
+            qh += step;
+            dh += step/2;
+            a  += step/2;
 
         }
-        sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
+
+        y1 += 4 * QK_K;
 
     }
 #else
-    const int il  = 4 * tpitg.x;  // 0, 4, 8, 12
-    const int im  = il/8;         // 0, 0, 1, 1
-    const int in  = il%8;         // 0, 4, 0, 4
+    float yl[8], yh[8];
 
-    for (int i = tpitg.y; i < nb; i += tptg.y) {
+    const int il = 4 * (tiisg/8);  // 0, 4, 8, 12
+    const int ix = tiisg%8;
+    const int im = il/8;         // 0, 0, 1, 1
+    const int in = il%8;         // 0, 4, 0, 4
 
-        const float d = (float)x[i].d;
+    device const float * y = yy + ix*QK_K + il;
+
+    for (int i = ix; i < nb; i += 8) {
+
+        float4 sumy = {0.f, 0.f, 0.f, 0.f};
+        for (int l = 0; l < 4; ++l) {
+            yl[l+0] = y[l+ 0];
+            yl[l+4] = y[l+16];
+            yh[l+0] = y[l+32];
+            yh[l+4] = y[l+48];
+        }
+
+        device const half * dh = &x[i].d;
         device const uint8_t * q = x[i].qs + il;
         device const uint8_t * h = x[i].qh + in;
         device const int8_t  * s = x[i].scales;
-        device const float   * y = yy + i*QK_K + il;
 
-        for (int l = 0; l < 4; ++l) {
-            const uint8_t hl = h[l] >> im;
-            sumf += y[l+ 0] * d * s[0] * ((q[l+ 0] & 0xF) - (hl & 0x01 ? 0 : 16))
-                  + y[l+16] * d * s[1] * ((q[l+16] & 0xF) - (hl & 0x04 ? 0 : 16))
-                  + y[l+32] * d * s[2] * ((q[l+ 0] >>  4) - (hl & 0x10 ? 0 : 16))
-                  + y[l+48] * d * s[3] * ((q[l+16] >>  4) - (hl & 0x40 ? 0 : 16));
+        for (int row = 0; row < 2; ++row) {
+
+            const float d = dh[0];
+
+            float2 acc = {0.f, 0.f};
+            for (int l = 0; l < 4; ++l) {
+                const uint8_t hl = h[l] >> im;
+                acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
+                        + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
+                acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
+                        + yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256));
+            }
+            sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]);
+
+            q += step;
+            h += step;
+            s += step;
+            dh += step/2;
+
         }
+
+        y += 8 * QK_K;
     }
 #endif
-    sum[ith] = sumf;
 
-    //
-    // Accumulate the sum from all threads in the threadgroup
-    //
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith%4 == 0) {
-        sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith%16 == 0) {
-        sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith == 0) {
-        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
-        dst[r1*ne0 + r0] = sum[0];
+    for (int row = 0; row < 2; ++row) {
+        const float tot = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst[r1*ne0 + first_row + row] = tot;
+        }
     }
 
 }
@@ -1753,10 +1828,9 @@ kernel void kernel_mul_mat_q6_K_f32(
         constant   int64_t & ne00,
         constant   int64_t & ne10,
         constant   int64_t & ne0,
-        threadgroup float  * sum [[threadgroup(0)]],
         uint2 tgpig[[threadgroup_position_in_grid]],
-        uint2 tpitg[[thread_position_in_threadgroup]],
-        uint2  tptg[[threads_per_threadgroup]]) {
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
     const uint8_t kmask1 = 0x03;
     const uint8_t kmask2 = 0x0C;
@@ -1768,19 +1842,18 @@ kernel void kernel_mul_mat_q6_K_f32(
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
 
-    device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb;
-    device const float     * yy = (device const float      *) src1 + r1*ne10;
+    const int row = 2 * r0 + sgitg;
 
-    const int nth = tptg.x*tptg.y;
-    const int ith = tptg.y*tpitg.x + tpitg.y;
+    device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb; //r0*nb;
+    device const float     * yy = (device const float      *) src1 + r1*ne10;
 
     float sumf = 0;
 
 #if QK_K == 256
-    // Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
-    const int iqs  = 16 * tpitg.y;
-    const int ip   = iqs / 128;         // 0 or 1
-    const int il   = (iqs - 128*ip)/16; // 0...7
+    const int tid  = tiisg/2;
+    const int ix   = tiisg%2;
+    const int ip   = tid/8;         // 0 or 1
+    const int il   = tid%8;
     const int n    = 4;
     const int l0   = n*il;
     const int is   = 8*ip + l0/16;
@@ -1789,9 +1862,10 @@ kernel void kernel_mul_mat_q6_K_f32(
     const int q_offset_l = 64*ip + l0;
     const int q_offset_h = 32*ip + l0;
 
-    for (int i = tpitg.x; i < nb; i += tptg.x) {
+    for (int i = ix; i < nb; i += 2) {
 
-        device const uint8_t * ql = x[i].ql + q_offset_l;
+        device const uint8_t * q1 = x[i].ql + q_offset_l;
+        device const uint8_t * q2 = q1 + 32;
         device const uint8_t * qh = x[i].qh + q_offset_h;
         device const int8_t  * sc = x[i].scales + is;
 
@@ -1801,19 +1875,21 @@ kernel void kernel_mul_mat_q6_K_f32(
 
         float4 sums = {0.f, 0.f, 0.f, 0.f};
         for (int l = 0; l < n; ++l) {
-            sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
-            sums[1] += y[l+32] * ((int8_t)((ql[l+32] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
-            sums[2] += y[l+64] * ((int8_t)((ql[l+ 0]  >> 4) | ((qh[l] & kmask3) << 0)) - 32);
-            sums[3] += y[l+96] * ((int8_t)((ql[l+32]  >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
+            sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
+            sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
+            sums[2] += y[l+64] * ((int8_t)((q1[l]  >> 4) | ((qh[l] & kmask3) << 0)) - 32);
+            sums[3] += y[l+96] * ((int8_t)((q2[l]  >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
         }
 
         sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
 
     }
+
 #else
-    const int il  = 4*tpitg.x;    // 0, 4, 8, 12
+    const int ix  = tiisg/4;
+    const int il  = 4*(tiisg%4);
 
-    for (int i = tpitg.y; i < nb; i += tptg.y) {
+    for (int i = ix; i < nb; i += 8) {
         device const float * y = yy + i * QK_K + il;
         device const uint8_t * ql = x[i].ql + il;
         device const uint8_t * qh = x[i].qh + il;
@@ -1833,23 +1909,8 @@ kernel void kernel_mul_mat_q6_K_f32(
 
 #endif
 
-    sum[ith] = sumf;
-
-    //
-    // Accumulate the sum from all threads in the threadgroup
-    //
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith%4 == 0) {
-        for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
+    const float tot = simd_sum(sumf);
+    if (tiisg == 0) {
+        dst[r1*ne0 + row] = tot;
     }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith%16 == 0) {
-        for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith == 0) {
-        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
-        dst[r1*ne0 + r0] = sum[0];
-    }
-
 }

文件差異過大導致無法顯示
+ 440 - 138
llama/ggml.c


+ 49 - 2
llama/ggml.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 5bf2a2771886ee86137e01dbc7492f78fb392066
+ * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc
  *
  * MIT License
  *
@@ -227,8 +227,13 @@
 #define GGML_MAX_NAME          48
 #define GGML_DEFAULT_N_THREADS 4
 
+
+#define GGML_EXIT_SUCCESS 0
+#define GGML_EXIT_ABORTED 1
+
 #define GGML_UNUSED(x) (void)(x)
 
+
 #define GGML_ASSERT(x) \
     do { \
         if (!(x)) { \
@@ -389,6 +394,8 @@ extern "C" {
         GGML_OP_CLAMP,
         GGML_OP_CONV_1D,
         GGML_OP_CONV_2D,
+        GGML_OP_POOL_1D,
+        GGML_OP_POOL_2D,
 
         GGML_OP_FLASH_ATTN,
         GGML_OP_FLASH_FF,
@@ -468,6 +475,10 @@ extern "C" {
 
         // the `n_tasks` of nodes, 1:1 mapping to cgraph nodes
         int n_tasks[GGML_MAX_NODES];
+
+        // abort ggml_graph_compute when true
+        bool (*abort_callback)(void * data);
+        void * abort_callback_data;
     };
 
     // computation graph
@@ -1136,6 +1147,17 @@ extern "C" {
             int                   mode,
             int                   n_ctx);
 
+    // custom RoPE, in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_past,
+            int                   n_dims,
+            int                   mode,
+            float                 freq_base,
+            float                 freq_scale,
+            int                   n_ctx);
+
     // rotary position embedding backward, i.e compute dx from dy
     // a - dy
     GGML_API struct ggml_tensor * ggml_rope_back(
@@ -1190,6 +1212,31 @@ extern "C" {
             int                   s,
             int                   d);
 
+    enum ggml_op_pool {
+        GGML_OP_POOL_MAX,
+        GGML_OP_POOL_AVG,
+        GGML_OP_POOL_COUNT,
+    };
+
+    GGML_API struct ggml_tensor* ggml_pool_1d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            enum ggml_op_pool     op,
+            int                   k0, // kernel size
+            int                   s0, // stride
+            int                   p0); // padding
+
+    GGML_API struct ggml_tensor* ggml_pool_2d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            enum ggml_op_pool     op,
+            int                   k0,
+            int                   k1,
+            int                   s0,
+            int                   s1,
+            int                   p0,
+            int                   p1);
+
     GGML_API struct ggml_tensor * ggml_flash_attn(
             struct ggml_context * ctx,
             struct ggml_tensor  * q,
@@ -1329,7 +1376,7 @@ extern "C" {
     // ggml_graph_plan() has to be called before ggml_graph_compute()
     // when plan.work_size > 0, caller must allocate memory for plan.work_data
     GGML_API struct ggml_cplan ggml_graph_plan   (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
-    GGML_API              void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
+    GGML_API               int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
     GGML_API              void ggml_graph_reset  (struct ggml_cgraph * cgraph);
 
     // same as ggml_graph_compute() but the work data is allocated as a part of the context

+ 1 - 1
llama/k_quants.c

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 5bf2a2771886ee86137e01dbc7492f78fb392066
+ * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc
  *
  * MIT License
  *

+ 9 - 1
llama/k_quants.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 5bf2a2771886ee86137e01dbc7492f78fb392066
+ * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc
  *
  * MIT License
  *
@@ -41,6 +41,14 @@
 #define K_SCALE_SIZE 12
 #endif
 
+#ifndef static_assert
+#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L)
+#define static_assert(cond, msg) _Static_assert(cond, msg)
+#else
+#define static_assert(cond, msg) struct global_scope_noop_trick
+#endif
+#endif
+
 //
 // Super-block quantization structures
 //

+ 4 - 4
llama/llama-util.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 5bf2a2771886ee86137e01dbc7492f78fb392066
+ * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc
  *
  * MIT License
  *
@@ -201,13 +201,13 @@ struct llama_mmap {
     llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1 /* -1 = max value */, bool numa = false) {
         size = file->size;
         int fd = fileno(file->fp);
-        int flags = MAP_PRIVATE;
+        int flags = MAP_SHARED;
         // prefetch/readahead impairs performance on NUMA systems
         if (numa) { prefetch = 0; }
 #ifdef __linux__
         if (prefetch) { flags |= MAP_POPULATE; }
 #endif
-        addr = mmap(NULL, file->size, PROT_READ | PROT_WRITE, flags, fd, 0);
+        addr = mmap(NULL, file->size, PROT_READ, flags, fd, 0);
         if (addr == MAP_FAILED) {
             throw std::runtime_error(format("mmap failed: %s", strerror(errno)));
         }
@@ -249,7 +249,7 @@ struct llama_mmap {
             throw std::runtime_error(format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str()));
         }
 
-        addr = MapViewOfFile(hMapping, FILE_MAP_COPY, 0, 0, 0);
+        addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0);
         error = GetLastError();
         CloseHandle(hMapping);
 

+ 120 - 53
llama/llama.cpp

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 5bf2a2771886ee86137e01dbc7492f78fb392066
+ * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc
  *
  * MIT License
  *
@@ -127,14 +127,15 @@ static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph *
 // memory sizes
 //
 
-static const std::map<e_model, size_t> & MEM_REQ_SCRATCH0()
+static const std::map<e_model, size_t> & MEM_REQ_SCRATCH0(int n_ctx)
 {
     static std::map<e_model, size_t> k_sizes = {
-        { MODEL_3B,    256ull * MB },
-        { MODEL_7B,    512ull * MB },
-        { MODEL_13B,   512ull * MB },
-        { MODEL_30B,   512ull * MB },
-        { MODEL_65B,  1024ull * MB },
+        /* empirical scaling, still a guess */
+        { MODEL_3B,   ((size_t) n_ctx / 16ull + 128ull) * MB },
+        { MODEL_7B,   ((size_t) n_ctx / 16ull + 256ull) * MB },
+        { MODEL_13B,  ((size_t) n_ctx / 12ull + 256ull) * MB },
+        { MODEL_30B,  ((size_t) n_ctx / 10ull + 256ull) * MB },
+        { MODEL_65B,  ((size_t) n_ctx /  8ull + 512ull) * MB },
     };
     return k_sizes;
 }
@@ -166,14 +167,14 @@ static const std::map<e_model, size_t> & MEM_REQ_KV_SELF()
 
 // this is mostly needed for temporary mul_mat buffers to dequantize the data
 // not actually needed if BLAS is disabled
-static const std::map<e_model, size_t> & MEM_REQ_EVAL()
+static const std::map<e_model, size_t> & MEM_REQ_EVAL(int n_ctx)
 {
     static std::map<e_model, size_t> k_sizes = {
-        { MODEL_3B,   512ull * MB },
-        { MODEL_7B,   768ull * MB },
-        { MODEL_13B, 1024ull * MB },
-        { MODEL_30B, 1280ull * MB },
-        { MODEL_65B, 1536ull * MB },
+        { MODEL_3B,  ((size_t) n_ctx / 256ull +  512ull) * MB },
+        { MODEL_7B,  ((size_t) n_ctx / 256ull +  768ull) * MB },
+        { MODEL_13B, ((size_t) n_ctx / 256ull + 1024ull) * MB },
+        { MODEL_30B, ((size_t) n_ctx / 256ull + 1280ull) * MB },
+        { MODEL_65B, ((size_t) n_ctx / 256ull + 1536ull) * MB },
     };
     return k_sizes;
 }
@@ -215,6 +216,10 @@ struct llama_hparams {
     uint32_t n_head  = 32;
     uint32_t n_layer = 32;
     uint32_t n_rot   = 64;
+
+    float rope_freq_base  = 10000.0f;
+    float rope_freq_scale = 1.0f;
+
     enum llama_ftype ftype = LLAMA_FTYPE_MOSTLY_F16;
 
     bool operator!=(const llama_hparams & other) const {
@@ -329,7 +334,7 @@ struct llama_model {
 };
 
 struct llama_context {
-    llama_context(const llama_model & model, const llama_vocab & vocab) : model(model), vocab(vocab), t_load_us(model.t_load_us), t_start_us(model.t_start_us) {}
+    llama_context(const llama_model & model) : model(model), t_load_us(model.t_load_us), t_start_us(model.t_start_us) {}
 #ifdef GGML_USE_METAL
     ~llama_context() {
         if (ctx_metal) {
@@ -350,7 +355,6 @@ struct llama_context {
     int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
 
     const llama_model & model;
-    const llama_vocab & vocab;
 
     bool model_owner = false;
 
@@ -577,7 +581,9 @@ struct llama_file_loader {
             }
 
             // skip to the next multiple of 32 bytes
-            file.seek(-static_cast<ptrdiff_t>(file.tell()) & 31, SEEK_CUR);
+            if (file_version >= LLAMA_FILE_VERSION_GGJT_V1) {
+                file.seek(-static_cast<ptrdiff_t>(file.tell()) & 31, SEEK_CUR);
+            }
 
             tensor.file_off = file.tell();
             tensor.name = name;
@@ -674,7 +680,7 @@ struct llama_model_loader {
         *ctx_size_p = *mmapped_size_p = 0;
         for (const llama_load_tensor & lt : tensors_map.tensors) {
             *ctx_size_p += sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE;
-            *(use_mmap ? mmapped_size_p : ctx_size_p) += lt.size;
+            *(use_mmap ? mmapped_size_p : ctx_size_p) += lt.size + 16;
         }
     }
 
@@ -870,6 +876,8 @@ struct llama_context_params llama_context_default_params() {
         /*.gpu_layers                  =*/ 0,
         /*.main_gpu                    =*/ 0,
         /*.tensor_split                =*/ {0},
+        /*.rope_freq_base              =*/ 10000.0f,
+        /*.rope_freq_scale             =*/ 1.0f,
         /*.progress_callback           =*/ nullptr,
         /*.progress_callback_user_data =*/ nullptr,
         /*.low_vram                    =*/ false,
@@ -895,6 +903,10 @@ struct llama_model_quantize_params llama_model_quantize_default_params() {
     return result;
 }
 
+int llama_max_devices() {
+    return LLAMA_MAX_DEVICES;
+}
+
 bool llama_mmap_supported() {
     return llama_mmap::SUPPORTED;
 }
@@ -993,6 +1005,8 @@ static void llama_model_load_internal(
         int n_gpu_layers,
         int main_gpu,
         const float * tensor_split,
+        float rope_freq_base,
+        float rope_freq_scale,
         bool low_vram,
         ggml_type memory_type,
         bool use_mmap,
@@ -1027,22 +1041,27 @@ static void llama_model_load_internal(
         }
 
         hparams.n_ctx = n_ctx;
+
+        hparams.rope_freq_base  = rope_freq_base;
+        hparams.rope_freq_scale = rope_freq_scale;
     }
 
     const uint32_t n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
 
     {
-        fprintf(stderr, "%s: format     = %s\n",  __func__, llama_file_version_name(file_version));
-        fprintf(stderr, "%s: n_vocab    = %u\n",  __func__, hparams.n_vocab);
-        fprintf(stderr, "%s: n_ctx      = %u\n",  __func__, hparams.n_ctx);
-        fprintf(stderr, "%s: n_embd     = %u\n",  __func__, hparams.n_embd);
-        fprintf(stderr, "%s: n_mult     = %u\n",  __func__, hparams.n_mult);
-        fprintf(stderr, "%s: n_head     = %u\n",  __func__, hparams.n_head);
-        fprintf(stderr, "%s: n_layer    = %u\n",  __func__, hparams.n_layer);
-        fprintf(stderr, "%s: n_rot      = %u\n",  __func__, hparams.n_rot);
+        fprintf(stderr, "%s: format     = %s\n",   __func__, llama_file_version_name(file_version));
+        fprintf(stderr, "%s: n_vocab    = %u\n",   __func__, hparams.n_vocab);
+        fprintf(stderr, "%s: n_ctx      = %u\n",   __func__, hparams.n_ctx);
+        fprintf(stderr, "%s: n_embd     = %u\n",   __func__, hparams.n_embd);
+        fprintf(stderr, "%s: n_mult     = %u\n",   __func__, hparams.n_mult);
+        fprintf(stderr, "%s: n_head     = %u\n",   __func__, hparams.n_head);
+        fprintf(stderr, "%s: n_layer    = %u\n",   __func__, hparams.n_layer);
+        fprintf(stderr, "%s: n_rot      = %u\n",   __func__, hparams.n_rot);
+        fprintf(stderr, "%s: freq_base  = %.1f\n", __func__, hparams.rope_freq_base);
+        fprintf(stderr, "%s: freq_scale = %g\n",   __func__, hparams.rope_freq_scale);
         fprintf(stderr, "%s: ftype      = %u (%s)\n", __func__, hparams.ftype, llama_ftype_name(hparams.ftype));
-        fprintf(stderr, "%s: n_ff       = %u\n",  __func__, n_ff);
-        fprintf(stderr, "%s: model size = %s\n",  __func__, llama_model_type_name(model.type));
+        fprintf(stderr, "%s: n_ff       = %u\n",   __func__, n_ff);
+        fprintf(stderr, "%s: model size = %s\n",   __func__, llama_model_type_name(model.type));
     }
 
     if (file_version < LLAMA_FILE_VERSION_GGJT_V2) {
@@ -1191,9 +1210,9 @@ static void llama_model_load_internal(
         const size_t mem_required =
             ctx_size +
             mmapped_size - vram_weights + // weights in VRAM not in memory
-            MEM_REQ_SCRATCH0().at(model.type) +
+            MEM_REQ_SCRATCH0(hparams.n_ctx).at(model.type) +
             MEM_REQ_SCRATCH1().at(model.type) +
-            MEM_REQ_EVAL().at    (model.type);
+            MEM_REQ_EVAL(hparams.n_ctx).at(model.type);
 
         // this is the memory required by one llama_state
         const size_t mem_required_state =
@@ -1297,6 +1316,8 @@ static bool llama_model_load(
         int n_gpu_layers,
         int main_gpu,
         float * tensor_split,
+        float rope_freq_base,
+        float rope_freq_scale,
         bool low_vram,
         ggml_type memory_type,
         bool use_mmap,
@@ -1305,7 +1326,7 @@ static bool llama_model_load(
         llama_progress_callback progress_callback,
         void *progress_callback_user_data) {
     try {
-        llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gpu_layers, main_gpu, tensor_split, low_vram, memory_type,
+        llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gpu_layers, main_gpu, tensor_split, rope_freq_base, rope_freq_scale, low_vram, memory_type,
                                   use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data);
         return true;
     } catch (const std::exception & err) {
@@ -1357,6 +1378,9 @@ static bool llama_eval_internal(
     const int n_rot        = hparams.n_embd/hparams.n_head;
     const int n_gpu_layers = model.n_gpu_layers;
 
+    const float freq_base  = hparams.rope_freq_base;
+    const float freq_scale = hparams.rope_freq_scale;
+
     auto & mem_per_token = lctx.mem_per_token;
     auto & buf_compute   = lctx.buf_compute;
 
@@ -1454,11 +1478,11 @@ static bool llama_eval_internal(
             offload_func_kq(tmpq);
             ggml_set_name(tmpq, "tmpq");
 
-            struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0, 0);
+            struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0, freq_base, freq_scale, 0);
             offload_func_kq(Kcur);
             ggml_set_name(Kcur, "Kcur");
 
-            struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0, 0);
+            struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0, freq_base, freq_scale, 0);
             offload_func_kq(Qcur);
             ggml_set_name(Qcur, "Qcur");
 
@@ -2032,9 +2056,18 @@ void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array *
     }
 
     // Normalize the second derivatives
-    float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f);
-    for (float & value : second_derivatives) {
-        value /= second_derivatives_sum;
+    {
+        const float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f);
+
+        if (second_derivatives_sum > 1e-6f) {
+            for (float & value : second_derivatives) {
+                value /= second_derivatives_sum;
+            }
+        } else {
+            for (float & value : second_derivatives) {
+                value = 1.0f / second_derivatives.size();
+            }
+        }
     }
 
     float cum_sum = 0.0f;
@@ -2213,7 +2246,7 @@ void llama_sample_classifier_free_guidance(
           struct llama_context * guidance_ctx,
                          float   scale,
                          float   smooth_factor) {
-    int64_t t_start_sample_us = t_start_sample_us = ggml_time_us();
+    int64_t t_start_sample_us = ggml_time_us();
 
     assert(ctx);
     auto n_vocab = llama_n_vocab(ctx);
@@ -2701,8 +2734,9 @@ struct llama_model * llama_load_model_from_file(
     ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
 
     if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gpu_layers,
-                params.main_gpu, params.tensor_split, params.low_vram, memory_type, params.use_mmap, params.use_mlock,
-                params.vocab_only, params.progress_callback, params.progress_callback_user_data)) {
+                params.main_gpu, params.tensor_split, params.rope_freq_base, params.rope_freq_scale,params.low_vram,
+                memory_type, params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback,
+                params.progress_callback_user_data)) {
         delete model;
         fprintf(stderr, "%s: failed to load model\n", __func__);
         return nullptr;
@@ -2723,7 +2757,7 @@ struct llama_context * llama_new_context_with_model(
         return nullptr;
     }
 
-    llama_context * ctx = new llama_context(*model, model->vocab);
+    llama_context * ctx = new llama_context(*model);
 
     if (params.seed == LLAMA_DEFAULT_SEED) {
         params.seed = time(NULL);
@@ -2777,9 +2811,9 @@ struct llama_context * llama_new_context_with_model(
             ctx->embedding.resize(hparams.n_embd);
         }
 
-        ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type));
+        ctx->buf_compute.resize(MEM_REQ_EVAL(hparams.n_ctx).at(ctx->model.type));
 
-        ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0().at(ctx->model.type));
+        ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type));
         ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type));
     }
 
@@ -3561,13 +3595,13 @@ int llama_eval_export(struct llama_context * ctx, const char * fname) {
     return 0;
 }
 
-int llama_tokenize(
-        struct llama_context * ctx,
+int llama_tokenize_with_model(
+    const struct llama_model * model,
                   const char * text,
                  llama_token * tokens,
                          int   n_max_tokens,
                         bool   add_bos) {
-    auto res = llama_tokenize(ctx->vocab, text, add_bos);
+    auto res = llama_tokenize(model->vocab, text, add_bos);
 
     if (n_max_tokens < (int) res.size()) {
         fprintf(stderr, "%s: too many tokens\n", __func__);
@@ -3581,8 +3615,29 @@ int llama_tokenize(
     return res.size();
 }
 
+int llama_tokenize(
+        struct llama_context * ctx,
+                  const char * text,
+                 llama_token * tokens,
+                         int   n_max_tokens,
+                        bool   add_bos) {
+    return llama_tokenize_with_model(&ctx->model, text, tokens, n_max_tokens, add_bos);
+}
+
+int llama_n_vocab_from_model(const struct llama_model * model) {
+    return model->vocab.id_to_token.size();
+}
+
+int llama_n_ctx_from_model(const struct llama_model * model) {
+    return model->hparams.n_ctx;
+}
+
+int llama_n_embd_from_model(const struct llama_model * model) {
+    return model->hparams.n_embd;
+}
+
 int llama_n_vocab(const struct llama_context * ctx) {
-    return ctx->vocab.id_to_token.size();
+    return ctx->model.vocab.id_to_token.size();
 }
 
 int llama_n_ctx(const struct llama_context * ctx) {
@@ -3593,19 +3648,27 @@ int llama_n_embd(const struct llama_context * ctx) {
     return ctx->model.hparams.n_embd;
 }
 
-int llama_get_vocab(
-        const struct llama_context * ctx,
+int llama_get_vocab_from_model(
+        const struct llama_model * model,
         const char * * strings,
         float  * scores,
         int capacity) {
-    int n = std::min(capacity, (int) ctx->vocab.id_to_token.size());
+    int n = std::min(capacity, (int) model->vocab.id_to_token.size());
     for (int i = 0; i<n; ++i) {
-        strings[i] = ctx->vocab.id_to_token[i].tok.c_str();
-        scores[i]  = ctx->vocab.id_to_token[i].score;
+        strings[i] = model->vocab.id_to_token[i].tok.c_str();
+        scores[i]  = model->vocab.id_to_token[i].score;
     }
     return n;
 }
 
+int llama_get_vocab(
+        const struct llama_context * ctx,
+        const char * * strings,
+        float  * scores,
+        int capacity) {
+    return llama_get_vocab_from_model(&ctx->model, strings, scores, capacity);
+}
+
 float * llama_get_logits(struct llama_context * ctx) {
     return ctx->logits.data();
 }
@@ -3614,12 +3677,16 @@ float * llama_get_embeddings(struct llama_context * ctx) {
     return ctx->embedding.data();
 }
 
-const char * llama_token_to_str(const struct llama_context * ctx, llama_token token) {
-    if (token >= llama_n_vocab(ctx)) {
+const char * llama_token_to_str_with_model(const struct llama_model * model, llama_token token) {
+    if (token >= llama_n_vocab_from_model(model)) {
         return nullptr;
     }
 
-    return ctx->vocab.id_to_token[token].tok.c_str();
+    return model->vocab.id_to_token[token].tok.c_str();
+}
+
+const char * llama_token_to_str(const struct llama_context * ctx, llama_token token) {
+    return llama_token_to_str_with_model(&ctx->model, token);
 }
 
 llama_token llama_token_bos() {

+ 32 - 2
llama/llama.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 5bf2a2771886ee86137e01dbc7492f78fb392066
+ * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc
  *
  * MIT License
  *
@@ -115,6 +115,11 @@ extern "C" {
         int32_t  n_gpu_layers;                 // number of layers to store in VRAM
         int32_t  main_gpu;                     // the GPU that is used for scratch and small tensors
         float tensor_split[LLAMA_MAX_DEVICES]; // how to split layers across multiple GPUs
+
+        // ref: https://github.com/ggerganov/llama.cpp/pull/2054
+        float    rope_freq_base;  // RoPE base frequency
+        float    rope_freq_scale; // RoPE frequency scaling factor
+
         // called with a progress value between 0 and 1, pass NULL to disable
         llama_progress_callback progress_callback;
         // context pointer passed to the progress callback
@@ -174,6 +179,8 @@ extern "C" {
         int32_t n_eval;
     };
 
+    LLAMA_API int llama_max_devices();
+
     LLAMA_API struct llama_context_params llama_context_default_params();
     LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params();
 
@@ -296,10 +303,21 @@ extern "C" {
                              int   n_max_tokens,
                             bool   add_bos);
 
+    LLAMA_API int llama_tokenize_with_model(
+        const struct llama_model * model,
+                      const char * text,
+                     llama_token * tokens,
+                             int   n_max_tokens,
+                            bool   add_bos);
+
     LLAMA_API int llama_n_vocab(const struct llama_context * ctx);
     LLAMA_API int llama_n_ctx  (const struct llama_context * ctx);
     LLAMA_API int llama_n_embd (const struct llama_context * ctx);
 
+    LLAMA_API int llama_n_vocab_from_model(const struct llama_model * model);
+    LLAMA_API int llama_n_ctx_from_model  (const struct llama_model * model);
+    LLAMA_API int llama_n_embd_from_model (const struct llama_model * model);
+
     // Get the vocabulary as output parameters.
     // Returns number of results.
     LLAMA_API int llama_get_vocab(
@@ -308,6 +326,12 @@ extern "C" {
                                  float * scores,
                                    int   capacity);
 
+    LLAMA_API int llama_get_vocab_from_model(
+              const struct llama_model * model,
+                          const char * * strings,
+                                 float * scores,
+                                   int   capacity);
+
     // Token logits obtained from the last call to llama_eval()
     // The logits for the last token are stored in the last row
     // Can be mutated in order to change the probabilities of the next token
@@ -320,7 +344,13 @@ extern "C" {
     LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
 
     // Token Id -> String. Uses the vocabulary in the provided context
-    LLAMA_API const char * llama_token_to_str(const struct llama_context * ctx, llama_token token);
+    LLAMA_API const char * llama_token_to_str(
+            const struct llama_context * ctx,
+                           llama_token   token);
+
+    LLAMA_API const char * llama_token_to_str_with_model(
+              const struct llama_model * model,
+                           llama_token   token);
 
     // Special tokens
     LLAMA_API llama_token llama_token_bos();  // beginning-of-sentence

+ 61 - 0
llama/update-llama-cpp.sh

@@ -0,0 +1,61 @@
+#!/bin/sh
+
+set -eu
+
+
+status() { echo >&2 ">>> $*"; }
+error() { status "ERROR $*"; }
+usage() {
+    echo "usage: $(basename $0) /path/to/repo"
+    exit 1
+}
+
+OUT=$(dirname $0)
+while getopts "hC:" OPTION; do
+    case $OPTION in
+        C) OUT=$OPTARG ;;
+        *) usage ;;
+    esac
+done
+
+shift $(( $OPTIND - 1 ))
+[ $# -eq 1 ] || usage
+
+status "updating source..."
+cp -a "$1"/*.{c,h,cpp,m,metal,cu} "$OUT"
+
+status "removing incompatible files..."
+rm -f "$OUT"/build-info.h
+rm -f "$OUT"/ggml-{mpi,opencl}.*
+
+SHA1=$(git -C $1 rev-parse @)
+
+LICENSE=$(mktemp)
+cleanup() {
+    rm -f $LICENSE
+}
+trap cleanup 0
+
+cat <<EOF | sed 's/ *$//' >$LICENSE
+/**
+ * llama.cpp - git $SHA1
+ *
+$(sed 's/^/ * /' <$1/LICENSE)
+ */
+
+EOF
+
+for f in $OUT/*.{c,h,cpp,m,metal,cu}; do
+    TMP=$(mktemp)
+    status "updating license: $f"
+    cat $LICENSE $f >$TMP
+    mv $TMP $f
+done
+
+status "touching up MacOS files..."
+TMP=$(mktemp)
+{
+    echo "// +build darwin"
+    echo
+} | cat - $OUT/ggml-metal.m >$TMP
+mv $TMP $OUT/ggml-metal.m

部分文件因文件數量過多而無法顯示