diff --git a/dev/cuda/rope.cu b/dev/cuda/rope.cu index 4e45a4711..d8a457c11 100644 --- a/dev/cuda/rope.cu +++ b/dev/cuda/rope.cu @@ -122,6 +122,66 @@ void rope_forward(int kernel_num, floatX *out, const floatX *inp, const floatX * } } +// ---------------------------------------------------------------------------- +// while we're at it, let's also briefly validate our backward kernel here + +void apply_rotary_emb_backward(float *dinp, const float *dout, const float *inp, const float *freqs_cis, int B, int T, int n_head, int head_dim) { + // backward pass of the RoPE embedding + for (int b = 0; b < B; b++) { + for (int t = 0; t < T; t++) { + int idx_bt = b * (T * n_head * head_dim) + t * (n_head * head_dim); + for (int h = 0; h < n_head; h++) { + int idx_bth = idx_bt + h * head_dim; + for (int d = 0; d < head_dim / 2; d++) { + // fetch the angle from freqs_cis + int freqs_idx = t * head_dim + 2 * d; + float freqs_cos = freqs_cis[freqs_idx]; + float freqs_sin = freqs_cis[freqs_idx + 1]; + // and the input index we'll be updating + int idx = idx_bth + 2 * d; + // backward pass is simple because freqs_cis is just scaling by a constant + dinp[idx] += dout[idx] * freqs_cos + dout[idx + 1] * freqs_sin; + dinp[idx + 1] += -dout[idx] * freqs_sin + dout[idx + 1] * freqs_cos; + } + } + } + } +} + +__global__ void rope_backward_inplace_kernel1(floatX *dinp, const floatX *dout, const floatX *freqs_cis, int B, int T, int n_head, int head_dim) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int head_dim_half = head_dim / 2; + if (idx >= B * T * n_head * head_dim_half) return; + // decode the individual indices + int b = idx / (T * n_head * head_dim_half); + int t = (idx / (n_head * head_dim_half)) % T; + int h = (idx / head_dim_half) % n_head; + int d = idx % head_dim_half; + // calculate the index in the input + int idx_bt = b * (T * n_head * head_dim) + t * (n_head * head_dim); + int idx_bth = idx_bt + h * head_dim; + int idxi = idx_bth + 2 * d; // index in the input + // fetch the freqs_cis + int freqs_idx = t * head_dim + 2 * d; + float freqs_cos = freqs_cis[freqs_idx]; + float freqs_sin = freqs_cis[freqs_idx + 1]; + // apply the rotation + float dout_real = (float)dout[idxi]; + float dout_imag = (float)dout[idxi + 1]; + dinp[idxi] = dout_real * freqs_cos + dout_imag * freqs_sin; + dinp[idxi + 1] = -dout_real * freqs_sin + dout_imag * freqs_cos; +} + +void rope_backward_inplace(floatX *dinp, const floatX *dout, const floatX *freqs_cis, int B, int T, int n_head, int head_dim, cudaStream_t stream) { + // backward pass of forward, mirrors the forward kernel in setup and indexing + const int block_size = 128; + int total_threads = B * T * 3 * n_head * head_dim / 2; + int num_blocks = ceil_div(total_threads, block_size); + rope_backward_inplace_kernel1<<>>(dinp, dout, freqs_cis, B, T, n_head, head_dim); + cudaCheck(cudaGetLastError()); +} + +// ---------------------------------------------------------------------------- // tester int main(int argc, char **argv) { srand(0); @@ -179,6 +239,16 @@ int main(int argc, char **argv) { printf("block_size %4d time %.4f ms\n", block_size, elapsed_time); } + // now also briefly validate the backward pass + // first, the reference CPU calculation + float *dinp = (float *)malloc(B * T * n_head * head_dim * sizeof(float)); + memset(dinp, 0, B * T * n_head * head_dim * sizeof(float)); // init at zero + apply_rotary_emb_backward(dinp, out, inp, freqs_cis, B, T, n_head, head_dim); + // now the GPU calculation (note it is done in-place, as we wish it to be to save space) + rope_backward_inplace(d_out, d_out, d_freqs_cis, B, T, n_head, head_dim, 0); + validate_result(d_out, dinp, "dinp", B * T * n_head * head_dim, 1e-5f); + printf("Backward pass result matches.\n"); + // free memory free(inp); free(freqs_cis);