diff --git a/tests/framework/callbacks/test_dcp_saver.py b/tests/framework/callbacks/test_dcp_saver.py index e1f0df4b70..d31c6f5d3c 100644 --- a/tests/framework/callbacks/test_dcp_saver.py +++ b/tests/framework/callbacks/test_dcp_saver.py @@ -20,7 +20,6 @@ from torch import nn from torch.utils.data import DataLoader from torchsnapshot.test_utils import assert_state_dict_eq, check_state_dict_eq - from torchtnt.framework._test_utils import ( DummyAutoUnit, DummyTrainUnit, @@ -31,7 +30,7 @@ from torchtnt.framework.train import train from torchtnt.utils.distributed import get_global_rank, spawn_multi_process from torchtnt.utils.env import seed -from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu +from torchtnt.utils.test_utils import skip_if_not_distributed class DistributedCheckpointSaverTest(unittest.TestCase): @@ -222,55 +221,6 @@ def test_save_restore_no_lr_scheduler_restore( app_state = mock_dist_cp.load_state_dict.call_args.args[0]["app_state"] self.assertIn("lr_scheduler", app_state) - @skip_if_not_distributed - @skip_if_not_gpu - def test_save_restore_fsdp(self) -> None: - spawn_multi_process( - 2, - "nccl", - self._save_restore_fsdp, - ) - - @staticmethod - def _save_restore_fsdp() -> None: - input_dim = 2 - dataset_len = 10 - batch_size = 2 - max_epochs = 2 - save_every_n_epochs = 1 - - my_unit = DummyAutoUnit(module=torch.nn.Linear(input_dim, 2), strategy="fsdp") - dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size) - if get_global_rank() == 0: - temp_dir = tempfile.mkdtemp() - else: - temp_dir = "" - - dcp_cb = DistributedCheckpointSaver( - temp_dir, - save_every_n_epochs=save_every_n_epochs, - ) - temp_dir = dcp_cb.dirpath - train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[dcp_cb]) - - tc = unittest.TestCase() - try: - my_new_unit = DummyAutoUnit( - module=torch.nn.Linear(input_dim, 2), strategy="fsdp" - ) - tc.assertNotEqual( - my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict() - ) - # get latest checkpoint - ckpt_path = os.path.join(temp_dir, f"epoch_{max_epochs}_step_10") - dcp_cb.restore(ckpt_path, my_new_unit) - tc.assertEqual( - my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict() - ) - finally: - if get_global_rank() == 0: - shutil.rmtree(temp_dir) # delete temp directory - @skip_if_not_distributed def test_save_restore_ddp(self) -> None: spawn_multi_process( diff --git a/tests/framework/callbacks/test_dcp_saver_gpu.py b/tests/framework/callbacks/test_dcp_saver_gpu.py new file mode 100644 index 0000000000..2f57b7a9bc --- /dev/null +++ b/tests/framework/callbacks/test_dcp_saver_gpu.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import os +import shutil +import tempfile +import unittest + +import torch + +from torchtnt.framework._test_utils import DummyAutoUnit, generate_random_dataloader +from torchtnt.framework.callbacks.dcp_saver import DistributedCheckpointSaver +from torchtnt.framework.train import train +from torchtnt.utils.distributed import get_global_rank, spawn_multi_process +from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu + + +class DistributedCheckpointSaverGPUTest(unittest.TestCase): + @skip_if_not_distributed + @skip_if_not_gpu + def test_save_restore_fsdp(self) -> None: + spawn_multi_process( + 2, + "nccl", + self._save_restore_fsdp, + ) + + @staticmethod + def _save_restore_fsdp() -> None: + input_dim = 2 + dataset_len = 10 + batch_size = 2 + max_epochs = 2 + save_every_n_epochs = 1 + + my_unit = DummyAutoUnit(module=torch.nn.Linear(input_dim, 2), strategy="fsdp") + dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size) + if get_global_rank() == 0: + temp_dir = tempfile.mkdtemp() + else: + temp_dir = "" + + dcp_cb = DistributedCheckpointSaver( + temp_dir, + save_every_n_epochs=save_every_n_epochs, + ) + temp_dir = dcp_cb.dirpath + train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[dcp_cb]) + + tc = unittest.TestCase() + try: + my_new_unit = DummyAutoUnit( + module=torch.nn.Linear(input_dim, 2), strategy="fsdp" + ) + tc.assertNotEqual( + my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict() + ) + # get latest checkpoint + ckpt_path = os.path.join(temp_dir, f"epoch_{max_epochs}_step_10") + dcp_cb.restore(ckpt_path, my_new_unit) + tc.assertEqual( + my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict() + ) + finally: + if get_global_rank() == 0: + shutil.rmtree(temp_dir) # delete temp directory