Skip to content

Commit

Permalink
Refactor layernorm kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
gordicaleksa committed Jul 24, 2024
1 parent 5e30cb4 commit 347bb07
Showing 1 changed file with 8 additions and 16 deletions.
24 changes: 8 additions & 16 deletions llmc/layernorm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ __global__ void layernorm_forward_kernel3(floatX* __restrict__ out, float* __res

__global__ void layernorm_forward_kernel6(floatX* __restrict__ out, float* __restrict__ mean, float* __restrict__ rstd,
const floatX* __restrict__ inp, const floatX* __restrict__ weight,
const floatX* __restrict__ bias, int use_kv, int B, int T, int C) {
const floatX* __restrict__ bias, int use_kv, int kv_offset, int B, int T, int C) {
assert(blockDim.x == WARP_SIZE);
int N = B * (use_kv ? 1 : T);

Expand All @@ -90,8 +90,8 @@ __global__ void layernorm_forward_kernel6(floatX* __restrict__ out, float* __res
if(idx >= N) { return; } // guard

// adjust pointers to current token
inp += idx * C * (use_kv ? T : 1);
out += idx * C * (use_kv ? T : 1);
inp += idx * C * (use_kv ? T : 1) + kv_offset * C;
out += idx * C * (use_kv ? T : 1) + kv_offset * C;

const float eps = 1e-5f;
float sum = 0.0f;
Expand Down Expand Up @@ -442,30 +442,22 @@ void layernorm_forward(floatX* out, float* mean, float* rstd,
const int grid_size = CEIL_DIV(N, block_y);
size_t smem = (2 + block_y) * C * sizeof(floatX);

if (use_kv) {
inp += kv_offset * C;
out += kv_offset * C;
}

// in order to use more than 48 KiB of smem, need to call cudaFuncSetAttribute
// this may fail, in which case we fall back to the smem free implementation.
cudaCheck(cudaGetLastError());
auto status = cudaFuncSetAttribute(layernorm_forward_kernel6, cudaFuncAttributeMaxDynamicSharedMemorySize, smem);
cudaCheck(cudaGetLastError());
if (status == cudaSuccess) {
layernorm_forward_kernel6<<<grid_size, dim3(WARP_SIZE, block_y), smem, stream>>>(out, mean, rstd, inp, weight, bias, use_kv, B, T, C);
layernorm_forward_kernel6<<<grid_size, dim3(WARP_SIZE, block_y), smem, stream>>>(out, mean, rstd, inp, weight, bias, use_kv, kv_offset, B, T, C);
} else {
assert(0); // this should never happen
if (use_kv) {
fprintf(stderr, "KV cache: layernorm_forward_kernel6 failed to set shared memory size - exiting.\n");
exit(EXIT_FAILURE);
}
// fall back to the version without shared memory
const int grid_size_fb = CEIL_DIV(N * WARP_SIZE, block_size);
layernorm_forward_kernel3<<<grid_size_fb, block_size, 0, stream>>>(out, mean, rstd, inp, weight, bias, N, C);
}

if (use_kv) {
inp -= kv_offset * C;
out -= kv_offset * C;
}

cudaCheck(cudaGetLastError());
}

Expand Down

0 comments on commit 347bb07

Please sign in to comment.