Skip to content

Commit

Permalink
sync epoch number print in train.py (#905)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #905

# Context

In train.py, we log the number of steps underwent in a epoch here:
https://www.internalfb.com/code/fbsource/[9ef12f3ec86db52d063788e881d9f9d3f1209b7e]/fbcode/torchtnt/framework/train.py?lines=276-279&base=d2dcd7009ea1ca4cfb2a5ba2025caa2ec04a7e7a
The printed epoch is intentionally +1 since internally torchtnt starts the epoch at 0.

However, in a prior print which logs reason why epoch finished, the epoch is not +1.

https://www.internalfb.com/code/fbsource/[9ef12f3ec86db52d063788e881d9f9d3f1209b7e]/fbcode/torchtnt/framework/train.py?lines=258-264&base=d2dcd7009ea1ca4cfb2a5ba2025caa2ec04a7e7a

# This Diff
Add +1 so both prints are synced

Reviewed By: anshulverma, vbourgin

Differential Revision: D63557368

fbshipit-source-id: 48e9a4a1ebe63abd07354402aefa6c750245bfda
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Oct 2, 2024
1 parent d86828b commit a577dd4
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions tests/framework/test_loop_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,17 +199,17 @@ def test_log_reason_epoch_completed(self) -> None:
p, max_steps_per_epoch=5, max_steps=None, stop_iteration_reached=False
)
self.assertEqual(
reason, "Train epoch 2 ended as max steps per epoch reached: 5"
reason, "Train epoch 3 ended as max steps per epoch reached: 5"
)

reason = _reason_epoch_completed(
p, max_steps_per_epoch=6, max_steps=100, stop_iteration_reached=False
)
self.assertEqual(reason, "Train epoch 2 ended as max steps reached: 100")
self.assertEqual(reason, "Train epoch 3 ended as max steps reached: 100")

reason = _reason_epoch_completed(
p, max_steps_per_epoch=5, max_steps=None, stop_iteration_reached=True
)
self.assertEqual(
reason, "Train epoch 2 ended as it reached end of train dataloader"
reason, "Train epoch 3 ended as it reached end of train dataloader"
)
2 changes: 1 addition & 1 deletion torchtnt/framework/_loop_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _reason_epoch_completed(
max_steps: Optional[int],
stop_iteration_reached: bool,
) -> str:
current_epoch = progress.num_epochs_completed
current_epoch = progress.num_epochs_completed + 1
if stop_iteration_reached:
return (
f"Train epoch {current_epoch} ended as it reached end of train dataloader"
Expand Down

0 comments on commit a577dd4

Please sign in to comment.