Skip to content

Commit

Permalink
add backward kernel to dev/cuda for rope, to ensure correctness. but …
Browse files Browse the repository at this point in the history
…i mean, it's trivial. this can't possibly be the issue. it must be the repkv
  • Loading branch information
karpathy committed Sep 27, 2024
1 parent 075e430 commit 8d49062
Showing 1 changed file with 70 additions and 0 deletions.
70 changes: 70 additions & 0 deletions dev/cuda/rope.cu
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,66 @@ void rope_forward(int kernel_num, floatX *out, const floatX *inp, const floatX *
}
}

// ----------------------------------------------------------------------------
// while we're at it, let's also briefly validate our backward kernel here

void apply_rotary_emb_backward(float *dinp, const float *dout, const float *inp, const float *freqs_cis, int B, int T, int n_head, int head_dim) {
// backward pass of the RoPE embedding
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
int idx_bt = b * (T * n_head * head_dim) + t * (n_head * head_dim);
for (int h = 0; h < n_head; h++) {
int idx_bth = idx_bt + h * head_dim;
for (int d = 0; d < head_dim / 2; d++) {
// fetch the angle from freqs_cis
int freqs_idx = t * head_dim + 2 * d;
float freqs_cos = freqs_cis[freqs_idx];
float freqs_sin = freqs_cis[freqs_idx + 1];
// and the input index we'll be updating
int idx = idx_bth + 2 * d;
// backward pass is simple because freqs_cis is just scaling by a constant
dinp[idx] += dout[idx] * freqs_cos + dout[idx + 1] * freqs_sin;
dinp[idx + 1] += -dout[idx] * freqs_sin + dout[idx + 1] * freqs_cos;
}
}
}
}
}

__global__ void rope_backward_inplace_kernel1(floatX *dinp, const floatX *dout, const floatX *freqs_cis, int B, int T, int n_head, int head_dim) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int head_dim_half = head_dim / 2;
if (idx >= B * T * n_head * head_dim_half) return;
// decode the individual indices
int b = idx / (T * n_head * head_dim_half);
int t = (idx / (n_head * head_dim_half)) % T;
int h = (idx / head_dim_half) % n_head;
int d = idx % head_dim_half;
// calculate the index in the input
int idx_bt = b * (T * n_head * head_dim) + t * (n_head * head_dim);
int idx_bth = idx_bt + h * head_dim;
int idxi = idx_bth + 2 * d; // index in the input
// fetch the freqs_cis
int freqs_idx = t * head_dim + 2 * d;
float freqs_cos = freqs_cis[freqs_idx];
float freqs_sin = freqs_cis[freqs_idx + 1];
// apply the rotation
float dout_real = (float)dout[idxi];
float dout_imag = (float)dout[idxi + 1];
dinp[idxi] = dout_real * freqs_cos + dout_imag * freqs_sin;
dinp[idxi + 1] = -dout_real * freqs_sin + dout_imag * freqs_cos;
}

void rope_backward_inplace(floatX *dinp, const floatX *dout, const floatX *freqs_cis, int B, int T, int n_head, int head_dim, cudaStream_t stream) {
// backward pass of forward, mirrors the forward kernel in setup and indexing
const int block_size = 128;
int total_threads = B * T * 3 * n_head * head_dim / 2;
int num_blocks = ceil_div(total_threads, block_size);
rope_backward_inplace_kernel1<<<num_blocks, block_size, 0, stream>>>(dinp, dout, freqs_cis, B, T, n_head, head_dim);
cudaCheck(cudaGetLastError());
}

// ----------------------------------------------------------------------------
// tester
int main(int argc, char **argv) {
srand(0);
Expand Down Expand Up @@ -179,6 +239,16 @@ int main(int argc, char **argv) {
printf("block_size %4d time %.4f ms\n", block_size, elapsed_time);
}

// now also briefly validate the backward pass
// first, the reference CPU calculation
float *dinp = (float *)malloc(B * T * n_head * head_dim * sizeof(float));
memset(dinp, 0, B * T * n_head * head_dim * sizeof(float)); // init at zero
apply_rotary_emb_backward(dinp, out, inp, freqs_cis, B, T, n_head, head_dim);
// now the GPU calculation (note it is done in-place, as we wish it to be to save space)
rope_backward_inplace(d_out, d_out, d_freqs_cis, B, T, n_head, head_dim, 0);
validate_result(d_out, dinp, "dinp", B * T * n_head * head_dim, 1e-5f);
printf("Backward pass result matches.\n");

// free memory
free(inp);
free(freqs_cis);
Expand Down

0 comments on commit 8d49062

Please sign in to comment.