diff --git a/tests/framework/callbacks/test_base_checkpointer.py b/tests/framework/callbacks/test_base_checkpointer.py index 5f584ba495..a28fae5a65 100644 --- a/tests/framework/callbacks/test_base_checkpointer.py +++ b/tests/framework/callbacks/test_base_checkpointer.py @@ -339,6 +339,11 @@ def _checkpoint_impl_side_effect( expected_ckpts = [ f"{temp_dir}/epoch_0_predict_step_{i}" for i in range(1, 11) ] + + expected_ckpts.append( + f"{temp_dir}/epoch_1_predict_step_10" + ) # We always expect checkpoint on predict end + self.assertEqual(ckpt_container, expected_ckpts) @unittest.mock.patch("sys.stdout", new_callable=io.StringIO) diff --git a/tests/framework/callbacks/test_dcp_saver.py b/tests/framework/callbacks/test_dcp_saver.py index bb520614ad..fc12243229 100644 --- a/tests/framework/callbacks/test_dcp_saver.py +++ b/tests/framework/callbacks/test_dcp_saver.py @@ -536,12 +536,15 @@ def test_save_restore_predict(self) -> None: expected_ckpts = [ "epoch_0_predict_step_2", "epoch_0_predict_step_4", + "epoch_1_predict_step_5", ] self.assertCountEqual(generated_ckpts, expected_ckpts) - ckpt_path = none_throws(get_latest_checkpoint_path(temp_dir)) - self.assertEqual(ckpt_path, os.path.join(temp_dir, expected_ckpts[-1])) + latest_ckpt_path = none_throws(get_latest_checkpoint_path(temp_dir)) + self.assertEqual( + latest_ckpt_path, os.path.join(temp_dir, expected_ckpts[-1]) + ) expected_keys = [ "predict_progress", @@ -549,6 +552,9 @@ def test_save_restore_predict(self) -> None: "output_mean", ] + # Check keys on a checkpoint other than the latest since it won't have dataloader state + ckpt_path = f"{temp_dir}/{expected_ckpts[0]}" + storage_reader = FsspecReader(ckpt_path) metadata = storage_reader.read_metadata() self.assertCountEqual( diff --git a/torchtnt/framework/callbacks/base_checkpointer.py b/torchtnt/framework/callbacks/base_checkpointer.py index e1a07b8626..6b4e825823 100644 --- a/torchtnt/framework/callbacks/base_checkpointer.py +++ b/torchtnt/framework/callbacks/base_checkpointer.py @@ -378,6 +378,9 @@ def on_predict_step_end(self, state: State, unit: TPredictUnit) -> None: self._generate_checkpoint_and_upkeep(state, unit, hook="on_predict_step_end") + def on_predict_end(self, state: State, unit: TPredictUnit) -> None: + self._generate_checkpoint_and_upkeep(state, unit, hook="on_predict_end") + def _disable_ckpt_optimality_tracking(self) -> None: """ Disables checkpoint optimality tracking. This means that best_checkpoint and keep_last_n_checkpoints diff --git a/torchtnt/framework/callbacks/dcp_saver.py b/torchtnt/framework/callbacks/dcp_saver.py index 17731c251b..4f35894807 100644 --- a/torchtnt/framework/callbacks/dcp_saver.py +++ b/torchtnt/framework/callbacks/dcp_saver.py @@ -169,6 +169,7 @@ def _checkpoint_impl( "on_eval_epoch_end", "on_eval_step_end", "on_predict_step_end", + "on_predict_end", ]: raise RuntimeError(f"Unexpected hook encountered '{hook}'") @@ -178,7 +179,7 @@ def _checkpoint_impl( intra_epoch = "step_end" in hook or ( "on_eval_epoch_end" == hook and state.entry_point == EntryPoint.FIT ) - curr_snapshot_wait = hook == "on_train_end" + curr_snapshot_wait = hook in ("on_train_end", "on_predict_end") if planner is None: planner = DefaultSavePlanner()