Skip to content

Commit

Permalink
chore: refactor code for validation interval and logging in BaseTrain…
Browse files Browse the repository at this point in the history
…er class
  • Loading branch information
haoxiangsnr committed Dec 21, 2023
1 parent 0fb6d44 commit 7241dff
Showing 1 changed file with 3 additions and 9 deletions.
12 changes: 3 additions & 9 deletions audiozen/trainer/base_trainer_gan_accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ def __init__(
self.plot_norm = self.trainer_config.get("plot_norm", True)
self.validation_interval = self.trainer_config.get("validation_interval", 1)
self.max_num_checkpoints = self.trainer_config.get("max_num_checkpoints", 10)
assert (
self.validation_interval >= 1
), "'validation_interval' should be large than one."
assert self.validation_interval >= 1, "'validation_interval' should be large than one."

# Count Variables
self.total_norm = -1
Expand Down Expand Up @@ -147,9 +145,7 @@ def _initialize_exp_dirs_and_paths(self, config):

# Each run will have a unique source code, config, and log file.
time_now = self._get_time_now()
self.source_code_dir = (
Path(__file__).expanduser().absolute().parent.parent.parent
)
self.source_code_dir = Path(__file__).expanduser().absolute().parent.parent.parent
self.source_code_backup_dir = self.exp_dir / f"source_code__{time_now}"
self.config_path = self.exp_dir / f"config__{time_now}.toml"

Expand Down Expand Up @@ -241,9 +237,7 @@ def _run_early_stop_check(self, score: float, epoch: int):
else:
logger.info(f"Score did not improve from {self.best_score.value:.4f}.")
self.wait_count.value += 1
logger.info(
f"Early stopping counter: {self.wait_count.value} out of {self.patience}"
)
logger.info(f"Early stopping counter: {self.wait_count.value} out of {self.patience}")

if self.wait_count.value >= self.patience:
logger.info(f"Early stopping triggered, stopping training...")
Expand Down

0 comments on commit 7241dff

Please sign in to comment.