cross-entropy-loss.cu 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. /**
  2. * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
  3. *
  4. * MIT License
  5. *
  6. * Copyright (c) 2023-2024 The ggml authors
  7. *
  8. * Permission is hereby granted, free of charge, to any person obtaining a copy
  9. * of this software and associated documentation files (the "Software"), to deal
  10. * in the Software without restriction, including without limitation the rights
  11. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  12. * copies of the Software, and to permit persons to whom the Software is
  13. * furnished to do so, subject to the following conditions:
  14. *
  15. * The above copyright notice and this permission notice shall be included in all
  16. * copies or substantial portions of the Software.
  17. *
  18. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  19. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  20. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  21. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  22. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  23. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  24. * SOFTWARE.
  25. */
  26. #include "common.cuh"
  27. #include "cross-entropy-loss.cuh"
  28. #include "sumrows.cuh"
  29. #include <cmath>
  30. #include <cstdint>
  31. static __global__ void cross_entropy_loss_f32(const float * logits, const float * labels, float * dst, const int nclasses, const int k) {
  32. const int warp_id = threadIdx.x / WARP_SIZE;
  33. const int lane_id = threadIdx.x % WARP_SIZE;
  34. const int i0 = blockDim.x*blockIdx.x + warp_id*WARP_SIZE;
  35. const int ne_tmp = WARP_SIZE*nclasses;
  36. extern __shared__ float tmp_all[];
  37. float * tmp_logits = tmp_all + (2*warp_id + 0)*ne_tmp;
  38. float * tmp_labels = tmp_all + (2*warp_id + 1)*ne_tmp;
  39. // Each warp first loads ne_tmp logits/labels into shared memory:
  40. for (int i = lane_id; i < ne_tmp; i += WARP_SIZE) {
  41. const int ig = i0*nclasses + i; // ig == i global
  42. tmp_logits[i] = ig < k*nclasses ? logits[ig] : 0.0f;
  43. tmp_labels[i] = ig < k*nclasses ? labels[ig] : 0.0f;
  44. }
  45. // Each thread in the warp then calculates the cross entropy loss for a single row.
  46. // TODO: pad in order to avoid shared memory bank conflicts.
  47. // Find maximum for softmax:
  48. float max = -INFINITY;
  49. for (int i = 0; i < nclasses; ++i) {
  50. max = fmaxf(max, tmp_logits[lane_id*nclasses + i]);
  51. }
  52. // Calculate log(softmax(logits)) which is just logits - max:
  53. float sum = 0.0f;
  54. for (int i = 0; i < nclasses; ++i) {
  55. float val = tmp_logits[lane_id*nclasses + i] - max;
  56. sum += expf(val);
  57. tmp_logits[lane_id*nclasses + i] = val;
  58. }
  59. sum = logf(sum);
  60. // log(exp(logits - max) / sum) = (logits - max) - log(sum)
  61. float loss = 0.0f;
  62. for (int i = 0; i < nclasses; ++i) {
  63. loss += (tmp_logits[lane_id*nclasses + i] - sum) * tmp_labels[lane_id*nclasses + i];
  64. }
  65. loss = -warp_reduce_sum(loss) / (float)k;
  66. __syncthreads();
  67. if (lane_id == 0) {
  68. tmp_all[warp_id] = loss;
  69. }
  70. __syncthreads();
  71. if (warp_id != 0) {
  72. return;
  73. }
  74. loss = lane_id < CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE/WARP_SIZE ? tmp_all[lane_id] : 0.0f;
  75. loss = warp_reduce_sum(loss);
  76. if (lane_id != 0) {
  77. return;
  78. }
  79. dst[blockIdx.x] = loss;
  80. }
  81. void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  82. const ggml_tensor * src0 = dst->src[0];
  83. const ggml_tensor * src1 = dst->src[1];
  84. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  85. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  86. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  87. GGML_ASSERT(ggml_is_contiguous(src0));
  88. GGML_ASSERT(ggml_is_contiguous(src1));
  89. GGML_ASSERT(ggml_is_contiguous(dst));
  90. const int64_t ne00 = src0->ne[0];
  91. const int64_t nrows = ggml_nrows(src0);
  92. const float * src0_d = (const float *) src0->data;
  93. const float * src1_d = (const float *) src1->data;
  94. float * dst_d = (float *) dst->data;
  95. ggml_cuda_pool & pool = ctx.pool();
  96. cudaStream_t stream = ctx.stream();
  97. const dim3 blocks_dim(CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1);
  98. const dim3 blocks_num((nrows + CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE - 1) / CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1);
  99. const int shmem = 2*CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE*ne00*sizeof(float);
  100. ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
  101. cross_entropy_loss_f32<<<blocks_num, blocks_dim, shmem, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
  102. // Combine results from individual blocks:
  103. sum_rows_f32_cuda(dst_tmp.ptr, dst_d, blocks_num.x, 1, stream);
  104. }