Skip to content

Commit

Permalink
Move DCP Saver GPU test to a dedicated test file (#752)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #752

Reviewed By: JKSenthil

Differential Revision: D55151767

fbshipit-source-id: 127c2b2e4c5a5b086ad45534d1c6ef2f97493c34
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Mar 21, 2024
1 parent e6b7933 commit bc2bf15
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 51 deletions.
52 changes: 1 addition & 51 deletions tests/framework/callbacks/test_dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from torch import nn
from torch.utils.data import DataLoader
from torchsnapshot.test_utils import assert_state_dict_eq, check_state_dict_eq

from torchtnt.framework._test_utils import (
DummyAutoUnit,
DummyTrainUnit,
Expand All @@ -31,7 +30,7 @@
from torchtnt.framework.train import train
from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
from torchtnt.utils.env import seed
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu
from torchtnt.utils.test_utils import skip_if_not_distributed


class DistributedCheckpointSaverTest(unittest.TestCase):
Expand Down Expand Up @@ -222,55 +221,6 @@ def test_save_restore_no_lr_scheduler_restore(
app_state = mock_dist_cp.load_state_dict.call_args.args[0]["app_state"]
self.assertIn("lr_scheduler", app_state)

@skip_if_not_distributed
@skip_if_not_gpu
def test_save_restore_fsdp(self) -> None:
spawn_multi_process(
2,
"nccl",
self._save_restore_fsdp,
)

@staticmethod
def _save_restore_fsdp() -> None:
input_dim = 2
dataset_len = 10
batch_size = 2
max_epochs = 2
save_every_n_epochs = 1

my_unit = DummyAutoUnit(module=torch.nn.Linear(input_dim, 2), strategy="fsdp")
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
if get_global_rank() == 0:
temp_dir = tempfile.mkdtemp()
else:
temp_dir = ""

dcp_cb = DistributedCheckpointSaver(
temp_dir,
save_every_n_epochs=save_every_n_epochs,
)
temp_dir = dcp_cb.dirpath
train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[dcp_cb])

tc = unittest.TestCase()
try:
my_new_unit = DummyAutoUnit(
module=torch.nn.Linear(input_dim, 2), strategy="fsdp"
)
tc.assertNotEqual(
my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict()
)
# get latest checkpoint
ckpt_path = os.path.join(temp_dir, f"epoch_{max_epochs}_step_10")
dcp_cb.restore(ckpt_path, my_new_unit)
tc.assertEqual(
my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict()
)
finally:
if get_global_rank() == 0:
shutil.rmtree(temp_dir) # delete temp directory

@skip_if_not_distributed
def test_save_restore_ddp(self) -> None:
spawn_multi_process(
Expand Down
72 changes: 72 additions & 0 deletions tests/framework/callbacks/test_dcp_saver_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import os
import shutil
import tempfile
import unittest

import torch

from torchtnt.framework._test_utils import DummyAutoUnit, generate_random_dataloader
from torchtnt.framework.callbacks.dcp_saver import DistributedCheckpointSaver
from torchtnt.framework.train import train
from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu


class DistributedCheckpointSaverGPUTest(unittest.TestCase):
@skip_if_not_distributed
@skip_if_not_gpu
def test_save_restore_fsdp(self) -> None:
spawn_multi_process(
2,
"nccl",
self._save_restore_fsdp,
)

@staticmethod
def _save_restore_fsdp() -> None:
input_dim = 2
dataset_len = 10
batch_size = 2
max_epochs = 2
save_every_n_epochs = 1

my_unit = DummyAutoUnit(module=torch.nn.Linear(input_dim, 2), strategy="fsdp")
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
if get_global_rank() == 0:
temp_dir = tempfile.mkdtemp()
else:
temp_dir = ""

dcp_cb = DistributedCheckpointSaver(
temp_dir,
save_every_n_epochs=save_every_n_epochs,
)
temp_dir = dcp_cb.dirpath
train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[dcp_cb])

tc = unittest.TestCase()
try:
my_new_unit = DummyAutoUnit(
module=torch.nn.Linear(input_dim, 2), strategy="fsdp"
)
tc.assertNotEqual(
my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict()
)
# get latest checkpoint
ckpt_path = os.path.join(temp_dir, f"epoch_{max_epochs}_step_10")
dcp_cb.restore(ckpt_path, my_new_unit)
tc.assertEqual(
my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict()
)
finally:
if get_global_rank() == 0:
shutil.rmtree(temp_dir) # delete temp directory

0 comments on commit bc2bf15

Please sign in to comment.