Skip to content

Commit

Permalink
Early exit on predict entrypoint if epoch has completed (#972)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #972

Reviewed By: galrotem, anshulverma

Differential Revision: D69865409
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Feb 25, 2025
1 parent 2e6cd59 commit 2227e82
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
11 changes: 11 additions & 0 deletions tests/framework/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torchtnt.framework.predict import predict
from torchtnt.framework.state import State
from torchtnt.framework.unit import PredictUnit, TPredictUnit
from torchtnt.utils.progress import Progress
from torchtnt.utils.timer import Timer


Expand Down Expand Up @@ -242,6 +243,16 @@ def test_predict_ckpt_autograd_mode(
predict(unit, dataloader, callbacks=cast(List[Callback], callbacks))
mock_autograd_mode.assert_called_once()

def test_predict_epoch_check(self) -> None:
unit = MagicMock(wraps=DummyPredictUnit(2))
unit.predict_progress = Progress(num_epochs_completed=1, num_steps_completed=5)

dataloader = generate_random_dataloader(10, 2, 2)

predict(unit, dataloader, max_steps_per_epoch=100)

unit.on_predict_start.assert_not_called()


Batch = Tuple[torch.Tensor, torch.Tensor]

Expand Down
6 changes: 6 additions & 0 deletions torchtnt/framework/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ def _predict_impl(
# input validation
predict_state = none_throws(state.predict_state)

if predict_unit.predict_progress.num_epochs_completed >= 1:
logger.warning(
"Predict epoch has already been completed. Skipping to avoid duplicate outputs."
)
return

state._active_phase = ActivePhase.PREDICT
logger.info(
f"Started predict with max_steps_per_epoch={predict_state.max_steps_per_epoch}"
Expand Down

0 comments on commit 2227e82

Please sign in to comment.