diff --git a/tests/framework/test_loop_utils.py b/tests/framework/test_loop_utils.py index 9be63bb113..c7ad1febac 100644 --- a/tests/framework/test_loop_utils.py +++ b/tests/framework/test_loop_utils.py @@ -199,17 +199,17 @@ def test_log_reason_epoch_completed(self) -> None: p, max_steps_per_epoch=5, max_steps=None, stop_iteration_reached=False ) self.assertEqual( - reason, "Train epoch 2 ended as max steps per epoch reached: 5" + reason, "Train epoch 3 ended as max steps per epoch reached: 5" ) reason = _reason_epoch_completed( p, max_steps_per_epoch=6, max_steps=100, stop_iteration_reached=False ) - self.assertEqual(reason, "Train epoch 2 ended as max steps reached: 100") + self.assertEqual(reason, "Train epoch 3 ended as max steps reached: 100") reason = _reason_epoch_completed( p, max_steps_per_epoch=5, max_steps=None, stop_iteration_reached=True ) self.assertEqual( - reason, "Train epoch 2 ended as it reached end of train dataloader" + reason, "Train epoch 3 ended as it reached end of train dataloader" ) diff --git a/torchtnt/framework/_loop_utils.py b/torchtnt/framework/_loop_utils.py index 74a139aec3..2316da9185 100644 --- a/torchtnt/framework/_loop_utils.py +++ b/torchtnt/framework/_loop_utils.py @@ -49,7 +49,7 @@ def _reason_epoch_completed( max_steps: Optional[int], stop_iteration_reached: bool, ) -> str: - current_epoch = progress.num_epochs_completed + current_epoch = progress.num_epochs_completed + 1 if stop_iteration_reached: return ( f"Train epoch {current_epoch} ended as it reached end of train dataloader"