diff --git a/tests/utils/test_checkpoint.py b/tests/utils/test_checkpoint.py index cc2a7c0c20..cba800a986 100644 --- a/tests/utils/test_checkpoint.py +++ b/tests/utils/test_checkpoint.py @@ -1422,18 +1422,10 @@ def test_does_checkpoint_metadata_exist(self) -> None: dirpath = os.path.join(temp_dir, "checkpoint") Snapshot.take(dirpath, app_state=app_state) - self.assertTrue( - CheckpointManager.does_checkpoint_metadata_exist( - dirpath, SNAPSHOT_METADATA_FNAME - ) - ) + self.assertTrue(does_checkpoint_exist(dirpath, SNAPSHOT_METADATA_FNAME)) os.remove(os.path.join(dirpath, SNAPSHOT_METADATA_FNAME)) - self.assertFalse( - CheckpointManager.does_checkpoint_metadata_exist( - dirpath, SNAPSHOT_METADATA_FNAME - ) - ) + self.assertFalse(does_checkpoint_exist(dirpath, SNAPSHOT_METADATA_FNAME)) def test_does_checkpoint_exist(self) -> None: with tempfile.TemporaryDirectory() as temp_dir: diff --git a/torchtnt/framework/callbacks/base_checkpointer.py b/torchtnt/framework/callbacks/base_checkpointer.py index 43f3657551..3a77fd11fa 100644 --- a/torchtnt/framework/callbacks/base_checkpointer.py +++ b/torchtnt/framework/callbacks/base_checkpointer.py @@ -12,6 +12,8 @@ from datetime import timedelta from typing import Any, cast, Iterable, List, Literal, Optional, Union +import fsspec + import torch.distributed as dist from pyre_extensions import none_throws from torchtnt.framework.callback import Callback @@ -449,6 +451,7 @@ def restore_from_latest( train_dataloader: Optional[Iterable[TTrainData]] = None, process_group: Optional[dist.ProcessGroup] = None, restore_options: Optional[RestoreOptions] = None, + file_system: Optional[fsspec.AbstractFileSystem] = None, **kwargs: Any, ) -> bool: """ @@ -463,12 +466,17 @@ def restore_from_latest( train_dataloader: An optional train dataloader to restore. process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world) restore_options: Controls what to filter when restoring the state. + file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be + used to match the file system of the dirpath. Returns: True if the latest checkpoint directory was found and successfully restored, otherwise False. """ path = get_latest_checkpoint_path( - dirpath, metadata_fname=cls.metadata_fnames, process_group=process_group + dirpath, + metadata_fname=cls.metadata_fnames, + process_group=process_group, + file_system=file_system, ) if path is None: logger.info( @@ -497,6 +505,7 @@ def restore_from_best( train_dataloader: Optional[Iterable[TTrainData]] = None, process_group: Optional[dist.ProcessGroup] = None, restore_options: Optional[RestoreOptions] = None, + file_system: Optional[fsspec.AbstractFileSystem] = None, **kwargs: Any, ) -> bool: """ @@ -512,6 +521,8 @@ def restore_from_best( mode: Either 'min' or 'max'. If 'min', finds and loads the lowest value metric checkpoint. If 'max', finds and loads the largest. train_dataloader: An optional train dataloader to restore. process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world) + file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be + used to match the file system of the dirpath. restore_options: Controls what to filter when restoring the state. Returns: @@ -522,6 +533,7 @@ def restore_from_best( metric_name=metric_name, mode=mode, metadata_fname=cls.metadata_fnames, + file_system=file_system, process_group=process_group, ) diff --git a/torchtnt/utils/checkpoint.py b/torchtnt/utils/checkpoint.py index b6e1f962ca..c92c4e00ee 100644 --- a/torchtnt/utils/checkpoint.py +++ b/torchtnt/utils/checkpoint.py @@ -366,6 +366,7 @@ def __init__( keep_last_n_checkpoints: Optional[int] = None, metadata_fnames: Optional[List[str]] = None, process_group: Optional[dist.ProcessGroup] = None, + file_system: Optional[fsspec.AbstractFileSystem] = None, ) -> None: """ Initialize a checkpoint manager. If a `keep_last_n_checkpoints` value is provided, this will read the @@ -389,6 +390,11 @@ def __init__( self._keep_last_n_checkpoints = keep_last_n_checkpoints self._pg_wrapper = PGWrapper(process_group) + if file_system is None: + file_system, _ = url_to_fs(self.dirpath) + + self._file_system: fsspec.AbstractFileSystem = file_system + if metadata_fnames is None: self._metadata_fnames: List[str] = [] else: @@ -568,8 +574,8 @@ def does_checkpoint_exist( ckpt.path, self._metadata_fnames, process_group=process_group ) - @staticmethod def does_checkpoint_metadata_exist( + self, checkpoint_path: str, metadata_fname: str, ) -> bool: @@ -577,8 +583,7 @@ def does_checkpoint_metadata_exist( Checking whether a checkpoint metadata file exists in the directory. If the checkpointer has that metadata file, this function will returns True. Returns False otherwise. """ - fs, _ = url_to_fs(checkpoint_path) - return _metadata_exists(fs, checkpoint_path, metadata_fname) + return _metadata_exists(self._file_system, checkpoint_path, metadata_fname) @staticmethod @rank_zero_read_and_broadcast @@ -596,9 +601,8 @@ def remove_checkpoint(self) -> None: """ worst_ckpt_path = self._ckpt_paths.pop(0) if self._pg_wrapper.get_rank() == 0: - fs, _ = url_to_fs(self.dirpath) try: - fs.rm(worst_ckpt_path.path, recursive=True) + self._file_system.rm(worst_ckpt_path.path, recursive=True) except Exception as exc: logger.error( ( @@ -612,6 +616,7 @@ def remove_checkpoint(self) -> None: def does_checkpoint_exist( ckpt_path: str, metadata_fname: Union[str, List[str]], + file_system: Optional[fsspec.AbstractFileSystem] = None, process_group: Optional[dist.ProcessGroup] = None, ) -> bool: """ @@ -622,6 +627,8 @@ def does_checkpoint_exist( Args: ckpt: The checkpoint to check. metadata_fname: File to check for existence. If a list is provided, it will check that at least one of the files is present. + file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be + used to match the file system of the dirpath. process_group: Optional process group on which the ranks will communicate on. By default, the entire world is used. """ if not metadata_fname: @@ -631,7 +638,10 @@ def does_checkpoint_exist( [metadata_fname] if isinstance(metadata_fname, str) else metadata_fname ) - fs, _ = url_to_fs(ckpt_path) + fs = file_system + if fs is None: + fs, _ = url_to_fs(ckpt_path) + return any(_metadata_exists(fs, ckpt_path, fname) for fname in metadata_fnames) @@ -639,6 +649,7 @@ def does_checkpoint_exist( def get_latest_checkpoint_path( dirpath: str, metadata_fname: Optional[Union[str, List[str]]] = None, + file_system: Optional[fsspec.AbstractFileSystem] = None, process_group: Optional[dist.ProcessGroup] = None, ) -> Optional[str]: """ @@ -648,6 +659,8 @@ def get_latest_checkpoint_path( dirpath: parent directory where checkpoints are saved. metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist. If a list is provided, it will check that at least one of the files is present. + file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be + used to match the file system of the dirpath. process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world) Raises: @@ -658,14 +671,17 @@ def get_latest_checkpoint_path( gloo process groups are recommended over nccl. """ - return _get_latest_checkpoint_path(dirpath, metadata_fname) + return _get_latest_checkpoint_path(dirpath, metadata_fname, file_system) def _get_latest_checkpoint_path( dirpath: str, metadata_fname: Optional[Union[str, List[str]]] = None, + file_system: Optional[fsspec.AbstractFileSystem] = None, ) -> Optional[str]: - candidate_dirpaths = _retrieve_checkpoint_dirpaths(dirpath, metadata_fname) + candidate_dirpaths = _retrieve_checkpoint_dirpaths( + dirpath, metadata_fname, file_system=file_system + ) if not candidate_dirpaths: return None @@ -683,6 +699,7 @@ def get_best_checkpoint_path( metric_name: str, mode: Literal["min", "max"], metadata_fname: Optional[Union[str, List[str]]] = None, + file_system: Optional[fsspec.AbstractFileSystem] = None, process_group: Optional[dist.ProcessGroup] = None, ) -> Optional[str]: """ @@ -697,6 +714,8 @@ def get_best_checkpoint_path( mode: Either 'min' or 'max'. If 'min', finds and loads the lowest value metric checkpoint. If 'max', finds and loads the largest. metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist. If a list is provided, it will check that at least one of the files is present. + file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be + used to match the file system of the dirpath. process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world) Note: @@ -704,7 +723,9 @@ def get_best_checkpoint_path( gloo process groups are recommended over nccl. """ - dirpaths = _retrieve_checkpoint_dirpaths(dirpath, metadata_fname, metric_name) + dirpaths = _retrieve_checkpoint_dirpaths( + dirpath, metadata_fname, metric_name, file_system=file_system + ) if not dirpaths: return None @@ -721,6 +742,7 @@ def get_checkpoint_dirpaths( dirpath: str, metadata_fname: Optional[Union[str, List[str]]] = None, metric_name: Optional[str] = None, + file_system: Optional[fsspec.AbstractFileSystem] = None, process_group: Optional[dist.ProcessGroup] = None, ) -> List[CheckpointPath]: """ @@ -736,6 +758,8 @@ def get_checkpoint_dirpaths( metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist. If a list is provided, it will check that at least one of the files is present. metric_name: fetches all the checkpoint directories containing the metric name only. + file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be + used to match the file system of the dirpath. process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world) Note: @@ -743,13 +767,16 @@ def get_checkpoint_dirpaths( gloo process groups are recommended over nccl. """ - return _retrieve_checkpoint_dirpaths(dirpath, metadata_fname, metric_name) + return _retrieve_checkpoint_dirpaths( + dirpath, metadata_fname, metric_name, file_system=file_system + ) def _retrieve_checkpoint_dirpaths( dirpath: str, metadata_fname: Optional[Union[str, List[str]]], metric_name: Optional[str] = None, + file_system: Optional[fsspec.AbstractFileSystem] = None, ) -> List[CheckpointPath]: """ Given a parent directory where checkpoints are saved, return the unsorted checkpoint subdirectories @@ -759,9 +786,12 @@ def _retrieve_checkpoint_dirpaths( metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist. If a list is provided, it will check that at least one of the files is present. metric_name: Name of the metric that must exist in checkpoint name. + file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be + used to match the file system of the dirpath. """ - - fs, _ = url_to_fs(dirpath) + fs = file_system + if fs is None: + fs, _ = url_to_fs(dirpath) if not fs.exists(dirpath): logger.warning(f"Input dirpath doesn't exist: {dirpath}")