From e4ee316b1aee1b4dd612d6cb92dc6d55a0a7b85a Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Wed, 24 Jul 2024 12:37:57 +0200 Subject: [PATCH] Refactor residual kernel --- llmc/layernorm.cuh | 34 +++++++++++----------------------- 1 file changed, 11 insertions(+), 23 deletions(-) diff --git a/llmc/layernorm.cuh b/llmc/layernorm.cuh index 9e6e0a6bc..ed54beb36 100644 --- a/llmc/layernorm.cuh +++ b/llmc/layernorm.cuh @@ -143,7 +143,7 @@ __global__ void layernorm_forward_kernel6(floatX* __restrict__ out, float* __res __global__ void fused_residual_forward_kernel5(floatX* residual, floatX* normed, float* mean, float* rstd, const floatX* inp1, const floatX* inp2, const floatX* weight, const floatX* bias, - int use_kv, int B, int T, int C) { + int use_kv, int kv_offset, int B, int T, int C) { assert(blockDim.x == WARP_SIZE); int N = B * (use_kv ? 1 : T); @@ -167,10 +167,10 @@ __global__ void fused_residual_forward_kernel5(floatX* residual, floatX* normed, if(idx >= N) return; // adjust pointers to current token - residual += idx * C * (use_kv ? T : 1); - normed += idx * C * (use_kv ? T : 1); - inp1 += idx * C * (use_kv ? T : 1); - inp2 += idx * C * (use_kv ? T : 1); + residual += idx * C * (use_kv ? T : 1) + kv_offset * C; + normed += idx * C * (use_kv ? T : 1) + kv_offset * C; + inp1 += idx * C * (use_kv ? T : 1) + kv_offset * C; + inp2 += idx * C * (use_kv ? T : 1) + kv_offset * C; const float eps = 1e-5f; float sum = 0.0f; @@ -480,13 +480,6 @@ void fused_residual_forward5(floatX* residual, floatX* normed, float* mean, floa const int grid_size = CEIL_DIV(N, block_y); size_t smem = (2 + block_y) * C * sizeof(floatX); - if (use_kv) { - inp1 += kv_offset * C; - inp2 += kv_offset * C; - residual += kv_offset * C; - normed += kv_offset * C; - } - // in order to use more than 48 KiB of smem, need to call cudaFuncSetAttribute // this may fail, in which case we fall back to the smem free implementation. cudaCheck(cudaGetLastError()); @@ -495,20 +488,15 @@ void fused_residual_forward5(floatX* residual, floatX* normed, float* mean, floa if(status == cudaSuccess) { fused_residual_forward_kernel5<<>>(residual, normed, mean, rstd, inp1, inp2, - weight, bias, use_kv, B, T, C); + weight, bias, use_kv, kv_offset, B, T, C); } else { - assert(0); // this should never happen -> print warnings if it does use static var + if (use_kv) { + fprintf(stderr, "KV cache: fused_residual_forward_kernel5 failed to set shared memory size - exiting.\n"); + exit(EXIT_FAILURE); + } residual_forward(residual, inp1, inp2, N*C, stream); - layernorm_forward(normed, mean, rstd, residual, weight, bias, use_kv, kv_offset, N, 1, C, stream); - } - - if (use_kv) { - inp1 -= kv_offset * C; - inp2 -= kv_offset * C; - residual -= kv_offset * C; - normed -= kv_offset * C; + layernorm_forward(normed, mean, rstd, residual, weight, bias, 0, 0, N, 1, C, stream); } - cudaCheck(cudaGetLastError()); }