diff --git a/tests/framework/callbacks/test_torchsnapshot_saver.py b/tests/framework/callbacks/test_torchsnapshot_saver.py index 6a9af77791..e7ba1bd5d4 100644 --- a/tests/framework/callbacks/test_torchsnapshot_saver.py +++ b/tests/framework/callbacks/test_torchsnapshot_saver.py @@ -27,6 +27,7 @@ ) from torchtnt.framework.callbacks.checkpointer_types import KnobOptions, RestoreOptions from torchtnt.framework.callbacks.torchsnapshot_saver import ( + _exclude_progress_from_replicated, _override_knobs, TorchSnapshotSaver, ) @@ -363,6 +364,68 @@ def test_sync_checkpoint(self, _: MagicMock) -> None: snapshot_cb.on_train_step_end(state, my_unit) snapshot_cb._sync_snapshot.assert_called_once() + def test_exclude_progress_from_replicated(self) -> None: + """ + Tests that replicated is populated correctly with progress excluded + """ + + module = nn.Linear(2, 3) + my_unit = DummyAutoUnit(module=module) + keys = my_unit.app_state().keys() + + progress_keys = {"train_progress", "eval_progress", "predict_progress"} + + replicated = _exclude_progress_from_replicated(my_unit.app_state()) + for key in keys: + if key not in progress_keys: + self.assertIn(f"{key}/**", replicated) + + # since we exclude 3 keys (train, eval, predict) + self.assertEqual(len(keys) - 3, len(replicated)) + + # check that progress is not included + for progress_key in progress_keys: + self.assertNotIn(f"{progress_key}/", replicated) + + @patch("torchtnt.framework.callbacks.torchsnapshot_saver.Snapshot.take") + def test_exclude_progress_from_replicated_e2e(self, mock_take: MagicMock) -> None: + """ + Tests that replicated is populated correctly during snapshotting + """ + + module = nn.Linear(2, 3) + my_unit = DummyAutoUnit(module=module) + state = get_dummy_train_state() + + with tempfile.TemporaryDirectory() as temp_dir: + for replicated_value in (None, ["optimizer/**"], ["**"]): + tss = TorchSnapshotSaver( + dirpath=temp_dir, + save_every_n_train_steps=1, + async_checkpoint=False, + replicated=replicated_value, + ) + + progress_keys = {"train_progress", "eval_progress", "predict_progress"} + + tss.on_train_step_end(state, my_unit) + replicated = mock_take.call_args.kwargs["replicated"] + + if replicated_value is None: + self.assertEqual(replicated, []) + elif replicated_value == ["optimizer/**"]: + self.assertEqual(replicated, ["optimizer/**"]) + elif replicated_value == ["**"]: + expected_replicated = [ + f"{key}/**" + for key in my_unit.app_state().keys() + if key not in progress_keys + ] + # this is added outside of the unit's app_state so it should be included + expected_replicated.append("rng_state/**") + + self.assertEqual(set(replicated), set(expected_replicated)) + class DummyStatefulDataLoader: def __init__(self, dataloader: DataLoader) -> None: diff --git a/torchtnt/framework/callbacks/torchsnapshot_saver.py b/torchtnt/framework/callbacks/torchsnapshot_saver.py index 4675603c84..1c854db881 100644 --- a/torchtnt/framework/callbacks/torchsnapshot_saver.py +++ b/torchtnt/framework/callbacks/torchsnapshot_saver.py @@ -216,12 +216,16 @@ def _async_snapshot( ) return False + replicated = self._replicated + if self._replicated == {"**"}: + replicated = _exclude_progress_from_replicated(app_state) + with _override_knobs(self._knob_options): self._prev_snapshot = Snapshot.async_take( str(snapshot_path), app_state=app_state, pg=self._process_group, - replicated=list(self._replicated), + replicated=list(replicated), storage_options=self._storage_options, ) rank_zero_info(f"Saving snapshot to path: {snapshot_path}", logger=logger) @@ -232,6 +236,10 @@ def _sync_snapshot( snapshot_path: str, app_state: Dict[str, _TStateful], ) -> bool: + replicated = self._replicated + if self._replicated == {"**"}: + replicated = _exclude_progress_from_replicated(app_state) + with _override_knobs(self._knob_options): rank_zero_info( f"Started saving snapshot to path: {snapshot_path}", logger=logger @@ -240,7 +248,7 @@ def _sync_snapshot( str(snapshot_path), app_state=app_state, pg=self._process_group, - replicated=list(self._replicated), + replicated=list(replicated), storage_options=self._storage_options, ) rank_zero_info( @@ -316,6 +324,22 @@ def restore( rank_zero_info(f"Restored snapshot from path: {path}", logger=logger) +def _exclude_progress_from_replicated(app_state: Dict[str, _TStateful]) -> Set[str]: + """ + Excludes progress state from being replicated. Called if replicated=["**"] is passed in. + Works by populating replicated with all possible keys from app_state, except for + the keys that match the "{train,eval,predict}_progress/**" pattern. + """ + + filtered_replicated = set() + progress_keys = {"train_progress", "eval_progress", "predict_progress"} + for key in app_state.keys(): + if key in progress_keys: + continue + filtered_replicated.add(f"{key}/**") + return filtered_replicated + + def _validate_snapshot_available() -> None: if not _TORCHSNAPSHOT_AVAILABLE: raise RuntimeError(