Skip to content

Commit

Permalink
Use TNT's ManifoldPathHandler for listing checkpoints internally
Browse files Browse the repository at this point in the history
Summary:
We've faced multiple issues in the past where users register incompatible implementations of Manifold path handlers to fsspec, causing errors when listing and even loading the checkpoints.

Currently [one user is facing an error because of this](https://fb.workplace.com/groups/277527419809135/permalink/1625534775008386/), and it's tricky to debug because the error is not reproducible outside of that particular project (because of the specific dependencies being used)

Most internal customers should be using storage optimizations and then use modelstore components within DCP APIs, but we still use fsspec for listing latest and best checkpoints.

We need to make sure that our own implementation is used to list the checkpoints. Note we can't modify directly on checkpoint utils because it's OSS, while filesystem is internal.

Differential Revision: D65370757
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Nov 4, 2024
1 parent 8150bcc commit 21cf6aa
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 23 deletions.
12 changes: 2 additions & 10 deletions tests/utils/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1422,18 +1422,10 @@ def test_does_checkpoint_metadata_exist(self) -> None:
dirpath = os.path.join(temp_dir, "checkpoint")
Snapshot.take(dirpath, app_state=app_state)

self.assertTrue(
CheckpointManager.does_checkpoint_metadata_exist(
dirpath, SNAPSHOT_METADATA_FNAME
)
)
self.assertTrue(does_checkpoint_exist(dirpath, SNAPSHOT_METADATA_FNAME))

os.remove(os.path.join(dirpath, SNAPSHOT_METADATA_FNAME))
self.assertFalse(
CheckpointManager.does_checkpoint_metadata_exist(
dirpath, SNAPSHOT_METADATA_FNAME
)
)
self.assertFalse(does_checkpoint_exist(dirpath, SNAPSHOT_METADATA_FNAME))

def test_does_checkpoint_exist(self) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
Expand Down
14 changes: 13 additions & 1 deletion torchtnt/framework/callbacks/base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from datetime import timedelta
from typing import Any, cast, Iterable, List, Literal, Optional, Union

import fsspec

import torch.distributed as dist
from pyre_extensions import none_throws
from torchtnt.framework.callback import Callback
Expand Down Expand Up @@ -449,6 +451,7 @@ def restore_from_latest(
train_dataloader: Optional[Iterable[TTrainData]] = None,
process_group: Optional[dist.ProcessGroup] = None,
restore_options: Optional[RestoreOptions] = None,
file_system: Optional[fsspec.AbstractFileSystem] = None,
**kwargs: Any,
) -> bool:
"""
Expand All @@ -463,12 +466,17 @@ def restore_from_latest(
train_dataloader: An optional train dataloader to restore.
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world)
restore_options: Controls what to filter when restoring the state.
file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be
used to match the file system of the dirpath.
Returns:
True if the latest checkpoint directory was found and successfully restored, otherwise False.
"""
path = get_latest_checkpoint_path(
dirpath, metadata_fname=cls.metadata_fnames, process_group=process_group
dirpath,
metadata_fname=cls.metadata_fnames,
process_group=process_group,
file_system=file_system,
)
if path is None:
logger.info(
Expand Down Expand Up @@ -497,6 +505,7 @@ def restore_from_best(
train_dataloader: Optional[Iterable[TTrainData]] = None,
process_group: Optional[dist.ProcessGroup] = None,
restore_options: Optional[RestoreOptions] = None,
file_system: Optional[fsspec.AbstractFileSystem] = None,
**kwargs: Any,
) -> bool:
"""
Expand All @@ -512,6 +521,8 @@ def restore_from_best(
mode: Either 'min' or 'max'. If 'min', finds and loads the lowest value metric checkpoint. If 'max', finds and loads the largest.
train_dataloader: An optional train dataloader to restore.
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world)
file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be
used to match the file system of the dirpath.
restore_options: Controls what to filter when restoring the state.
Returns:
Expand All @@ -522,6 +533,7 @@ def restore_from_best(
metric_name=metric_name,
mode=mode,
metadata_fname=cls.metadata_fnames,
file_system=file_system,
process_group=process_group,
)

Expand Down
54 changes: 42 additions & 12 deletions torchtnt/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ def __init__(
keep_last_n_checkpoints: Optional[int] = None,
metadata_fnames: Optional[List[str]] = None,
process_group: Optional[dist.ProcessGroup] = None,
file_system: Optional[fsspec.AbstractFileSystem] = None,
) -> None:
"""
Initialize a checkpoint manager. If a `keep_last_n_checkpoints` value is provided, this will read the
Expand All @@ -389,6 +390,11 @@ def __init__(
self._keep_last_n_checkpoints = keep_last_n_checkpoints
self._pg_wrapper = PGWrapper(process_group)

if file_system is None:
file_system, _ = url_to_fs(self.dirpath)

self._file_system: fsspec.AbstractFileSystem = file_system

if metadata_fnames is None:
self._metadata_fnames: List[str] = []
else:
Expand Down Expand Up @@ -568,17 +574,16 @@ def does_checkpoint_exist(
ckpt.path, self._metadata_fnames, process_group=process_group
)

@staticmethod
def does_checkpoint_metadata_exist(
self,
checkpoint_path: str,
metadata_fname: str,
) -> bool:
"""
Checking whether a checkpoint metadata file exists in the directory.
If the checkpointer has that metadata file, this function will returns True. Returns False otherwise.
"""
fs, _ = url_to_fs(checkpoint_path)
return _metadata_exists(fs, checkpoint_path, metadata_fname)
return _metadata_exists(self._file_system, checkpoint_path, metadata_fname)

@staticmethod
@rank_zero_read_and_broadcast
Expand All @@ -596,9 +601,8 @@ def remove_checkpoint(self) -> None:
"""
worst_ckpt_path = self._ckpt_paths.pop(0)
if self._pg_wrapper.get_rank() == 0:
fs, _ = url_to_fs(self.dirpath)
try:
fs.rm(worst_ckpt_path.path, recursive=True)
self._file_system.rm(worst_ckpt_path.path, recursive=True)
except Exception as exc:
logger.error(
(
Expand All @@ -612,6 +616,7 @@ def remove_checkpoint(self) -> None:
def does_checkpoint_exist(
ckpt_path: str,
metadata_fname: Union[str, List[str]],
file_system: Optional[fsspec.AbstractFileSystem] = None,
process_group: Optional[dist.ProcessGroup] = None,
) -> bool:
"""
Expand All @@ -622,6 +627,8 @@ def does_checkpoint_exist(
Args:
ckpt: The checkpoint to check.
metadata_fname: File to check for existence. If a list is provided, it will check that at least one of the files is present.
file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be
used to match the file system of the dirpath.
process_group: Optional process group on which the ranks will communicate on. By default, the entire world is used.
"""
if not metadata_fname:
Expand All @@ -631,14 +638,18 @@ def does_checkpoint_exist(
[metadata_fname] if isinstance(metadata_fname, str) else metadata_fname
)

fs, _ = url_to_fs(ckpt_path)
fs = file_system
if fs is None:
fs, _ = url_to_fs(ckpt_path)

return any(_metadata_exists(fs, ckpt_path, fname) for fname in metadata_fnames)


@rank_zero_read_and_broadcast
def get_latest_checkpoint_path(
dirpath: str,
metadata_fname: Optional[Union[str, List[str]]] = None,
file_system: Optional[fsspec.AbstractFileSystem] = None,
process_group: Optional[dist.ProcessGroup] = None,
) -> Optional[str]:
"""
Expand All @@ -648,6 +659,8 @@ def get_latest_checkpoint_path(
dirpath: parent directory where checkpoints are saved.
metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist.
If a list is provided, it will check that at least one of the files is present.
file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be
used to match the file system of the dirpath.
process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world)
Raises:
Expand All @@ -658,14 +671,17 @@ def get_latest_checkpoint_path(
gloo process groups are recommended over nccl.
"""

return _get_latest_checkpoint_path(dirpath, metadata_fname)
return _get_latest_checkpoint_path(dirpath, metadata_fname, file_system)


def _get_latest_checkpoint_path(
dirpath: str,
metadata_fname: Optional[Union[str, List[str]]] = None,
file_system: Optional[fsspec.AbstractFileSystem] = None,
) -> Optional[str]:
candidate_dirpaths = _retrieve_checkpoint_dirpaths(dirpath, metadata_fname)
candidate_dirpaths = _retrieve_checkpoint_dirpaths(
dirpath, metadata_fname, file_system=file_system
)
if not candidate_dirpaths:
return None

Expand All @@ -683,6 +699,7 @@ def get_best_checkpoint_path(
metric_name: str,
mode: Literal["min", "max"],
metadata_fname: Optional[Union[str, List[str]]] = None,
file_system: Optional[fsspec.AbstractFileSystem] = None,
process_group: Optional[dist.ProcessGroup] = None,
) -> Optional[str]:
"""
Expand All @@ -697,14 +714,18 @@ def get_best_checkpoint_path(
mode: Either 'min' or 'max'. If 'min', finds and loads the lowest value metric checkpoint. If 'max', finds and loads the largest.
metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist.
If a list is provided, it will check that at least one of the files is present.
file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be
used to match the file system of the dirpath.
process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world)
Note:
When doing distributed training, only rank 0 will read the file system. The result will be broadcasted to all ranks.
gloo process groups are recommended over nccl.
"""

dirpaths = _retrieve_checkpoint_dirpaths(dirpath, metadata_fname, metric_name)
dirpaths = _retrieve_checkpoint_dirpaths(
dirpath, metadata_fname, metric_name, file_system=file_system
)
if not dirpaths:
return None

Expand All @@ -721,6 +742,7 @@ def get_checkpoint_dirpaths(
dirpath: str,
metadata_fname: Optional[Union[str, List[str]]] = None,
metric_name: Optional[str] = None,
file_system: Optional[fsspec.AbstractFileSystem] = None,
process_group: Optional[dist.ProcessGroup] = None,
) -> List[CheckpointPath]:
"""
Expand All @@ -736,20 +758,25 @@ def get_checkpoint_dirpaths(
metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist.
If a list is provided, it will check that at least one of the files is present.
metric_name: fetches all the checkpoint directories containing the metric name only.
file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be
used to match the file system of the dirpath.
process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world)
Note:
When doing distributed training, only rank 0 will read the file system. The result will be broadcasted to all ranks.
gloo process groups are recommended over nccl.
"""

return _retrieve_checkpoint_dirpaths(dirpath, metadata_fname, metric_name)
return _retrieve_checkpoint_dirpaths(
dirpath, metadata_fname, metric_name, file_system=file_system
)


def _retrieve_checkpoint_dirpaths(
dirpath: str,
metadata_fname: Optional[Union[str, List[str]]],
metric_name: Optional[str] = None,
file_system: Optional[fsspec.AbstractFileSystem] = None,
) -> List[CheckpointPath]:
"""
Given a parent directory where checkpoints are saved, return the unsorted checkpoint subdirectories
Expand All @@ -759,9 +786,12 @@ def _retrieve_checkpoint_dirpaths(
metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist.
If a list is provided, it will check that at least one of the files is present.
metric_name: Name of the metric that must exist in checkpoint name.
file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be
used to match the file system of the dirpath.
"""

fs, _ = url_to_fs(dirpath)
fs = file_system
if fs is None:
fs, _ = url_to_fs(dirpath)

if not fs.exists(dirpath):
logger.warning(f"Input dirpath doesn't exist: {dirpath}")
Expand Down

0 comments on commit 21cf6aa

Please sign in to comment.