From a577dd4e2a423214853de431a6f5318d351c75a1 Mon Sep 17 00:00:00 2001 From: Jason Senthil Date: Tue, 1 Oct 2024 21:03:52 -0700 Subject: [PATCH] sync epoch number print in train.py (#905) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/905 # Context In train.py, we log the number of steps underwent in a epoch here: https://www.internalfb.com/code/fbsource/[9ef12f3ec86db52d063788e881d9f9d3f1209b7e]/fbcode/torchtnt/framework/train.py?lines=276-279&base=d2dcd7009ea1ca4cfb2a5ba2025caa2ec04a7e7a The printed epoch is intentionally +1 since internally torchtnt starts the epoch at 0. However, in a prior print which logs reason why epoch finished, the epoch is not +1. https://www.internalfb.com/code/fbsource/[9ef12f3ec86db52d063788e881d9f9d3f1209b7e]/fbcode/torchtnt/framework/train.py?lines=258-264&base=d2dcd7009ea1ca4cfb2a5ba2025caa2ec04a7e7a # This Diff Add +1 so both prints are synced Reviewed By: anshulverma, vbourgin Differential Revision: D63557368 fbshipit-source-id: 48e9a4a1ebe63abd07354402aefa6c750245bfda --- tests/framework/test_loop_utils.py | 6 +++--- torchtnt/framework/_loop_utils.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/framework/test_loop_utils.py b/tests/framework/test_loop_utils.py index 9be63bb113..c7ad1febac 100644 --- a/tests/framework/test_loop_utils.py +++ b/tests/framework/test_loop_utils.py @@ -199,17 +199,17 @@ def test_log_reason_epoch_completed(self) -> None: p, max_steps_per_epoch=5, max_steps=None, stop_iteration_reached=False ) self.assertEqual( - reason, "Train epoch 2 ended as max steps per epoch reached: 5" + reason, "Train epoch 3 ended as max steps per epoch reached: 5" ) reason = _reason_epoch_completed( p, max_steps_per_epoch=6, max_steps=100, stop_iteration_reached=False ) - self.assertEqual(reason, "Train epoch 2 ended as max steps reached: 100") + self.assertEqual(reason, "Train epoch 3 ended as max steps reached: 100") reason = _reason_epoch_completed( p, max_steps_per_epoch=5, max_steps=None, stop_iteration_reached=True ) self.assertEqual( - reason, "Train epoch 2 ended as it reached end of train dataloader" + reason, "Train epoch 3 ended as it reached end of train dataloader" ) diff --git a/torchtnt/framework/_loop_utils.py b/torchtnt/framework/_loop_utils.py index 74a139aec3..2316da9185 100644 --- a/torchtnt/framework/_loop_utils.py +++ b/torchtnt/framework/_loop_utils.py @@ -49,7 +49,7 @@ def _reason_epoch_completed( max_steps: Optional[int], stop_iteration_reached: bool, ) -> str: - current_epoch = progress.num_epochs_completed + current_epoch = progress.num_epochs_completed + 1 if stop_iteration_reached: return ( f"Train epoch {current_epoch} ended as it reached end of train dataloader"