Skip to content

Commit

Permalink
Merge pull request #769 from gordicaleksa/fused_rmsnorm
Browse files Browse the repository at this point in the history
Fused rmsnorm reference
  • Loading branch information
karpathy authored Sep 25, 2024
2 parents 52c7254 + 2ebf8f6 commit 6538df6
Showing 1 changed file with 95 additions and 0 deletions.
95 changes: 95 additions & 0 deletions llmc/layernorm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,77 @@ __global__ void fused_residual_forward_kernel5(floatX* residual, floatX* normed,
}
}

__global__ void fused_residual_rmsnorm_forward_kernel5(floatX* residual, floatX* normed, float* rrms,
const floatX* inp1, const floatX* inp2,
const floatX* weight,
int N, int C) {
assert(blockDim.x == WARP_SIZE);

// load weights and biases into shared memory
// do this before we allow any threads to exit!
extern __shared__ char* params[];
// load128/store128 sometimes generated multiple instructions when the types here were floatX*, so
// let's keep everything as x128
x128* s_weight = reinterpret_cast<x128*>(params);
x128* s_res = reinterpret_cast<x128*>(params) + ((1 + threadIdx.y) * C / x128::size);

int sidx = (threadIdx.x + WARP_SIZE * threadIdx.y) * x128::size;
for(int i = sidx; i < C; i += blockDim.y * WARP_SIZE * x128::size) {
s_weight[i/x128::size] = load128(weight + i);
}
__syncthreads();

int idx = blockIdx.x * blockDim.y + threadIdx.y;
if(idx > N) return;

// adjust pointers to current token
residual += C * idx;
normed += C * idx;
inp1 += C * idx;
inp2 += C * idx;

const float eps = 1e-5f;
for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) {
const x128 in1 = load128cs(inp1 + c);
const x128 in2 = load128cs(inp2 + c);
x128 out;
for(int k = 0; k < x128::size; ++k) {
out[k] = (float)in1[k] + (float)in2[k];
}
store128cs(residual + c, out);
s_res[c / x128::size] = out;
}

float v = 0.f;

for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) {
const x128 res = s_res[c / x128::size];
for(int k = 0; k < x128::size; ++k) {
v += (float)res[k] * (float)res[k];
}
}

v = warpReduceSum(v) / C;
float s = rsqrtf(v + eps);

for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) {
const x128 res = s_res[c / x128::size];
const x128 w = s_weight[c / x128::size];
x128 out;
for(int k = 0; k < x128::size; ++k) {
float n = s * (float)res[k]; // normalized output
float o = n * (float)w[k]; // scale
out[k] = o;
}

store128cs(normed + c, out);
}
// cache the rrms for the backward pass later
if(threadIdx.x == 0) {
rrms[idx] = s;
}
}

__global__ void residual_forward_kernel(floatX* out, const floatX* inp1, const floatX* inp2) {
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size;

Expand Down Expand Up @@ -549,6 +620,30 @@ void fused_residual_forward5(floatX* residual, floatX* normed, float* mean, floa
cudaCheck(cudaGetLastError());
}

void fused_residual_rmsnorm_forward5(floatX* residual, floatX* normed, float* rrms,
const floatX* inp1, const floatX* inp2,
const floatX* weight,
int N, int C, cudaStream_t stream) {
const int block_size = 256;
int block_y = block_size / WARP_SIZE;
const int grid_size = CEIL_DIV(N, block_y);
size_t smem = (1 + block_y) * C * sizeof(floatX);

// in order to use more than 48 KiB of smem, need to call cudaFuncSetAttribute
// this may fail, in which case we fall back to the smem free implementation.
cudaCheck(cudaGetLastError());
auto status = cudaFuncSetAttribute(fused_residual_rmsnorm_forward_kernel5, cudaFuncAttributeMaxDynamicSharedMemorySize, smem);
cudaCheck(cudaGetLastError());
if(status == cudaSuccess) {
fused_residual_rmsnorm_forward_kernel5<<<grid_size, dim3(WARP_SIZE, block_y), smem, stream>>>(residual, normed,
rrms, inp1, inp2,
weight, N, C);
} else {
assert(false);
}
cudaCheck(cudaGetLastError());
}

void layernorm_backward(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch,
const floatX* dout, const floatX* inp, const floatX* weight, const float* mean, const float* rstd,
int B, int T, int C, cudaStream_t stream) {
Expand Down

0 comments on commit 6538df6

Please sign in to comment.