Skip to content

Commit

Permalink
Refactor residual kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
gordicaleksa committed Jul 24, 2024
1 parent 347bb07 commit e4ee316
Showing 1 changed file with 11 additions and 23 deletions.
34 changes: 11 additions & 23 deletions llmc/layernorm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ __global__ void layernorm_forward_kernel6(floatX* __restrict__ out, float* __res
__global__ void fused_residual_forward_kernel5(floatX* residual, floatX* normed, float* mean, float* rstd,
const floatX* inp1, const floatX* inp2,
const floatX* weight, const floatX* bias,
int use_kv, int B, int T, int C) {
int use_kv, int kv_offset, int B, int T, int C) {
assert(blockDim.x == WARP_SIZE);
int N = B * (use_kv ? 1 : T);

Expand All @@ -167,10 +167,10 @@ __global__ void fused_residual_forward_kernel5(floatX* residual, floatX* normed,
if(idx >= N) return;

// adjust pointers to current token
residual += idx * C * (use_kv ? T : 1);
normed += idx * C * (use_kv ? T : 1);
inp1 += idx * C * (use_kv ? T : 1);
inp2 += idx * C * (use_kv ? T : 1);
residual += idx * C * (use_kv ? T : 1) + kv_offset * C;
normed += idx * C * (use_kv ? T : 1) + kv_offset * C;
inp1 += idx * C * (use_kv ? T : 1) + kv_offset * C;
inp2 += idx * C * (use_kv ? T : 1) + kv_offset * C;

const float eps = 1e-5f;
float sum = 0.0f;
Expand Down Expand Up @@ -480,13 +480,6 @@ void fused_residual_forward5(floatX* residual, floatX* normed, float* mean, floa
const int grid_size = CEIL_DIV(N, block_y);
size_t smem = (2 + block_y) * C * sizeof(floatX);

if (use_kv) {
inp1 += kv_offset * C;
inp2 += kv_offset * C;
residual += kv_offset * C;
normed += kv_offset * C;
}

// in order to use more than 48 KiB of smem, need to call cudaFuncSetAttribute
// this may fail, in which case we fall back to the smem free implementation.
cudaCheck(cudaGetLastError());
Expand All @@ -495,20 +488,15 @@ void fused_residual_forward5(floatX* residual, floatX* normed, float* mean, floa
if(status == cudaSuccess) {
fused_residual_forward_kernel5<<<grid_size, dim3(WARP_SIZE, block_y), smem, stream>>>(residual, normed,
mean, rstd, inp1, inp2,
weight, bias, use_kv, B, T, C);
weight, bias, use_kv, kv_offset, B, T, C);
} else {
assert(0); // this should never happen -> print warnings if it does use static var
if (use_kv) {
fprintf(stderr, "KV cache: fused_residual_forward_kernel5 failed to set shared memory size - exiting.\n");
exit(EXIT_FAILURE);
}
residual_forward(residual, inp1, inp2, N*C, stream);
layernorm_forward(normed, mean, rstd, residual, weight, bias, use_kv, kv_offset, N, 1, C, stream);
}

if (use_kv) {
inp1 -= kv_offset * C;
inp2 -= kv_offset * C;
residual -= kv_offset * C;
normed -= kv_offset * C;
layernorm_forward(normed, mean, rstd, residual, weight, bias, 0, 0, N, 1, C, stream);
}

cudaCheck(cudaGetLastError());
}

Expand Down

0 comments on commit e4ee316

Please sign in to comment.