Skip to content

Commit

Permalink
Recover logging dead code in DCP _async_save (#889)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #889

Reviewed By: JKSenthil

Differential Revision: D61739504

fbshipit-source-id: 1f71578ffd7c29dcbb1b74d8e051e13d5d458fba
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Aug 28, 2024
1 parent a6d3d91 commit 3345ed9
Showing 1 changed file with 43 additions and 23 deletions.
66 changes: 43 additions & 23 deletions torchtnt/framework/callbacks/dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import torch
import torch.distributed as dist
from pyre_extensions import none_throws
from torch.distributed import checkpoint as dcp
from torch.distributed.checkpoint.default_planner import (
DefaultLoadPlanner,
Expand Down Expand Up @@ -160,7 +161,7 @@ def _checkpoint_impl(
checkpoint_id, app_state, planner, storage_writer
)
if curr_snapshot_wait:
self._wait()
self._wait(log_warning=False)
else:
with get_timing_context(state, f"{self.__class__.__name__}.save"):
checkpoint_success = self._save(
Expand All @@ -169,9 +170,42 @@ def _checkpoint_impl(

return checkpoint_success

def _wait(self) -> None:
if self._prev_snapshot is not None:
self._prev_snapshot.result()
def _wait(self, log_warning: bool = True) -> None:
"""
If the previous async checkpoint is still running, wait for it to finish before continuing. Otherwise,
distributed collectives that use the checkpointing process group will result in a stuck job. This also
computes and logs the time spent waiting on the previous checkpoint to finish, and a toggable warning
for the user to modify checkpointing frequency.
If the previous checkpoing has already finished, this is a no-op.
Args:
log_warning: Toggle for logging a warning to the user to modify checkpointing frequency. Sometimes
this is not up to the user (e.g. on_exception, on_train_end).
"""
if self._prev_snapshot is None:
return

if self._prev_snapshot.done():
none_throws(self._prev_snapshot).result()
return

if log_warning:
rank_zero_warn(
(
"Waiting on previous checkpoint to finish... Consider modifying checkpointing "
f"frequency if this is an issue. Current value (current {self._save_every_n_train_steps})"
),
logger=logger,
)

t0 = time.monotonic()
none_throws(self._prev_snapshot).result()

rank_zero_warn(
f"Waiting on previous checkpoint for {time.monotonic()-t0:.3f} seconds",
logger=logger,
)

def _async_save(
self,
Expand All @@ -187,24 +221,8 @@ def _async_save(
if storage_writer is None:
storage_writer = Writer(checkpoint_id, **self.default_writer_options)

if self._prev_snapshot is not None:
if not self._prev_snapshot.done():
# TODO this is unreachable at this point, since we are waiting on other functions called before _checkpoint_impl.
rank_zero_warn(
(
"Waiting on previous checkpoint to finish... Consider modifying checkpointing "
f"frequency if this is an issue. Current value (current {self._save_every_n_train_steps})"
),
logger=logger,
)
t0 = time.monotonic()
self._wait()
rank_zero_warn(
f"Waiting on previous checkpoint for {time.monotonic()-t0:.3f} seconds",
logger=logger,
)
else:
self._wait()
# Redundant check for safety
self._wait(log_warning=True)

self._prev_snapshot = dcp.async_save(
state_dict={"app_state": MultiStateful(app_state)},
Expand Down Expand Up @@ -257,7 +275,8 @@ def on_exception(
unit: Union[TTrainUnit, TEvalUnit, TPredictUnit],
exc: BaseException,
) -> None:
self._wait()
rank_zero_info("Ensuring previous async checkpoint finished before exiting.")
self._wait(log_warning=False)

@staticmethod
def restore(
Expand Down Expand Up @@ -404,6 +423,7 @@ def _generate_checkpoint_and_upkeep(
# operations in the base class use the process group. So wait here instead.
self._wait()

# Note that every async checkpoint will be completed at this point.
return super()._generate_checkpoint_and_upkeep(state, unit, hook)

@property
Expand Down

0 comments on commit 3345ed9

Please sign in to comment.