From 8456d0c55af796b555a2471769f09290cb97fef3 Mon Sep 17 00:00:00 2001 From: Jason Senthil Date: Tue, 14 May 2024 13:15:11 -0700 Subject: [PATCH] fix dcp pyre errors (#832) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/832 Fixes pyre error in dcp tests Also moves latest pytorch check to the top, otherwise import error is raised on StorageMeta in stable unit tests Reviewed By: diego-urgell Differential Revision: D57342263 fbshipit-source-id: 79d66a5637313eb399d995158baa0ae5c1821d27 --- tests/framework/callbacks/test_dcp_saver.py | 31 ++++++++++++--------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/tests/framework/callbacks/test_dcp_saver.py b/tests/framework/callbacks/test_dcp_saver.py index 04b697bb09..a56bef385b 100644 --- a/tests/framework/callbacks/test_dcp_saver.py +++ b/tests/framework/callbacks/test_dcp_saver.py @@ -7,12 +7,18 @@ # pyre-strict +import unittest + +from torchtnt.framework.callbacks.dcp_saver import _LATEST_DCP_AVAIL + +if not _LATEST_DCP_AVAIL: + raise unittest.SkipTest("Latest Pytorch is required to run DCP tests") + import math import os import shutil import tempfile -import unittest -from typing import Any, Dict, Iterator, List +from typing import Any, Dict, Iterator, List, Optional from unittest import mock from unittest.mock import MagicMock, patch @@ -23,7 +29,7 @@ DefaultLoadPlanner, DefaultSavePlanner, ) -from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE +from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE, StorageMeta from torch.utils.data import DataLoader from torchsnapshot.test_utils import assert_state_dict_eq, check_state_dict_eq from torchtnt.framework._test_utils import ( @@ -32,18 +38,12 @@ generate_random_dataloader, ) from torchtnt.framework.callbacks.checkpointer_types import KnobOptions, RestoreOptions -from torchtnt.framework.callbacks.dcp_saver import ( - _LATEST_DCP_AVAIL, - DistributedCheckpointSaver, -) +from torchtnt.framework.callbacks.dcp_saver import DistributedCheckpointSaver from torchtnt.framework.train import train from torchtnt.utils.distributed import get_global_rank, spawn_multi_process from torchtnt.utils.env import seed from torchtnt.utils.test_utils import skip_if_not_distributed -if not _LATEST_DCP_AVAIL: - raise unittest.SkipTest("Latest Pytorch is required to run DCP tests") - class DistributedCheckpointSaverTest(unittest.TestCase): def test_save_restore(self) -> None: @@ -410,8 +410,13 @@ class DummySavePlanner(DefaultSavePlanner): def __init__(self) -> None: super().__init__() - def set_up_planner(self, state_dict: STATE_DICT_TYPE, is_coordinator: bool) -> None: - super().set_up_planner(state_dict, is_coordinator) + def set_up_planner( + self, + state_dict: STATE_DICT_TYPE, + storage_meta: Optional[StorageMeta], + is_coordinator: bool, + ) -> None: + super().set_up_planner(state_dict, storage_meta, is_coordinator) class DummyLoadPlanner(DefaultLoadPlanner): @@ -421,7 +426,7 @@ def __init__(self) -> None: def set_up_planner( self, state_dict: STATE_DICT_TYPE, - metadata: Metadata, + metadata: Optional[Metadata], is_coordinator: bool, ) -> None: super().set_up_planner(state_dict, metadata, is_coordinator)