Skip to content

Commit

Permalink
Move torchsnapshot_saver GPU test to dedicate file (#760)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #760

Differential Revision: D55327868
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Mar 25, 2024
1 parent c2dcee9 commit d326923
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 51 deletions.
52 changes: 1 addition & 51 deletions tests/framework/callbacks/test_torchsnapshot_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from torchtnt.framework.train import train
from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
from torchtnt.utils.env import seed
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu
from torchtnt.utils.test_utils import skip_if_not_distributed


class TorchSnapshotSaverTest(unittest.TestCase):
Expand Down Expand Up @@ -227,56 +227,6 @@ def test_save_restore_no_lr_scheduler_restore(
app_state = mock_torchsnapshot.Snapshot().restore.call_args.args[0]
self.assertIn("lr_scheduler", app_state)

@skip_if_not_distributed
@skip_if_not_gpu
def test_save_restore_fsdp(self) -> None:
spawn_multi_process(
2,
"nccl",
self._save_restore_fsdp,
)

@staticmethod
def _save_restore_fsdp() -> None:
input_dim = 2
dataset_len = 10
batch_size = 2
max_epochs = 2
save_every_n_epochs = 1

my_unit = DummyAutoUnit(module=torch.nn.Linear(input_dim, 2), strategy="fsdp")
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
if get_global_rank() == 0:
temp_dir = tempfile.mkdtemp()
else:
temp_dir = ""

snapshot_cb = TorchSnapshotSaver(
temp_dir,
save_every_n_epochs=save_every_n_epochs,
replicated=["**"],
)
temp_dir = snapshot_cb.dirpath
train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[snapshot_cb])

tc = unittest.TestCase()
try:
my_new_unit = DummyAutoUnit(
module=torch.nn.Linear(input_dim, 2), strategy="fsdp"
)
tc.assertNotEqual(
my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict()
)
# get latest checkpoint
ckpt_path = os.path.join(temp_dir, f"epoch_{max_epochs}_step_10")
snapshot_cb.restore(ckpt_path, my_new_unit)
tc.assertEqual(
my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict()
)
finally:
if get_global_rank() == 0:
shutil.rmtree(temp_dir) # delete temp directory

@skip_if_not_distributed
def test_save_restore_ddp(self) -> None:
spawn_multi_process(
Expand Down
72 changes: 72 additions & 0 deletions tests/framework/callbacks/test_torchsnapshot_saver_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import os
import shutil
import tempfile
import unittest

import torch
from torchtnt.framework._test_utils import DummyAutoUnit, generate_random_dataloader
from torchtnt.framework.callbacks.torchsnapshot_saver import TorchSnapshotSaver
from torchtnt.framework.train import train
from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu


class TorchSnapshotSaverGPUTest(unittest.TestCase):
@skip_if_not_distributed
@skip_if_not_gpu
def test_save_restore_fsdp(self) -> None:
spawn_multi_process(
2,
"nccl",
self._save_restore_fsdp,
)

@staticmethod
def _save_restore_fsdp() -> None:
input_dim = 2
dataset_len = 10
batch_size = 2
max_epochs = 2
save_every_n_epochs = 1

my_unit = DummyAutoUnit(module=torch.nn.Linear(input_dim, 2), strategy="fsdp")
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
if get_global_rank() == 0:
temp_dir = tempfile.mkdtemp()
else:
temp_dir = ""

snapshot_cb = TorchSnapshotSaver(
temp_dir,
save_every_n_epochs=save_every_n_epochs,
replicated=["**"],
)
temp_dir = snapshot_cb.dirpath
train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[snapshot_cb])

tc = unittest.TestCase()
try:
my_new_unit = DummyAutoUnit(
module=torch.nn.Linear(input_dim, 2), strategy="fsdp"
)
tc.assertNotEqual(
my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict()
)
# get latest checkpoint
ckpt_path = os.path.join(temp_dir, f"epoch_{max_epochs}_step_10")
snapshot_cb.restore(ckpt_path, my_new_unit)
tc.assertEqual(
my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict()
)
finally:
if get_global_rank() == 0:
shutil.rmtree(temp_dir) # delete temp directory

0 comments on commit d326923

Please sign in to comment.