From 075e430d23f8fa887bb290cb64b509a753c8c0b5 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Fri, 27 Sep 2024 03:37:53 +0000 Subject: [PATCH] just pushing what i have. it's epsilon away from working sigh. basically at this point of where prints happen, gradients match. but once we backward attention, rope and repkv, gradients don't match. attention hasn't changed so that can't be wrong (?), so it's either repkv or rope. i have to go slower and double check the backward pass of both of these in detail. also had to introduce one more additional buffer for backward --- llmc/repkv.cuh | 58 +++++++++++++++++++++++++++++++++++++++++++++++++ llmc/rope.cuh | 41 +++++++++++++++++++++++++++++++--- train_llama3.cu | 48 +++++++++++++++++++++++----------------- train_llama3.py | 14 +++++++----- 4 files changed, 132 insertions(+), 29 deletions(-) diff --git a/llmc/repkv.cuh b/llmc/repkv.cuh index 666ad8c44..f4c517eaa 100644 --- a/llmc/repkv.cuh +++ b/llmc/repkv.cuh @@ -48,6 +48,54 @@ __global__ void repkv_forward_kernel1(floatX* replicated_qkv, replicated_qkv[idx_flat] = __ldcs(&gqa_qkv[inp_idx]); } +__global__ void repkv_backward_kernel1(floatX* dinp, const floatX* dout, + int B, int N, int NH, int replicate_factor, int HD) { + // we have a single tensor dout of shapae of (B, N 3 * NH * HD) + // we want to reduce sum (for K and V) into (B, N, (NH + 2*(NH/replicate_factor)) * HD) + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= B * N * 3 * NH * HD) { return;} + int dout_idx = idx; // keep backup + + // decode the dout index + int d = idx % HD; + idx /= HD; + int nh = idx % NH; + idx /= NH; + int c = idx % 3; + idx /= 3; + int n = idx % N; + int b = idx / N; + + int dinp_idx; + int nh_total = NH + 2 * (NH / replicate_factor); + + if (c == 0) { + dinp_idx = b * N * nh_total * HD + n * nh_total * HD + 0 * NH * HD + nh * HD + d; + dinp[dinp_idx] = __ldcs(&dout[dout_idx]); + } else if (c == 1) { + if (nh % replicate_factor == 0) { + float reduced_sum = 0; + for (int i = 0; i < replicate_factor; i++) { + reduced_sum += (float) __ldcs(&dout[dout_idx+HD*i]); + } + + dinp_idx = b * N * nh_total * HD + n * nh_total * HD + 1 * NH * HD + (nh / replicate_factor) * HD + d; + dinp[dinp_idx] = reduced_sum; + } + + } else { + if (nh % replicate_factor == 0) { + float reduced_sum = 0; + for (int i = 0; i < replicate_factor; i++) { + reduced_sum += (float) __ldcs(&dout[dout_idx+HD*i]); + } + dinp_idx = b * N * nh_total * HD + n * nh_total * HD + (NH * HD + (NH / replicate_factor) * HD) + (nh / replicate_factor) * HD + d; + dinp[dinp_idx] = reduced_sum; + } + } +} + +// kernel launchers void repkv_forward(floatX* out, const floatX* inp, int B, int T, int NH, int NH_KV, int HD, cudaStream_t stream) { // NH = number of query heads, NH_KV = number of key and value heads, HD = head dimension const int block_size = 128; @@ -61,3 +109,13 @@ void repkv_forward(floatX* out, const floatX* inp, int B, int T, int NH, int NH_ } cudaCheck(cudaGetLastError()); } + +void repkv_backward(floatX* dinp, const floatX* dout, + const int B, const int T, const int NH, const int NH_KV, const int d) { + const int block_size = 128; + int total_threads = B * T * (3 * NH) * d; + int num_blocks = CEIL_DIV(total_threads, block_size); + int replicate_factor = NH / NH_KV; + repkv_backward_kernel1<<>>(dinp, dout, B, T, NH, replicate_factor, d); + cudaCheck(cudaGetLastError()); +} diff --git a/llmc/rope.cuh b/llmc/rope.cuh index ca5fc56f9..50371c47b 100644 --- a/llmc/rope.cuh +++ b/llmc/rope.cuh @@ -58,18 +58,44 @@ __global__ void rope_forward_kernel1(floatX *out, const floatX *inp, const float int idx_bt = b * (T * 3 * n_head * head_dim) + t * (3 * n_head * head_dim); int idx_bth = idx_bt + qkv * (n_head * head_dim) + h * head_dim; int idxi = idx_bth + 2 * d; // index in the input - // fetch the input - float x_real = inp[idxi]; - float x_imag = inp[idxi + 1]; // fetch the freqs_cis int freqs_idx = t * head_dim + 2 * d; float freqs_cos = freqs_cis[freqs_idx]; float freqs_sin = freqs_cis[freqs_idx + 1]; + // fetch the input + float x_real = inp[idxi]; + float x_imag = inp[idxi + 1]; // apply the rotation out[idxi] = x_real * freqs_cos - x_imag * freqs_sin; out[idxi + 1] = x_real * freqs_sin + x_imag * freqs_cos; } +__global__ void rope_backward_inplace_kernel1(floatX *dinp, const floatX *dout, const floatX *freqs_cis, int B, int T, int n_head, int head_dim) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int head_dim_half = head_dim / 2; + if (idx >= B * T * 3 * n_head * head_dim_half) return; + // decode the qkv index early so we can early exit if it's a value index + int qkv = (idx / (n_head * head_dim_half)) % 3; + if (qkv == 2) return; // no-op for v + // decode the individual indices and get the input index + int b = idx / (T * 3 * n_head * head_dim_half); + int t = (idx / (3 * n_head * head_dim_half)) % T; + int h = (idx / head_dim_half) % n_head; + int d = idx % head_dim_half; + int idx_bt = b * (T * 3 * n_head * head_dim) + t * (3 * n_head * head_dim); + int idx_bth = idx_bt + qkv * (n_head * head_dim) + h * head_dim; + int idxi = idx_bth + 2 * d; // index in the input + // fetch the freqs_cis + int freqs_idx = t * head_dim + 2 * d; + float freqs_cos = freqs_cis[freqs_idx]; + float freqs_sin = freqs_cis[freqs_idx + 1]; + // backward + float dout_real = (float)dout[idxi]; + float dout_imag = (float)dout[idxi + 1]; + dinp[idxi] = dout_real * freqs_cos + dout_imag * freqs_sin; + dinp[idxi + 1] = -dout_real * freqs_sin + dout_imag * freqs_cos; +} + void rope_forward(floatX *out, const floatX *inp, const floatX *freqs_cis, int B, int T, int n_head, int head_dim, cudaStream_t stream) { // the input and output to this kernel are (B, T, 3, NH, HD) where the 3 is q,k,v // we are going to launch exactly one thread per element of the output, @@ -81,3 +107,12 @@ void rope_forward(floatX *out, const floatX *inp, const floatX *freqs_cis, int B rope_forward_kernel1<<>>(out, inp, freqs_cis, B, T, n_head, head_dim); cudaCheck(cudaGetLastError()); } + +void rope_backward_inplace(floatX *dinp, const floatX *dout, const floatX *freqs_cis, int B, int T, int n_head, int head_dim, cudaStream_t stream) { + // backward pass of forward, mirrors the forward kernel in setup and indexing + const int block_size = 128; + int total_threads = B * T * 3 * n_head * head_dim / 2; + int num_blocks = CEIL_DIV(total_threads, block_size); + rope_backward_inplace_kernel1<<>>(dinp, dout, freqs_cis, B, T, n_head, head_dim); + cudaCheck(cudaGetLastError()); +} diff --git a/train_llama3.cu b/train_llama3.cu index 630342ab7..2cc554a0b 100644 --- a/train_llama3.cu +++ b/train_llama3.cu @@ -64,9 +64,9 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage. #include "llmc/adamw.cuh" // defines: global_norm_squared #include "llmc/global_norm.cuh" -// defines: repkv_forward +// defines: repkv_forward, repkv_backward #include "llmc/repkv.cuh" -// defines: precompute_freqs_cis, rope_forward +// defines: precompute_freqs_cis, rope_forward, rope_backward_inplace #include "llmc/rope.cuh" // defines: swiglu_forward, swiglu_backward #include "llmc/swiglu.cuh" @@ -197,7 +197,7 @@ void* malloc_and_point_parameters(ParameterTensors* params, size_t* param_elemen return params_memory; } -constexpr int NUM_ACTIVATION_TENSORS = 21; +constexpr int NUM_ACTIVATION_TENSORS = 22; typedef struct { floatX* encoded; // (B, T, C) floatX* ln1; // (L, B, T, C) @@ -234,6 +234,7 @@ typedef struct { // some additional scratch buffers floatX* scratch_bt4c; // (B, T, 4*C) floatX* scratch_btc; // (B, T, C) + floatX* scratch_bt4c2; // (B, T, 4*C), for simplicify use this one for backward pass too, probably not needed } ActivationTensors; @@ -292,6 +293,7 @@ void fill_in_activation_sizes(const ActivationTensors* data, TensorSpec (&tensor tensors[18] = TENSOR_SPEC(data->output, B * T * max(qkv_channels, max(ffn_channels, max(NH*T, Vp)))); tensors[19] = TENSOR_SPEC(data->scratch_bt4c, B * T * ffn_channels); tensors[20] = TENSOR_SPEC(data->scratch_btc, B * T * C); + tensors[21] = TENSOR_SPEC(data->scratch_bt4c2, B * T * ffn_channels); } void* malloc_and_point_activations(TensorSpec (&tensors)[NUM_ACTIVATION_TENSORS]) { @@ -839,7 +841,6 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int // get the pointers of the weights for this layer floatX* l_ln1w = params.ln1w + l * C; - floatX* l_ln1b = params.ln1b + l * C; floatX* l_qkvw = params.qkvw + l * qkv_channels * C; floatX* l_attprojw = params.attprojw + l * C * C; floatX* l_ln2w = params.ln2w + l * C; @@ -860,13 +861,11 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int floatX* dl_fcprojb = grads.fcprojb + l * C; // get the pointers of the activations for this layer floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + l * B * T * C : acts.lnf; - float* l_ln1_mean = acts.ln1_mean + l * B * T; float* l_ln1_rstd = acts.ln1_rstd + l * B * T; floatX* l_qkvr = acts.qkvr + l * B * T * qkv_channels; floatX* l_atty = acts.atty + l * B * T * C; floatX* l_residual2 = acts.residual2 + l * B * T * C; floatX* l_ln2 = (model->recompute < 2) ? acts.ln2 + l * B * T * C : acts.lnf; - float* l_ln2_mean = acts.ln2_mean + l * B * T; float* l_ln2_rstd = acts.ln2_rstd + l * B * T; floatX* l_fch_pre_gelu = acts.fch + l * B * T * ffn_channels; floatX* l_fch_gelu = (model->recompute < 1) ? acts.fch_gelu + l * B * T * ffn_channels_post_gelu : acts.fch_gelu; @@ -875,6 +874,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int // re-using this memory in every Transformer block as we calculate backward pass floatX* dl_bt4c = (floatX*)model->acts.scratch_bt4c; + floatX* dl_bt4c2 = (floatX*)model->acts.scratch_bt4c2; // same size as dl_bt4c, just a second buffer // start the backward pass for this layer if(model->recompute >= 1) { @@ -886,13 +886,16 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int // backward the 2nd matmul of MLP matmul_backward(dl_bt4c, dl_fcprojw, dl_fcprojb, dresidual, l_fch_gelu, l_fcprojw, scratchF, B, T, ffn_channels_post_gelu, C, main_stream); // backward the swiglu here, use scratchX to hold the grad because SwiGLU can't be inplace - swiglu_backward(scratchX, dl_bt4c, l_fch_pre_gelu, B, T, ffn_channels_post_gelu, main_stream); + swiglu_backward(dl_bt4c2, dl_bt4c, l_fch_pre_gelu, B, T, ffn_channels_post_gelu, main_stream); // backward the 1st matmul of MLP if(model->recompute >= 2) { // same as gelu above, l_ln1 and l_ln2 are just buffers if recompute >= 2, recompute them here on demand rmsnorm_forward(l_ln2, l_ln2_rstd, l_residual2, l_ln2w, B, T, C, main_stream); } - matmul_backward(dl_btc, dl_fcw, dl_fcb, scratchX, l_ln2, l_fcw, scratchF, B, T, C, ffn_channels, main_stream); + matmul_backward(dl_btc, dl_fcw, dl_fcb, dl_bt4c2, l_ln2, l_fcw, scratchF, B, T, C, ffn_channels, main_stream); + // rmsnorm backward does += to the dresidual, so it correctly accumulates grad from the MLP block above + rmsnorm_backward(dresidual, dl_ln2w, scratchF, dl_btc, l_residual2, l_ln2w, l_ln2_rstd, B, T, C, main_stream); + matmul_backward(dl_btc, dl_attprojw, dl_attprojb, dresidual, l_atty, l_attprojw, scratchF, B, T, C, C, main_stream); // ------------------------------------------------------------------------ // DEBUGGING: we only work until this point right now, so exit here @@ -905,19 +908,18 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int } // write to .bin file // move output to cpu - floatX* cpu_output = (floatX*)mallocCheck(B*T*C * sizeof(floatX)); - cudaCheck(cudaMemcpy(cpu_output, output, B*T*C * sizeof(floatX), cudaMemcpyDeviceToHost)); + // int sz = B*T*qkv_channels; //B*T*C; + int sz = B*T*C; + floatX* cpu_output = (floatX*)mallocCheck(sz * sizeof(floatX)); + cudaCheck(cudaMemcpy(cpu_output, output, sz * sizeof(floatX), cudaMemcpyDeviceToHost)); FILE* f = fopen("out.bin", "wb"); - fwrite(cpu_output, sizeof(floatX), B*T*C, f); + fwrite(cpu_output, sizeof(floatX), sz, f); fclose(f); exit(0); // ------------------------------------------------------------------------ - // layernorm backward does += to the dresidual, so it correctly accumulates grad from the MLP block above - layernorm_backward(dresidual, dl_ln2w, dl_ln2b, scratchF, dl_btc, l_residual2, l_ln2w, l_ln2_mean, l_ln2_rstd, B, T, C, main_stream); - matmul_backward(dl_btc, dl_attprojw, dl_attprojb, dresidual, l_atty, l_attprojw, scratchF, B, T, C, C, main_stream); - #ifdef ENABLE_CUDNN + printf("cuDNN path TODO\n"); exit(0); float* l_att = (float*)acts.att + l * B * NH * T; // cuDNN needs a smaller FP32 tensor attention_backward_cudnn(dl_bt4c, dl_btc, l_qkvr, l_atty, (float*)l_att, B, T, NH, C, main_stream); #else @@ -927,13 +929,19 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int floatX* buffer_b = l_fch_pre_gelu; // this is B x T x 4C, so even larger than what we need attention_backward(dl_bt4c, buffer_b, scratchX, buffer_a, dl_btc, l_qkvr, l_att, B, T, C, NH, main_stream); #endif + // backward rope (this can be done in-place) + rope_backward_inplace(dl_bt4c, dl_bt4c, model->freqs_cis, B, T, NH, hd, main_stream); + // backward repkv (use scratchX as gradient buffer here) + repkv_backward(dl_bt4c2, dl_bt4c, B, T, NH, n_kv_head, hd); + + // <--- here the gradients don't match, so there is an issue in between + + // backward QKV projection if(model->recompute >= 2) { - layernorm_forward(l_ln1, l_ln1_mean, l_ln1_rstd, residual, l_ln1w, l_ln1b, B, T, C, main_stream); + rmsnorm_forward(l_ln1, l_ln1_rstd, residual, l_ln1w, B, T, C, main_stream); } - // QKV parameter gradients - matmul_backward(dl_btc, dl_qkvw, dl_qkvb, dl_bt4c, l_ln1, l_qkvw, scratchF, B, T, C, 3 * C, main_stream); - // layernorm backward does += to dresidual, so it correctly accumulates gradient for the Attention block above - layernorm_backward(dresidual, dl_ln1w, dl_ln1b, scratchF, dl_btc, residual, l_ln1w, l_ln1_mean, l_ln1_rstd, B, T, C, main_stream); + matmul_backward(dl_btc, dl_qkvw, dl_qkvb, dl_bt4c2, l_ln1, l_qkvw, scratchF, B, T, C, qkv_channels, main_stream); + rmsnorm_backward(dresidual, dl_ln1w, scratchF, dl_btc, residual, l_ln1w, l_ln1_rstd, B, T, C, main_stream); // Accumulate gradients from this layer in a background stream. if(last_step) { diff --git a/train_llama3.py b/train_llama3.py index 455282e1d..cd1549b42 100644 --- a/train_llama3.py +++ b/train_llama3.py @@ -197,6 +197,12 @@ def forward(self, x, freqs_cis=None, start_pos=None, mask=None): att = F.softmax(scores.float(), dim=-1).type_as(q) y = att @ v # (B, NH, T, T) x (B, NH, T, HD) -> (B, NH, T, HD) y = y.transpose(1, 2).contiguous().view(B, T, C) + + DEBUG_POINT = y.detach() + DEBUG_POINT = DEBUG_POINT.requires_grad_(True) + self.DEBUG_POINT = DEBUG_POINT + y = DEBUG_POINT + y = self.c_proj(y) return y @@ -234,11 +240,7 @@ def __init__(self, config): def forward(self, x, freqs_cis=None, start_pos=None, mask=None): x = x + self.attn(self.ln_1(x), freqs_cis, start_pos, mask) - MLP_INPUT = self.ln_2(x) - MLP_INPUT = MLP_INPUT.detach() - MLP_INPUT.requires_grad = True - self.MLP_INPUT = MLP_INPUT - x = x + self.mlp(MLP_INPUT) + x = x + self.mlp(self.ln_2(x)) return x # ----------------------------------------------------------------------------- @@ -1260,7 +1262,7 @@ def get_lr(it): # --------------------------------------------------------------------- # DEBUGGING: print first 32 elements of x - x = model.transformer.h[-1].MLP_INPUT.grad + x = model.transformer.h[-1].attn.DEBUG_POINT.grad for i in range(32): print("q[{}]: {:.8f}".format(i, x.view(-1)[i].item())) # write to .bin file