Skip to content

Commit

Permalink
fix: use the optimizer step separately for the Trainer class
Browse files Browse the repository at this point in the history
  • Loading branch information
haoxiangsnr committed Dec 28, 2023
1 parent 43363b1 commit 79cf94d
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 26 deletions.
115 changes: 89 additions & 26 deletions audiozen/common_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,16 +265,80 @@ def get_warmup_steps(warmup_steps, max_steps, warmup_ratio):
else:
return math.ceil(max_steps * warmup_ratio)

def create_warmup_scheduler(self, scheduler_name, max_steps: int):
def create_warmup_scheduler(self, optimizer, scheduler_name, max_steps: int):
num_warmup_steps = self.get_warmup_steps(self.warmup_steps, max_steps, self.warmup_ratio)
if scheduler_name == "constant_schedule_with_warmup":
return get_constant_schedule_with_warmup(optimizer=self.optimizer, num_warmup_steps=num_warmup_steps)
return get_constant_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=num_warmup_steps)
elif scheduler_name == "linear_schedule_with_warmup":
return get_linear_schedule_with_warmup(
optimizer=self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=max_steps
optimizer=optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=max_steps
)

def create_schedulers(self, max_steps: int):
"""Create schedulers.
You can override this method to create your own schedulers. For example, in GAN training, you may want to
create two schedulers for the generator and the discriminator.
Args:
max_steps: the maximum number of steps to train.
"""
self.lr_scheduler = self.create_warmup_scheduler(
optimizer=self.optimizer, scheduler_name=self.scheduler_name, max_steps=max_steps
)
self.lr_scheduler = self.accelerator.prepare(self.lr_scheduler)

def set_models_to_train_mode(self):
"""Set models to train mode.
You can override this method to set your own models to train mode. For example, in GAN training, you may want to
set the generator and the discriminator to train mode.
"""
self.model.train()

def set_models_to_eval_mode(self):
self.model.eval()

def lr_scheduler_step(self):
"""Step the lr scheduler.
You can override this method to step your own lr scheduler. For example, in GAN training, you may want to
step the lr scheduler of the generator and the discriminator.
"""
self.lr_scheduler.step(self.state.steps_trained)

def create_bar_desc(self, loss_dict, norm):
bar_desc = ""
for k, v in loss_dict.items():
bar_desc += f"{k}: {(v):.4f}, "
bar_desc += f"norm: {norm:.4f}, " f"lr: {self.lr_scheduler.get_last_lr()[-1]:.10f}"
return bar_desc

def train(self, train_dataloader: DataLoader, validation_dataloaders):
"""Train the model.
Args:
train_dataloader: the dataloader to train.
validation_dataloaders: the dataloader(s) to validate.
Notes:
You are responsible for calling `.backward()`, `.step()`, and `.zero_grad()` in your implementation
of `training_step()`. Accelerate will automatically handle the gradient accumulation for you.
It means that in gradient accumulation, the step() of optimizer and scheduler is called only when gradient_accumulation_steps is reached.
The training step is implemented as follows:
.. code-block:: python
self.optimizer.zero_grad()
loss = training_step(batch, batch_idx)
self.accelerator.backward(loss)
self.optimizer.step()
return {
"loss": loss,
}
"""
early_stop_mark = torch.zeros(1, device=self.device)

if self.debug:
Expand All @@ -300,14 +364,14 @@ def train(self, train_dataloader: DataLoader, validation_dataloaders):
logger.info(f"`max_steps`: {max_steps}")
logger.info(f"`max_epochs`: {max_epochs}")

self.lr_scheduler = self.create_warmup_scheduler(scheduler_name=self.scheduler_name, max_steps=max_steps)
self.lr_scheduler = self.accelerator.prepare(self.lr_scheduler)
# Generator learning rate scheduler
self.create_schedulers(max_steps=max_steps)

for epoch in range(self.state.epochs_trained + 1, max_epochs + 1):
logger.info(f"{'=' * 9} Epoch {epoch} out of {max_epochs} {'=' * 9}")
logger.info("Begin training...")

self.model.train()
self.set_models_to_train_mode()

training_epoch_output = []

Expand All @@ -323,41 +387,39 @@ def train(self, train_dataloader: DataLoader, validation_dataloaders):
)

for batch_idx, batch in enumerate(dataloader_bar):
# accumulate() will automatically skip synchronization if applicable
# loss is linearly scaled with the optimizer.grad
# accumulate() will automatically skip synchronization if applicable loss is linearly scaled with the optimizer.grad
# accumulate() will automatically divide the loss in backward by the number of gradient accumulation steps
# However, it won't return this loss, so we need to manually divide the loss by the number of gradient accumulation steps.
with self.accelerator.accumulate(self.model):
# You are responsible for calling `.backward()`, `.step()`, and `.zero_grad()` in your implementation
loss_dict = self.training_step(batch, batch_idx)

# I guess we don't need to divide the loss by the number of gradient accumulation steps here
# for visualization, we just plot the mean of mean of the loss of each batch
training_epoch_output.append(loss_dict)

# =======================================
# If `sync_gradients` is True, the gradients are currently being synced across all processes.
# It means that the current step we have finished the cumulative gradient accumulation.
# =======================================
if self.accelerator.sync_gradients:
# The gradients are added across all processes in this cumulative gradient accumulation step.
if self.max_grad_norm > 0:
norm = self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)

if self.accelerator.is_local_main_process:
bar_desc = ""
for k, v in loss_dict.items():
bar_desc += f"{k}: {(v):.4f}, "
bar_desc += f"norm: {norm:.4f}, " f"lr: {self.lr_scheduler.get_last_lr()[-1]:.10f}"
dataloader_bar.set_description(bar_desc)

# In gradient accumulation, the step() of optimizer and scheduler is called
# only when gradient_accumulation_steps is reached.
self.optimizer.step()
# AMP: if the gradients are `nan` or `inf` skip the update step
bar_desc = self.create_bar_desc(loss_dict, norm)
dataloader_bar.set_description_str(bar_desc)

if not self.accelerator.optimizer_step_was_skipped:
# https://github.com/huggingface/accelerate/issues/1398
self.lr_scheduler.step(self.state.steps_trained)
self.optimizer.zero_grad()
# We can put the scheduler.step() into the training_step() function. However, it has **too much
# details should be considered**. It's better to put it here and add some comments.
#
# 1. every process lr_scheduler step N times, where N is the number of processes. We need to multiply the number of steps by the number of processes before constructing the
# scheduler to make sure it behaves as we expect it to do. https://github.com/huggingface/accelerate/issues/1398
#
# 2. For AMP, if the gradients are `nan` or `inf` skip the update step, we should call the
# `scheduler.step()` after checking `self.accelerator.optimizer_step_was_skipped`.
# Otherwise, the scheduler.step() will be called even if the optimizer step is skipped.
self.lr_scheduler_step()

self.state.steps_trained += 1
self.state.epochs_trained += 1
Expand Down Expand Up @@ -393,7 +455,7 @@ def train(self, train_dataloader: DataLoader, validation_dataloaders):
def validate(self, dataloaders):
logger.info(f"Begin validation...")

self.model.eval()
self.set_models_to_eval_mode()

if not isinstance(dataloaders, list):
dataloaders = [dataloaders]
Expand Down Expand Up @@ -446,7 +508,7 @@ def test(self, dataloaders, ckpt_path="best"):

self._load_checkpoint(ckpt_path)

self.model.eval()
self.set_models_to_eval_mode()

test_output = []
for dataloader_idx, dataloader in enumerate(dataloaders):
Expand Down Expand Up @@ -491,7 +553,8 @@ def predict(self, dataloaders, ckpt_path="best"):
"""
if self.rank == 0:
logger.info(f"Begin predicting...")
self.model.eval()

self.set_models_to_eval_mode()

if not isinstance(dataloaders, list):
dataloaders = [dataloaders]
Expand Down
1 change: 1 addition & 0 deletions recipes/intel_ndns/spiking_fullsubnet/trainer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def training_step(self, batch, batch_idx):
loss = loss_freq_mae + loss_mag_mae + loss_sdr_norm # + loss_g_fake

self.accelerator.backward(loss)
self.optimizer.step()

return {
"loss": loss,
Expand Down

0 comments on commit 79cf94d

Please sign in to comment.