From 5188ce051a11722edf6c7b269342c6c530bbe8e3 Mon Sep 17 00:00:00 2001 From: Jason Senthil Date: Wed, 30 Oct 2024 15:59:53 -0700 Subject: [PATCH] extract get_latest_checkpoint_path to a single rank function Summary: # Context `get_latest_checkpoint_path` broadcasts the path from rank 0 to other ranks. However, there may be scenario where an operation involving checkpoint read + other actions is only done on rank 0, in which case the broadcast will hang as other ranks never enter that logic # This Diff Write a `_get_latest_checkpoint_path` which is torch.distributed unaware Reviewed By: anshulverma Differential Revision: D64282217 --- torchtnt/utils/checkpoint.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/torchtnt/utils/checkpoint.py b/torchtnt/utils/checkpoint.py index 07449b84f4..b6e1f962ca 100644 --- a/torchtnt/utils/checkpoint.py +++ b/torchtnt/utils/checkpoint.py @@ -658,6 +658,13 @@ def get_latest_checkpoint_path( gloo process groups are recommended over nccl. """ + return _get_latest_checkpoint_path(dirpath, metadata_fname) + + +def _get_latest_checkpoint_path( + dirpath: str, + metadata_fname: Optional[Union[str, List[str]]] = None, +) -> Optional[str]: candidate_dirpaths = _retrieve_checkpoint_dirpaths(dirpath, metadata_fname) if not candidate_dirpaths: return None