Skip to content

Commit

Permalink
Save checkpoint in on_predict_end hook (#973)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #973

Reviewed By: galrotem

Differential Revision: D69865408
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Feb 25, 2025
1 parent cf65385 commit 608114c
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 3 deletions.
5 changes: 5 additions & 0 deletions tests/framework/callbacks/test_base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,11 @@ def _checkpoint_impl_side_effect(
expected_ckpts = [
f"{temp_dir}/epoch_0_predict_step_{i}" for i in range(1, 11)
]

expected_ckpts.append(
f"{temp_dir}/epoch_1_predict_step_10"
) # We always expect checkpoint on predict end

self.assertEqual(ckpt_container, expected_ckpts)

@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
Expand Down
10 changes: 8 additions & 2 deletions tests/framework/callbacks/test_dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,19 +536,25 @@ def test_save_restore_predict(self) -> None:
expected_ckpts = [
"epoch_0_predict_step_2",
"epoch_0_predict_step_4",
"epoch_1_predict_step_5",
]

self.assertCountEqual(generated_ckpts, expected_ckpts)

ckpt_path = none_throws(get_latest_checkpoint_path(temp_dir))
self.assertEqual(ckpt_path, os.path.join(temp_dir, expected_ckpts[-1]))
latest_ckpt_path = none_throws(get_latest_checkpoint_path(temp_dir))
self.assertEqual(
latest_ckpt_path, os.path.join(temp_dir, expected_ckpts[-1])
)

expected_keys = [
"predict_progress",
"predict_dataloader",
"output_mean",
]

# Check keys on a checkpoint other than the latest since it won't have dataloader state
ckpt_path = f"{temp_dir}/{expected_ckpts[0]}"

storage_reader = FsspecReader(ckpt_path)
metadata = storage_reader.read_metadata()
self.assertCountEqual(
Expand Down
3 changes: 3 additions & 0 deletions torchtnt/framework/callbacks/base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,9 @@ def on_predict_step_end(self, state: State, unit: TPredictUnit) -> None:

self._generate_checkpoint_and_upkeep(state, unit, hook="on_predict_step_end")

def on_predict_end(self, state: State, unit: TPredictUnit) -> None:
self._generate_checkpoint_and_upkeep(state, unit, hook="on_predict_end")

def _disable_ckpt_optimality_tracking(self) -> None:
"""
Disables checkpoint optimality tracking. This means that best_checkpoint and keep_last_n_checkpoints
Expand Down
3 changes: 2 additions & 1 deletion torchtnt/framework/callbacks/dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def _checkpoint_impl(
"on_eval_epoch_end",
"on_eval_step_end",
"on_predict_step_end",
"on_predict_end",
]:
raise RuntimeError(f"Unexpected hook encountered '{hook}'")

Expand All @@ -178,7 +179,7 @@ def _checkpoint_impl(
intra_epoch = "step_end" in hook or (
"on_eval_epoch_end" == hook and state.entry_point == EntryPoint.FIT
)
curr_snapshot_wait = hook == "on_train_end"
curr_snapshot_wait = hook in ("on_train_end", "on_predict_end")

if planner is None:
planner = DefaultSavePlanner()
Expand Down

0 comments on commit 608114c

Please sign in to comment.