From 703c4d9384d065f86d570a35899c508a5410fcf1 Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Tue, 30 Jul 2024 11:16:41 +0200 Subject: [PATCH] Wrap up backward pass --- llmc/matmul.cuh | 6 +++--- train_gpt2.cu | 20 ++++++++++---------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/llmc/matmul.cuh b/llmc/matmul.cuh index 3b19e136c..6a3f44794 100644 --- a/llmc/matmul.cuh +++ b/llmc/matmul.cuh @@ -260,7 +260,7 @@ void matmul_forward_fc1(floatX* out, void matmul_backward(floatX* dinp1, floatX* dinp2, floatX* dweight, floatX* dbias, floatX* dout, floatX* inp, floatX* weight, float* dbias_buffer, - int B, int T, int C, int OC, cudaStream_t stream, + int B, int T, int C, int OC, cudaStream_t stream, int accumulate_input, const char* act_func, floatX* pre_act1=NULL, floatX* pre_act2=NULL, int gelu_fusion=1) { NVTX_RANGE_FN(); @@ -294,8 +294,8 @@ void matmul_backward(floatX* dinp1, floatX* dinp2, floatX* dweight, floatX* dbia int is_gelu = strcmp(act_func, "gelu") == 0; // backward to input, uses = in the backward pass (set the gradient) - matmul_cublaslt(dinp1, weight, dout, NULL, C, B*T, OC, stream, false, false, 0, 0, 0, 0, false, - is_gelu && gelu_fusion >= 2 ? pre_act1 : NULL, true); + matmul_cublaslt(dinp1, weight, dout, NULL, C, B*T, OC, stream, false, false, 0, 0, 0, 0, + accumulate_input /* accumulate */, is_gelu && gelu_fusion >= 2 ? pre_act1 : NULL, true); // backward GELU (if it wasn't fused into the matmul above) if (is_gelu && gelu_fusion < 2 && pre_act1) { diff --git a/train_gpt2.cu b/train_gpt2.cu index fd6acf717..c74cf20ce 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -840,7 +840,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int // technically that is a small, inline backward() pass of calculating // total, final loss as the mean over all losses over all (B,T) positions in the batch // next: backward the classifier matmul - matmul_backward(model->acts.scratch_bt4c, NULL, grads.wte, NULL, acts.output, acts.lnf, params.wte, NULL, B, T, C, Vp, main_stream, model->act_func); + matmul_backward(model->acts.scratch_bt4c, NULL, grads.wte, NULL, acts.output, acts.lnf, params.wte, NULL, B, T, C, Vp, main_stream, 0, model->act_func); // backward the final layernorm floatX* residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3 layernorm_backward(dresidual, grads.lnfw, grads.lnfb, scratchF, model->acts.scratch_bt4c, residual, params.lnfw, acts.lnf_mean, acts.lnf_rstd, B, T, C, main_stream); @@ -875,6 +875,8 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int floatX* dl_ln2b = grads.ln2b + l * C; floatX* dl_fcw = grads.fcw + l * 4*C * C; floatX* dl_fcb = grads.fcb + l * 4*C; + floatX* dl_gatew = gated_ffn ? grads.gatew + l * 4*C * C : NULL; + floatX* dl_gateb = gated_ffn ? grads.gateb + l * 4*C : NULL; floatX* dl_fcprojw = grads.fcprojw + l * C * 4*C; floatX* dl_fcprojb = grads.fcprojb + l * C; // get the pointers of the activations for this layer @@ -900,28 +902,26 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int // start the backward pass for this layer if(model->recompute >= 1) { if (strcmp(model->act_func, "gelu") == 0) { - // recompute >= 1 means we recompute gelu. in this case, - // l_fch_act is just a buffer, so re-compute the gelu from l_fch_pre_act here + // l_fch_act is just a B*T*4*C buffer (vs L*B*T*4*C), so re-compute the gelu from l_fch_pre_act here gelu_forward(l_fch_act, l_fch_pre_act, B*T*4*C, main_stream); } else { - assert(strcmp(model->act_func, "swiglu") == 0); + assert(strcmp(model->act_func, "swiglu") == 0); // only swiglu is supported atm swiglu_forward(l_fch_act, l_fch_pre_act, l_fch_pre_act2, B*T*4*C, main_stream); } } - matmul_backward(dl_bt4c, dl_bt4c2, dl_fcprojw, dl_fcprojb, dresidual, l_fch_act, l_fcprojw, scratchF, B, T, 4*C, C, main_stream, model->act_func, l_fch_pre_act, l_fch_pre_act2, model->act_func_fusion); + matmul_backward(dl_bt4c, dl_bt4c2, dl_fcprojw, dl_fcprojb, dresidual, l_fch_act, l_fcprojw, scratchF, B, T, 4*C, C, main_stream, 0, model->act_func, l_fch_pre_act, l_fch_pre_act2, model->act_func_fusion); if(model->recompute >= 2) { // same as gelu/{act_func} above, l_ln1 and l_ln2 are just buffers if recompute >= 2, recompute them here on demand layernorm_forward(l_ln2, l_ln2_mean, l_ln2_rstd, l_residual2, l_ln2w, l_ln2b, B, T, C, main_stream); } if (dl_bt4c2 != NULL) { - assert(0); // TODO(gordicaleksa): implement // backprop into gate - // matmul_backward(dl_bt4c2, NULL, dl_gatedw, dl_gatedb, dresidual, l_fch_act, l_gatedw, scratchF, B, T, 4*C, C, main_stream, model->act_func); + matmul_backward(dl_btc, NULL, dl_gatew, dl_gateb, dl_bt4c2, l_ln2, l_fcw, scratchF, B, T, C, 4 * C, main_stream, 0, model->act_func); } - matmul_backward(dl_btc, NULL, dl_fcw, dl_fcb, dl_bt4c, l_ln2, l_fcw, scratchF, B, T, C, 4 * C, main_stream, model->act_func); + matmul_backward(dl_btc, NULL, dl_fcw, dl_fcb, dl_bt4c, l_ln2, l_fcw, scratchF, B, T, C, 4 * C, main_stream, dl_bt4c2 != NULL, model->act_func); // layernorm backward does += to the dresidual, so it correctly accumulates grad from the MLP block above layernorm_backward(dresidual, dl_ln2w, dl_ln2b, scratchF, dl_btc, l_residual2, l_ln2w, l_ln2_mean, l_ln2_rstd, B, T, C, main_stream); - matmul_backward(dl_btc, NULL, dl_attprojw, dl_attprojb, dresidual, l_atty, l_attprojw, scratchF, B, T, C, C, main_stream, model->act_func); + matmul_backward(dl_btc, NULL, dl_attprojw, dl_attprojb, dresidual, l_atty, l_attprojw, scratchF, B, T, C, C, main_stream, 0, model->act_func); #ifdef ENABLE_CUDNN float* l_att = (float*)acts.att + l * B * NH * T; // cuDNN needs a smaller FP32 tensor @@ -937,7 +937,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int layernorm_forward(l_ln1, l_ln1_mean, l_ln1_rstd, residual, l_ln1w, l_ln1b, B, T, C, main_stream); } // QKV parameter gradients - matmul_backward(dl_btc, NULL, dl_qkvw, dl_qkvb, dl_bt4c, l_ln1, l_qkvw, scratchF, B, T, C, 3 * C, main_stream, model->act_func); + matmul_backward(dl_btc, NULL, dl_qkvw, dl_qkvb, dl_bt4c, l_ln1, l_qkvw, scratchF, B, T, C, 3 * C, main_stream, 0, model->act_func); // layernorm backward does += to dresidual, so it correctly accumulates gradient for the Attention block above layernorm_backward(dresidual, dl_ln1w, dl_ln1b, scratchF, dl_btc, residual, l_ln1w, l_ln1_mean, l_ln1_rstd, B, T, C, main_stream);