diff --git a/audiozen/common_trainer.py b/audiozen/common_trainer.py index 5f07d5c..38016f9 100644 --- a/audiozen/common_trainer.py +++ b/audiozen/common_trainer.py @@ -265,16 +265,80 @@ def get_warmup_steps(warmup_steps, max_steps, warmup_ratio): else: return math.ceil(max_steps * warmup_ratio) - def create_warmup_scheduler(self, scheduler_name, max_steps: int): + def create_warmup_scheduler(self, optimizer, scheduler_name, max_steps: int): num_warmup_steps = self.get_warmup_steps(self.warmup_steps, max_steps, self.warmup_ratio) if scheduler_name == "constant_schedule_with_warmup": - return get_constant_schedule_with_warmup(optimizer=self.optimizer, num_warmup_steps=num_warmup_steps) + return get_constant_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=num_warmup_steps) elif scheduler_name == "linear_schedule_with_warmup": return get_linear_schedule_with_warmup( - optimizer=self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=max_steps + optimizer=optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=max_steps ) + def create_schedulers(self, max_steps: int): + """Create schedulers. + + You can override this method to create your own schedulers. For example, in GAN training, you may want to + create two schedulers for the generator and the discriminator. + + Args: + max_steps: the maximum number of steps to train. + """ + self.lr_scheduler = self.create_warmup_scheduler( + optimizer=self.optimizer, scheduler_name=self.scheduler_name, max_steps=max_steps + ) + self.lr_scheduler = self.accelerator.prepare(self.lr_scheduler) + + def set_models_to_train_mode(self): + """Set models to train mode. + + You can override this method to set your own models to train mode. For example, in GAN training, you may want to + set the generator and the discriminator to train mode. + """ + self.model.train() + + def set_models_to_eval_mode(self): + self.model.eval() + + def lr_scheduler_step(self): + """Step the lr scheduler. + + You can override this method to step your own lr scheduler. For example, in GAN training, you may want to + step the lr scheduler of the generator and the discriminator. + """ + self.lr_scheduler.step(self.state.steps_trained) + + def create_bar_desc(self, loss_dict, norm): + bar_desc = "" + for k, v in loss_dict.items(): + bar_desc += f"{k}: {(v):.4f}, " + bar_desc += f"norm: {norm:.4f}, " f"lr: {self.lr_scheduler.get_last_lr()[-1]:.10f}" + return bar_desc + def train(self, train_dataloader: DataLoader, validation_dataloaders): + """Train the model. + + Args: + train_dataloader: the dataloader to train. + validation_dataloaders: the dataloader(s) to validate. + + Notes: + You are responsible for calling `.backward()`, `.step()`, and `.zero_grad()` in your implementation + of `training_step()`. Accelerate will automatically handle the gradient accumulation for you. + It means that in gradient accumulation, the step() of optimizer and scheduler is called only when gradient_accumulation_steps is reached. + + The training step is implemented as follows: + + .. code-block:: python + + self.optimizer.zero_grad() + loss = training_step(batch, batch_idx) + self.accelerator.backward(loss) + self.optimizer.step() + + return { + "loss": loss, + } + """ early_stop_mark = torch.zeros(1, device=self.device) if self.debug: @@ -300,14 +364,14 @@ def train(self, train_dataloader: DataLoader, validation_dataloaders): logger.info(f"`max_steps`: {max_steps}") logger.info(f"`max_epochs`: {max_epochs}") - self.lr_scheduler = self.create_warmup_scheduler(scheduler_name=self.scheduler_name, max_steps=max_steps) - self.lr_scheduler = self.accelerator.prepare(self.lr_scheduler) + # Generator learning rate scheduler + self.create_schedulers(max_steps=max_steps) for epoch in range(self.state.epochs_trained + 1, max_epochs + 1): logger.info(f"{'=' * 9} Epoch {epoch} out of {max_epochs} {'=' * 9}") logger.info("Begin training...") - self.model.train() + self.set_models_to_train_mode() training_epoch_output = [] @@ -323,41 +387,39 @@ def train(self, train_dataloader: DataLoader, validation_dataloaders): ) for batch_idx, batch in enumerate(dataloader_bar): - # accumulate() will automatically skip synchronization if applicable - # loss is linearly scaled with the optimizer.grad + # accumulate() will automatically skip synchronization if applicable loss is linearly scaled with the optimizer.grad # accumulate() will automatically divide the loss in backward by the number of gradient accumulation steps # However, it won't return this loss, so we need to manually divide the loss by the number of gradient accumulation steps. with self.accelerator.accumulate(self.model): + # You are responsible for calling `.backward()`, `.step()`, and `.zero_grad()` in your implementation loss_dict = self.training_step(batch, batch_idx) # I guess we don't need to divide the loss by the number of gradient accumulation steps here # for visualization, we just plot the mean of mean of the loss of each batch training_epoch_output.append(loss_dict) - # ======================================= # If `sync_gradients` is True, the gradients are currently being synced across all processes. # It means that the current step we have finished the cumulative gradient accumulation. - # ======================================= if self.accelerator.sync_gradients: # The gradients are added across all processes in this cumulative gradient accumulation step. if self.max_grad_norm > 0: norm = self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) if self.accelerator.is_local_main_process: - bar_desc = "" - for k, v in loss_dict.items(): - bar_desc += f"{k}: {(v):.4f}, " - bar_desc += f"norm: {norm:.4f}, " f"lr: {self.lr_scheduler.get_last_lr()[-1]:.10f}" - dataloader_bar.set_description(bar_desc) - - # In gradient accumulation, the step() of optimizer and scheduler is called - # only when gradient_accumulation_steps is reached. - self.optimizer.step() - # AMP: if the gradients are `nan` or `inf` skip the update step + bar_desc = self.create_bar_desc(loss_dict, norm) + dataloader_bar.set_description_str(bar_desc) + if not self.accelerator.optimizer_step_was_skipped: - # https://github.com/huggingface/accelerate/issues/1398 - self.lr_scheduler.step(self.state.steps_trained) - self.optimizer.zero_grad() + # We can put the scheduler.step() into the training_step() function. However, it has **too much + # details should be considered**. It's better to put it here and add some comments. + # + # 1. every process lr_scheduler step N times, where N is the number of processes. We need to multiply the number of steps by the number of processes before constructing the + # scheduler to make sure it behaves as we expect it to do. https://github.com/huggingface/accelerate/issues/1398 + # + # 2. For AMP, if the gradients are `nan` or `inf` skip the update step, we should call the + # `scheduler.step()` after checking `self.accelerator.optimizer_step_was_skipped`. + # Otherwise, the scheduler.step() will be called even if the optimizer step is skipped. + self.lr_scheduler_step() self.state.steps_trained += 1 self.state.epochs_trained += 1 @@ -393,7 +455,7 @@ def train(self, train_dataloader: DataLoader, validation_dataloaders): def validate(self, dataloaders): logger.info(f"Begin validation...") - self.model.eval() + self.set_models_to_eval_mode() if not isinstance(dataloaders, list): dataloaders = [dataloaders] @@ -446,7 +508,7 @@ def test(self, dataloaders, ckpt_path="best"): self._load_checkpoint(ckpt_path) - self.model.eval() + self.set_models_to_eval_mode() test_output = [] for dataloader_idx, dataloader in enumerate(dataloaders): @@ -491,7 +553,8 @@ def predict(self, dataloaders, ckpt_path="best"): """ if self.rank == 0: logger.info(f"Begin predicting...") - self.model.eval() + + self.set_models_to_eval_mode() if not isinstance(dataloaders, list): dataloaders = [dataloaders] diff --git a/recipes/intel_ndns/spiking_fullsubnet/trainer_v2.py b/recipes/intel_ndns/spiking_fullsubnet/trainer_v2.py index 3b22361..668e9f6 100644 --- a/recipes/intel_ndns/spiking_fullsubnet/trainer_v2.py +++ b/recipes/intel_ndns/spiking_fullsubnet/trainer_v2.py @@ -39,6 +39,7 @@ def training_step(self, batch, batch_idx): loss = loss_freq_mae + loss_mag_mae + loss_sdr_norm # + loss_g_fake self.accelerator.backward(loss) + self.optimizer.step() return { "loss": loss,