Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix the overfit single batch behavior to actually overfit batch, not … #662

Merged
merged 1 commit into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1580,8 +1580,9 @@ int main(int argc, char *argv[]) {
printf0("+-----------------------+----------------------------------------------------+\n");

// build DataLoaders for both train and val
int permute_train_loader = (overfit_single_batch == 1) ? 0 : 1;
DataLoader train_loader, val_loader;
dataloader_init(&train_loader, train_data_pattern, B, T, multi_gpu_config.process_rank, multi_gpu_config.num_processes, 1);
dataloader_init(&train_loader, train_data_pattern, B, T, multi_gpu_config.process_rank, multi_gpu_config.num_processes, permute_train_loader);
dataloader_init(&val_loader, val_data_pattern, B, T, multi_gpu_config.process_rank, multi_gpu_config.num_processes, 0);
// figure out the number of training steps we will run for
int train_num_batches = max_steps; // passed in from command line
Expand Down Expand Up @@ -1781,16 +1782,16 @@ int main(int argc, char *argv[]) {
if (last_step) { break; }

// --------------- TRAINING SECTION BEGIN -----------------
if (overfit_single_batch == 1) {
// if we are trying to overfit a single batch, we reset the loader here
dataloader_reset(&train_loader);
}
// do one training step, doing forward/backward/update on total_batch_size tokens
cudaEventRecord(start);
// gradient and loss accumulation loop over micro-batches
for (int micro_step = 0; micro_step < grad_accum_steps; micro_step++) {
// fetch the next data batch
// and if we're overfitting a single batch, we'll only call this a single time
if (overfit_single_batch == 0 ||
(overfit_single_batch == 1 && step == 0 && micro_step == 0)) {
dataloader_next_batch(&train_loader);
}
dataloader_next_batch(&train_loader);
// forward pass. note that we pass in grad_accum_steps, which scales down the loss
gpt2_forward(&model, train_loader.inputs, B, T);
// backward pass. all model params accumulate gradients with += inside this inner loop
Expand Down
9 changes: 5 additions & 4 deletions train_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,14 +794,15 @@ def get_lr(it):

# --------------- TRAINING SECTION BEGIN -----------------
model.train()
# if we are trying to overfit a single batch, we reset the loader here
if args.overfit_single_batch:
train_loader.reset()
# micro-batch loop where we do gradient accumulation to reach desired total batch size
lossf = 0.0 # for getting the mean loss (as simple float) over the accumulation steps
for micro_step in range(grad_accum_steps):
# fetch a batch
if not args.overfit_single_batch \
or (args.overfit_single_batch and step == 0 and micro_step == 0):
x, y = train_loader.next_batch()
x, y = x.to(device), y.to(device)
x, y = train_loader.next_batch()
x, y = x.to(device), y.to(device)
if ddp:
# we want only the last micro-step to sync grads in a DDP model
# the official way to do this is with model.no_sync(), but that is a
Expand Down