Skip to content

Commit

Permalink
fix dcp pyre errors (#832)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #832

Fixes pyre error in dcp tests

Also moves latest pytorch check to the top, otherwise import error is raised on StorageMeta in stable unit tests

Reviewed By: diego-urgell

Differential Revision: D57342263

fbshipit-source-id: 79d66a5637313eb399d995158baa0ae5c1821d27
  • Loading branch information
JKSenthil authored and facebook-github-bot committed May 14, 2024
1 parent 5e1db7f commit 8456d0c
Showing 1 changed file with 18 additions and 13 deletions.
31 changes: 18 additions & 13 deletions tests/framework/callbacks/test_dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,18 @@

# pyre-strict

import unittest

from torchtnt.framework.callbacks.dcp_saver import _LATEST_DCP_AVAIL

if not _LATEST_DCP_AVAIL:
raise unittest.SkipTest("Latest Pytorch is required to run DCP tests")

import math
import os
import shutil
import tempfile
import unittest
from typing import Any, Dict, Iterator, List
from typing import Any, Dict, Iterator, List, Optional
from unittest import mock
from unittest.mock import MagicMock, patch

Expand All @@ -23,7 +29,7 @@
DefaultLoadPlanner,
DefaultSavePlanner,
)
from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE
from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE, StorageMeta
from torch.utils.data import DataLoader
from torchsnapshot.test_utils import assert_state_dict_eq, check_state_dict_eq
from torchtnt.framework._test_utils import (
Expand All @@ -32,18 +38,12 @@
generate_random_dataloader,
)
from torchtnt.framework.callbacks.checkpointer_types import KnobOptions, RestoreOptions
from torchtnt.framework.callbacks.dcp_saver import (
_LATEST_DCP_AVAIL,
DistributedCheckpointSaver,
)
from torchtnt.framework.callbacks.dcp_saver import DistributedCheckpointSaver
from torchtnt.framework.train import train
from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
from torchtnt.utils.env import seed
from torchtnt.utils.test_utils import skip_if_not_distributed

if not _LATEST_DCP_AVAIL:
raise unittest.SkipTest("Latest Pytorch is required to run DCP tests")


class DistributedCheckpointSaverTest(unittest.TestCase):
def test_save_restore(self) -> None:
Expand Down Expand Up @@ -410,8 +410,13 @@ class DummySavePlanner(DefaultSavePlanner):
def __init__(self) -> None:
super().__init__()

def set_up_planner(self, state_dict: STATE_DICT_TYPE, is_coordinator: bool) -> None:
super().set_up_planner(state_dict, is_coordinator)
def set_up_planner(
self,
state_dict: STATE_DICT_TYPE,
storage_meta: Optional[StorageMeta],
is_coordinator: bool,
) -> None:
super().set_up_planner(state_dict, storage_meta, is_coordinator)


class DummyLoadPlanner(DefaultLoadPlanner):
Expand All @@ -421,7 +426,7 @@ def __init__(self) -> None:
def set_up_planner(
self,
state_dict: STATE_DICT_TYPE,
metadata: Metadata,
metadata: Optional[Metadata],
is_coordinator: bool,
) -> None:
super().set_up_planner(state_dict, metadata, is_coordinator)
Expand Down

0 comments on commit 8456d0c

Please sign in to comment.