diff --git a/audiozen/trainer/base_trainer_gan_accelerate.py b/audiozen/trainer/base_trainer_gan_accelerate.py index f4f0414..5eb9741 100644 --- a/audiozen/trainer/base_trainer_gan_accelerate.py +++ b/audiozen/trainer/base_trainer_gan_accelerate.py @@ -67,9 +67,7 @@ def __init__( self.plot_norm = self.trainer_config.get("plot_norm", True) self.validation_interval = self.trainer_config.get("validation_interval", 1) self.max_num_checkpoints = self.trainer_config.get("max_num_checkpoints", 10) - assert ( - self.validation_interval >= 1 - ), "'validation_interval' should be large than one." + assert self.validation_interval >= 1, "'validation_interval' should be large than one." # Count Variables self.total_norm = -1 @@ -147,9 +145,7 @@ def _initialize_exp_dirs_and_paths(self, config): # Each run will have a unique source code, config, and log file. time_now = self._get_time_now() - self.source_code_dir = ( - Path(__file__).expanduser().absolute().parent.parent.parent - ) + self.source_code_dir = Path(__file__).expanduser().absolute().parent.parent.parent self.source_code_backup_dir = self.exp_dir / f"source_code__{time_now}" self.config_path = self.exp_dir / f"config__{time_now}.toml" @@ -241,9 +237,7 @@ def _run_early_stop_check(self, score: float, epoch: int): else: logger.info(f"Score did not improve from {self.best_score.value:.4f}.") self.wait_count.value += 1 - logger.info( - f"Early stopping counter: {self.wait_count.value} out of {self.patience}" - ) + logger.info(f"Early stopping counter: {self.wait_count.value} out of {self.patience}") if self.wait_count.value >= self.patience: logger.info(f"Early stopping triggered, stopping training...")