123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113 |
- /**
- * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
- *
- * MIT License
- *
- * Copyright (c) 2023-2024 The ggml authors
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to deal
- * in the Software without restriction, including without limitation the rights
- * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- * copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
- #include "conv-transpose-1d.cuh"
- static __global__ void conv_transpose_1d_kernel(
- const int s0, const int p0, const int d0, const int output_size,
- const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3,
- const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3,
- const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3,
- const float * src0, const float * src1, float * dst) {
- int global_index = threadIdx.x + blockIdx.x * blockDim.x;
- if (global_index >= output_size) {
- return;
- }
- int out_index = global_index / dst_ne0;
- float accumulator = 0;
- for (int c = 0; c < src0_ne2; c++) {
- int idx = global_index % dst_ne0;
- int kernel_offset = (src0_ne0 * src0_ne1 * c) + (out_index * src0_ne0);
- int input_offset = src1_ne0 * c;
- for (int i = 0; i < src1_ne0; i++) {
- if (!(idx >= i*s0 && idx < i*s0 + src0_ne0)) {
- continue;
- }
- int weight_idx = idx - i*s0;
- float kernel_weight = src0[kernel_offset + weight_idx];
- float input_value = src1[input_offset+i];
- accumulator += kernel_weight * input_value;
- }
- }
- dst[global_index] = accumulator;
- }
- static void conv_transpose_1d_f32_f32_cuda(
- const int s0, const int p0, const int d0, const int output_size,
- const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3,
- const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3,
- const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3,
- const float * src0, const float * src1, float * dst,
- cudaStream_t stream) {
- const int num_blocks = (output_size + CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE;
- conv_transpose_1d_kernel<<<num_blocks,CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE, 0, stream>>>(
- s0,p0,d0,output_size,
- src0_ne0, src0_ne1, src0_ne2, src0_ne3,
- src1_ne0, src1_ne1, src1_ne2, src1_ne3,
- dst_ne0, dst_ne1, dst_ne2, dst_ne3,
- src0,src1, dst);
- }
- void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
- const ggml_tensor * src0 = dst->src[0];
- const float * src0_d = (const float *)src0->data;
- const ggml_tensor * src1 = dst->src[1];
- const float * src1_d = (const float *)src1->data;
- float * dst_d = (float *)dst->data;
- cudaStream_t stream = ctx.stream();
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
- GGML_ASSERT(ggml_is_contiguous(src0));
- GGML_ASSERT(ggml_is_contiguous(src1));
- const int32_t * opts = (const int32_t *)dst->op_params;
- const int s0 = opts[0];
- const int p0 = 0;//opts[3];
- const int d0 = 1;//opts[4];
- const int64_t kernel_size = ggml_nelements(src0);
- const int64_t input_size = ggml_nelements(src1);
- const int64_t output_size = ggml_nelements(dst);
- conv_transpose_1d_f32_f32_cuda(s0, p0, d0, output_size,
- src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
- src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
- dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
- src0_d, src1_d, dst_d, stream);
- }
|