Skip to content

Commit

Permalink
Wrap up backward pass
Browse files Browse the repository at this point in the history
  • Loading branch information
gordicaleksa committed Jul 30, 2024
1 parent 23856e0 commit 703c4d9
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
6 changes: 3 additions & 3 deletions llmc/matmul.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ void matmul_forward_fc1(floatX* out,
void matmul_backward(floatX* dinp1, floatX* dinp2, floatX* dweight, floatX* dbias,
floatX* dout, floatX* inp, floatX* weight,
float* dbias_buffer,
int B, int T, int C, int OC, cudaStream_t stream,
int B, int T, int C, int OC, cudaStream_t stream, int accumulate_input,
const char* act_func, floatX* pre_act1=NULL, floatX* pre_act2=NULL, int gelu_fusion=1) {
NVTX_RANGE_FN();

Expand Down Expand Up @@ -294,8 +294,8 @@ void matmul_backward(floatX* dinp1, floatX* dinp2, floatX* dweight, floatX* dbia
int is_gelu = strcmp(act_func, "gelu") == 0;

// backward to input, uses = in the backward pass (set the gradient)
matmul_cublaslt(dinp1, weight, dout, NULL, C, B*T, OC, stream, false, false, 0, 0, 0, 0, false,
is_gelu && gelu_fusion >= 2 ? pre_act1 : NULL, true);
matmul_cublaslt(dinp1, weight, dout, NULL, C, B*T, OC, stream, false, false, 0, 0, 0, 0,
accumulate_input /* accumulate */, is_gelu && gelu_fusion >= 2 ? pre_act1 : NULL, true);

// backward GELU (if it wasn't fused into the matmul above)
if (is_gelu && gelu_fusion < 2 && pre_act1) {
Expand Down
20 changes: 10 additions & 10 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int
// technically that is a small, inline backward() pass of calculating
// total, final loss as the mean over all losses over all (B,T) positions in the batch
// next: backward the classifier matmul
matmul_backward(model->acts.scratch_bt4c, NULL, grads.wte, NULL, acts.output, acts.lnf, params.wte, NULL, B, T, C, Vp, main_stream, model->act_func);
matmul_backward(model->acts.scratch_bt4c, NULL, grads.wte, NULL, acts.output, acts.lnf, params.wte, NULL, B, T, C, Vp, main_stream, 0, model->act_func);
// backward the final layernorm
floatX* residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3
layernorm_backward(dresidual, grads.lnfw, grads.lnfb, scratchF, model->acts.scratch_bt4c, residual, params.lnfw, acts.lnf_mean, acts.lnf_rstd, B, T, C, main_stream);
Expand Down Expand Up @@ -875,6 +875,8 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int
floatX* dl_ln2b = grads.ln2b + l * C;
floatX* dl_fcw = grads.fcw + l * 4*C * C;
floatX* dl_fcb = grads.fcb + l * 4*C;
floatX* dl_gatew = gated_ffn ? grads.gatew + l * 4*C * C : NULL;
floatX* dl_gateb = gated_ffn ? grads.gateb + l * 4*C : NULL;
floatX* dl_fcprojw = grads.fcprojw + l * C * 4*C;
floatX* dl_fcprojb = grads.fcprojb + l * C;
// get the pointers of the activations for this layer
Expand All @@ -900,28 +902,26 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int
// start the backward pass for this layer
if(model->recompute >= 1) {
if (strcmp(model->act_func, "gelu") == 0) {
// recompute >= 1 means we recompute gelu. in this case,
// l_fch_act is just a buffer, so re-compute the gelu from l_fch_pre_act here
// l_fch_act is just a B*T*4*C buffer (vs L*B*T*4*C), so re-compute the gelu from l_fch_pre_act here
gelu_forward(l_fch_act, l_fch_pre_act, B*T*4*C, main_stream);
} else {
assert(strcmp(model->act_func, "swiglu") == 0);
assert(strcmp(model->act_func, "swiglu") == 0); // only swiglu is supported atm
swiglu_forward(l_fch_act, l_fch_pre_act, l_fch_pre_act2, B*T*4*C, main_stream);
}
}
matmul_backward(dl_bt4c, dl_bt4c2, dl_fcprojw, dl_fcprojb, dresidual, l_fch_act, l_fcprojw, scratchF, B, T, 4*C, C, main_stream, model->act_func, l_fch_pre_act, l_fch_pre_act2, model->act_func_fusion);
matmul_backward(dl_bt4c, dl_bt4c2, dl_fcprojw, dl_fcprojb, dresidual, l_fch_act, l_fcprojw, scratchF, B, T, 4*C, C, main_stream, 0, model->act_func, l_fch_pre_act, l_fch_pre_act2, model->act_func_fusion);
if(model->recompute >= 2) {
// same as gelu/{act_func} above, l_ln1 and l_ln2 are just buffers if recompute >= 2, recompute them here on demand
layernorm_forward(l_ln2, l_ln2_mean, l_ln2_rstd, l_residual2, l_ln2w, l_ln2b, B, T, C, main_stream);
}
if (dl_bt4c2 != NULL) {
assert(0); // TODO(gordicaleksa): implement
// backprop into gate
// matmul_backward(dl_bt4c2, NULL, dl_gatedw, dl_gatedb, dresidual, l_fch_act, l_gatedw, scratchF, B, T, 4*C, C, main_stream, model->act_func);
matmul_backward(dl_btc, NULL, dl_gatew, dl_gateb, dl_bt4c2, l_ln2, l_fcw, scratchF, B, T, C, 4 * C, main_stream, 0, model->act_func);
}
matmul_backward(dl_btc, NULL, dl_fcw, dl_fcb, dl_bt4c, l_ln2, l_fcw, scratchF, B, T, C, 4 * C, main_stream, model->act_func);
matmul_backward(dl_btc, NULL, dl_fcw, dl_fcb, dl_bt4c, l_ln2, l_fcw, scratchF, B, T, C, 4 * C, main_stream, dl_bt4c2 != NULL, model->act_func);
// layernorm backward does += to the dresidual, so it correctly accumulates grad from the MLP block above
layernorm_backward(dresidual, dl_ln2w, dl_ln2b, scratchF, dl_btc, l_residual2, l_ln2w, l_ln2_mean, l_ln2_rstd, B, T, C, main_stream);
matmul_backward(dl_btc, NULL, dl_attprojw, dl_attprojb, dresidual, l_atty, l_attprojw, scratchF, B, T, C, C, main_stream, model->act_func);
matmul_backward(dl_btc, NULL, dl_attprojw, dl_attprojb, dresidual, l_atty, l_attprojw, scratchF, B, T, C, C, main_stream, 0, model->act_func);

#ifdef ENABLE_CUDNN
float* l_att = (float*)acts.att + l * B * NH * T; // cuDNN needs a smaller FP32 tensor
Expand All @@ -937,7 +937,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int
layernorm_forward(l_ln1, l_ln1_mean, l_ln1_rstd, residual, l_ln1w, l_ln1b, B, T, C, main_stream);
}
// QKV parameter gradients
matmul_backward(dl_btc, NULL, dl_qkvw, dl_qkvb, dl_bt4c, l_ln1, l_qkvw, scratchF, B, T, C, 3 * C, main_stream, model->act_func);
matmul_backward(dl_btc, NULL, dl_qkvw, dl_qkvb, dl_bt4c, l_ln1, l_qkvw, scratchF, B, T, C, 3 * C, main_stream, 0, model->act_func);
// layernorm backward does += to dresidual, so it correctly accumulates gradient for the Attention block above
layernorm_backward(dresidual, dl_ln1w, dl_ln1b, scratchF, dl_btc, residual, l_ln1w, l_ln1_mean, l_ln1_rstd, B, T, C, main_stream);

Expand Down

0 comments on commit 703c4d9

Please sign in to comment.