Skip to content

Commit

Permalink
use gloo in checkpointers (#686)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #686

In BaseCheckpointer (and subsequently in all subclassed checkpointers ie TorchSnapshotSaver, Distributed Checkpoint Saver), always creates a GLOO based process group to use for checkpointing

Reviewed By: galrotem

Differential Revision: D52916978

fbshipit-source-id: 717c178aa66bf372959ba83790068ea86b86a99d
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Jan 20, 2024
1 parent b036390 commit fe2620d
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 17 deletions.
29 changes: 14 additions & 15 deletions tests/framework/callbacks/test_base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,33 +416,32 @@ def test_invalid_args(self) -> None:
@skip_if_not_gpu
def test_process_group_plumbing(self) -> None:
"""
Creates a new process group and verifies that it's passed through correctly
Creates a new process group and verifies GLOO group is created accordingly
"""
spawn_multi_process(
2,
"nccl",
self._test_process_group_plumbing,
)
spawn_multi_process(
2,
"gloo",
self._test_process_group_plumbing,
)

@staticmethod
def _test_process_group_plumbing() -> None:
new_pg = dist.new_group(backend="gloo")

if get_global_rank() == 0:
temp_dir = tempfile.mkdtemp()
else:
temp_dir = ""

checkpoint_cb = BaseCheckpointSaver(
temp_dir,
process_group=new_pg,
"foo",
process_group=None,
)
tc = unittest.TestCase()
try:
tc.assertEqual(checkpoint_cb._process_group, new_pg)
finally:
if get_global_rank() == 0:
shutil.rmtree(temp_dir) # delete temp directory
tc.assertEqual(
dist.get_backend(checkpoint_cb._process_group), dist.Backend.GLOO
)
if dist.get_backend(dist.group.WORLD) == dist.Backend.GLOO:
# verify no new process group was created
tc.assertEqual(checkpoint_cb._process_group, dist.group.WORLD)

@patch(
"torchtnt.framework.callbacks.base_checkpointer.get_checkpoint_dirpaths",
Expand Down
29 changes: 27 additions & 2 deletions torchtnt/framework/callbacks/base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import bisect
import logging
import os
from datetime import timedelta
from typing import Any, cast, Iterable, List, Literal, Optional, Union

import torch.distributed as dist
Expand Down Expand Up @@ -57,7 +58,7 @@ class BaseCheckpointer(Callback, metaclass=abc.ABCMeta):
save_every_n_eval_epochs: Frequency of evaluation epochs with which to save checkpoints during training. Use this if wanting to save checkpoints after every eval epoch during fit.
keep_last_n_checkpoints: Number of most recent checkpoints to keep. If None, all checkpoints are kept. If an excess of existing checkpoints are present, the oldest ones will be deleted to clean the difference. If best checkpoint config is enabled, this param will manage the top n checkpoints instead.
best_checkpoint_config: Configuration for saving the best checkpoint based on a monitored metric. The metric is read off the attribute of the unit prior to checkpoint.
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world)
process_group: The process group on which the ranks will communicate on. If the process group is not gloo-based, a new gloo-based process group will be created.
Note:
If torch.distributed is available and default process group is initialized, the constructor will call a collective operation for rank 0 to broadcast the dirpath to all other ranks
Expand Down Expand Up @@ -130,12 +131,36 @@ def __init__(
else:
self._ckpt_dirpaths = _sort_by_recency(ckpt_dirpaths)

self._process_group = process_group
self._process_group: Optional[dist.ProcessGroup] = None
self._setup_gloo_pg(process_group)
self._pg_wrapper = PGWrapper(process_group)

# sync dirpaths from rank 0
self._sync_dirpath_to_all_ranks(dirpath)

def _setup_gloo_pg(self, process_group: Optional[dist.ProcessGroup]) -> None:
"""
Setups gloo process group to be used for any collectives called during
checkpointing. If global process group is already gloo, no action is required.
Gloo is used over nccl for better compatibility.
"""
if not dist.is_initialized():
# there can be no process group
return

if process_group is None:
# use global process group
process_group = dist.group.WORLD

# we create a new gloo process group if different backend is being used
if dist.get_backend(process_group) != dist.Backend.GLOO:
rank_zero_info("Creating new gloo process group for checkpointing.")
self._process_group = dist.new_group(
timeout=timedelta(seconds=3600), backend=dist.Backend.GLOO
)
else:
self._process_group = process_group

def _sync_dirpath_to_all_ranks(self, dirpath: str) -> None:
if not (dist.is_available() and dist.is_initialized()):
self._dirpath: str = dirpath
Expand Down

0 comments on commit fe2620d

Please sign in to comment.