Skip to content

Commit

Permalink
Merge pull request #665 from ngc92/zero-grad
Browse files Browse the repository at this point in the history
zero-grad is async and part of backward call
  • Loading branch information
karpathy authored Jul 1, 2024
2 parents 79223b9 + 486f98e commit 39270cc
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 21 deletions.
1 change: 0 additions & 1 deletion profile_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ int main(int argc, char *argv[]) {

// do a training step
gpt2_forward(&model, x, B, T);
gpt2_zero_grad(&model);
gpt2_backward_and_reduce(&model, x, y, 1, 0);
float grad_norm = gpt2_calculate_grad_norm(&model, &multi_gpu_config);
float grad_scale = (grad_norm > 1.0f) ? 1.0f / grad_norm : 1.0f;
Expand Down
3 changes: 0 additions & 3 deletions test_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,6 @@ int main(int argc, char *argv[]) {
struct timespec start, end;
clock_gettime(CLOCK_MONOTONIC, &start);
gpt2_forward(&model, x, B, T);
gpt2_zero_grad(&model);
gpt2_backward_and_reduce(&model, x, y, 1, 0);
clock_gettime(CLOCK_MONOTONIC, &end);
double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9;
Expand Down Expand Up @@ -337,7 +336,6 @@ int main(int argc, char *argv[]) {
for (int step = 0; step < 10; step++) {
dataloader_next_batch(&loader);
gpt2_forward(&model, loader.inputs, B, T);
gpt2_zero_grad(&model);
gpt2_backward_and_reduce(&model, loader.inputs, loader.targets, 1, 0);
gpt2_update(&model, 1e-4f, 0.9f, 0.95f, 1e-8f, 0.0f, 1.0f, step+11, &multi_gpu_config);
losses[step] = model.mean_loss;
Expand All @@ -352,7 +350,6 @@ int main(int argc, char *argv[]) {
for (int step = 0; step < 10; step++) {
dataloader_next_batch(&loader);
gpt2_forward(&model, loader.inputs, B, T);
gpt2_zero_grad(&model);
gpt2_backward_and_reduce(&model, loader.inputs, loader.targets, 1, 0);
gpt2_update(&model, 1e-4f, 0.9f, 0.95f, 1e-8f, 0.0f, 1.0f, step+11, &multi_gpu_config);

Expand Down
26 changes: 9 additions & 17 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -732,27 +732,10 @@ float gpt2_validate(GPT2 *model, const int* inputs, const int* targets, size_t B
return mean_loss;
}

void gpt2_zero_grad(GPT2 *model) {
NVTX_RANGE_FN();
// there are currently two state vars during the gradient accumulation inner loop:
// 1) the losses accumulate += into acts.losses, reset here
// 2) the gradients accumulate += into grads_memory, reset here
cudaCheck(cudaMemset(model->acts.losses, 0, model->batch_size * model->seq_len * sizeof(float)));
if (model->grads_memory != NULL) {
cudaCheck(cudaMemset(model->grads_memory, 0, model->num_parameters * sizeof(floatX)));
}
cudaCheck(cudaDeviceSynchronize());
}

void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int grad_accum_steps, int micro_step) {
NVTX_RANGE_FN();
bool last_step = micro_step == grad_accum_steps - 1;

// on the first micro-step zero the gradients, as we're about to += accumulate into them
if (micro_step == 0) {
gpt2_zero_grad(model);
}

// lazily allocate the memory for gradients of the weights and activations, if needed
if (model->grads_memory == NULL) {
NvtxRange rng("InitGrads");
Expand All @@ -766,6 +749,15 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int
model->bucket_info = (int4*)mallocCheck(sizeof(int4) * model->batch_size * model->seq_len * num_c_groups);
}

// on the first micro-step zero the gradients, as we're about to += accumulate into them
if (micro_step == 0) {
// there are currently two state vars during the gradient accumulation inner loop:
// 1) the losses accumulate += into acts.losses, reset here
// 2) the gradients accumulate += into grads_memory, reset here
cudaCheck(cudaMemsetAsync(model->acts.losses, 0, model->batch_size * model->seq_len * sizeof(float), main_stream));
cudaCheck(cudaMemsetAsync(model->grads_memory, 0, model->num_parameters * sizeof(floatX), main_stream));
}

// convenience shortcuts, size_t instead of int so that pointer arithmetics don't overflow
const size_t B = model->batch_size;
const size_t T = model->seq_len;
Expand Down

0 comments on commit 39270cc

Please sign in to comment.