From 2ebf8f6b8a1259f4a6003a5d91d16532991ad5b9 Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Wed, 25 Sep 2024 10:37:53 -0700 Subject: [PATCH] Add rmsnorm fused kernel --- llmc/layernorm.cuh | 95 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) diff --git a/llmc/layernorm.cuh b/llmc/layernorm.cuh index 1387b11ab..41f9ec6d3 100644 --- a/llmc/layernorm.cuh +++ b/llmc/layernorm.cuh @@ -278,6 +278,77 @@ __global__ void fused_residual_forward_kernel5(floatX* residual, floatX* normed, } } +__global__ void fused_residual_rmsnorm_forward_kernel5(floatX* residual, floatX* normed, float* rrms, + const floatX* inp1, const floatX* inp2, + const floatX* weight, + int N, int C) { + assert(blockDim.x == WARP_SIZE); + + // load weights and biases into shared memory + // do this before we allow any threads to exit! + extern __shared__ char* params[]; + // load128/store128 sometimes generated multiple instructions when the types here were floatX*, so + // let's keep everything as x128 + x128* s_weight = reinterpret_cast(params); + x128* s_res = reinterpret_cast(params) + ((1 + threadIdx.y) * C / x128::size); + + int sidx = (threadIdx.x + WARP_SIZE * threadIdx.y) * x128::size; + for(int i = sidx; i < C; i += blockDim.y * WARP_SIZE * x128::size) { + s_weight[i/x128::size] = load128(weight + i); + } + __syncthreads(); + + int idx = blockIdx.x * blockDim.y + threadIdx.y; + if(idx > N) return; + + // adjust pointers to current token + residual += C * idx; + normed += C * idx; + inp1 += C * idx; + inp2 += C * idx; + + const float eps = 1e-5f; + for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { + const x128 in1 = load128cs(inp1 + c); + const x128 in2 = load128cs(inp2 + c); + x128 out; + for(int k = 0; k < x128::size; ++k) { + out[k] = (float)in1[k] + (float)in2[k]; + } + store128cs(residual + c, out); + s_res[c / x128::size] = out; + } + + float v = 0.f; + + for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { + const x128 res = s_res[c / x128::size]; + for(int k = 0; k < x128::size; ++k) { + v += (float)res[k] * (float)res[k]; + } + } + + v = warpReduceSum(v) / C; + float s = rsqrtf(v + eps); + + for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { + const x128 res = s_res[c / x128::size]; + const x128 w = s_weight[c / x128::size]; + x128 out; + for(int k = 0; k < x128::size; ++k) { + float n = s * (float)res[k]; // normalized output + float o = n * (float)w[k]; // scale + out[k] = o; + } + + store128cs(normed + c, out); + } + // cache the rrms for the backward pass later + if(threadIdx.x == 0) { + rrms[idx] = s; + } +} + __global__ void residual_forward_kernel(floatX* out, const floatX* inp1, const floatX* inp2) { int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size; @@ -549,6 +620,30 @@ void fused_residual_forward5(floatX* residual, floatX* normed, float* mean, floa cudaCheck(cudaGetLastError()); } +void fused_residual_rmsnorm_forward5(floatX* residual, floatX* normed, float* rrms, + const floatX* inp1, const floatX* inp2, + const floatX* weight, + int N, int C, cudaStream_t stream) { + const int block_size = 256; + int block_y = block_size / WARP_SIZE; + const int grid_size = CEIL_DIV(N, block_y); + size_t smem = (1 + block_y) * C * sizeof(floatX); + + // in order to use more than 48 KiB of smem, need to call cudaFuncSetAttribute + // this may fail, in which case we fall back to the smem free implementation. + cudaCheck(cudaGetLastError()); + auto status = cudaFuncSetAttribute(fused_residual_rmsnorm_forward_kernel5, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); + cudaCheck(cudaGetLastError()); + if(status == cudaSuccess) { + fused_residual_rmsnorm_forward_kernel5<<>>(residual, normed, + rrms, inp1, inp2, + weight, N, C); + } else { + assert(false); + } + cudaCheck(cudaGetLastError()); +} + void layernorm_backward(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch, const floatX* dout, const floatX* inp, const floatX* weight, const float* mean, const float* rstd, int B, int T, int C, cudaStream_t stream) {