diff --git a/tests/framework/test_predict.py b/tests/framework/test_predict.py index b33ba30128..dc12f3c71f 100644 --- a/tests/framework/test_predict.py +++ b/tests/framework/test_predict.py @@ -21,6 +21,7 @@ from torchtnt.framework.predict import predict from torchtnt.framework.state import State from torchtnt.framework.unit import PredictUnit, TPredictUnit +from torchtnt.utils.progress import Progress from torchtnt.utils.timer import Timer @@ -242,6 +243,16 @@ def test_predict_ckpt_autograd_mode( predict(unit, dataloader, callbacks=cast(List[Callback], callbacks)) mock_autograd_mode.assert_called_once() + def test_predict_epoch_check(self) -> None: + unit = MagicMock(wraps=DummyPredictUnit(2)) + unit.predict_progress = Progress(num_epochs_completed=1, num_steps_completed=5) + + dataloader = generate_random_dataloader(10, 2, 2) + + predict(unit, dataloader, max_steps_per_epoch=100) + + unit.on_predict_start.assert_not_called() + Batch = Tuple[torch.Tensor, torch.Tensor] diff --git a/torchtnt/framework/predict.py b/torchtnt/framework/predict.py index 33207309bc..5332e71877 100644 --- a/torchtnt/framework/predict.py +++ b/torchtnt/framework/predict.py @@ -122,6 +122,12 @@ def _predict_impl( # input validation predict_state = none_throws(state.predict_state) + if predict_unit.predict_progress.num_epochs_completed >= 1: + logger.warning( + "Predict epoch has already been completed. Skipping to avoid duplicate outputs." + ) + return + state._active_phase = ActivePhase.PREDICT logger.info( f"Started predict with max_steps_per_epoch={predict_state.max_steps_per_epoch}"