123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115 |
- /**
- * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
- *
- * MIT License
- *
- * Copyright (c) 2023-2024 The ggml authors
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to deal
- * in the Software without restriction, including without limitation the rights
- * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- * copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
- #include "common.cuh"
- #include "wkv6.cuh"
- static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
- const int tid = threadIdx.x;
- const int bid = blockIdx.x;
- const int head_size = CUDA_WKV_BLOCK_SIZE;
- const int batch_i = bid / H;
- const int head_i = bid % H;
- const int state_size = C * head_size;
- const int n_seq_tokens = T / B;
- float state[head_size];
- __shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
- #pragma unroll
- for (int i = 0; i < head_size; i++) {
- state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
- }
- __syncthreads();
- _tf[tid] = tf[head_i * head_size + tid];
- __syncthreads();
- for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
- __syncthreads();
- _k[tid] = k[t];
- _r[tid] = r[t];
- _td[tid] = td[t];
- __syncthreads();
- const float _v = v[t];
- float y = 0;
- for (int j = 0; j < head_size; j += 4) {
- const float4& k = (float4&)(_k[j]);
- const float4& r = (float4&)(_r[j]);
- const float4& tf = (float4&)(_tf[j]);
- const float4& td = (float4&)(_td[j]);
- float4& s = (float4&)(state[j]);
- float4 kv;
- kv.x = k.x * _v;
- kv.y = k.y * _v;
- kv.z = k.z * _v;
- kv.w = k.w * _v;
- y += r.x * (tf.x * kv.x + s.x);
- y += r.y * (tf.y * kv.y + s.y);
- y += r.z * (tf.z * kv.z + s.z);
- y += r.w * (tf.w * kv.w + s.w);
- s.x = s.x * td.x + kv.x;
- s.y = s.y * td.y + kv.y;
- s.z = s.z * td.z + kv.z;
- s.w = s.w * td.w + kv.w;
- }
- dst[t] = y;
- }
- #pragma unroll
- for (int i = 0; i < head_size; i++) {
- dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
- }
- }
- void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
- const float * k_d = (const float *)dst->src[0]->data;
- const float * v_d = (const float *)dst->src[1]->data;
- const float * r_d = (const float *)dst->src[2]->data;
- const float * tf_d = (const float *)dst->src[3]->data;
- const float * td_d = (const float *)dst->src[4]->data;
- const float * s_d = (const float *)dst->src[5]->data;
- const int64_t B = dst->src[5]->ne[1];
- const int64_t T = dst->src[0]->ne[3];
- const int64_t C = dst->ne[0];
- const int64_t H = dst->src[0]->ne[2];
- float * dst_d = (float *)dst->data;
- cudaStream_t stream = ctx.stream();
- GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
- GGML_ASSERT(C % H == 0);
- GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE); // The current cuda kernel is designed for RWKV6, HEAD_SIZE == 64
- rwkv_wkv_f32<<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
- }
|