Skip to content

Commit

Permalink
add time limit interrupter callback (#725)
Browse files Browse the repository at this point in the history
Summary:

Adds callback to interrupt training loop when specified duration is exceeded

Reviewed By: galrotem

Differential Revision: D54600265
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Mar 7, 2024
1 parent ac10dc8 commit 0472f25
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/framework/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ We offer several pre-written callbacks which are ready to be used out of the box
PyTorchProfiler
SystemResourcesMonitor
TensorBoardParameterMonitor
TimeLimitInterrupter
IterationTimeLogger
TorchSnapshotSaver
TQDMProgressBar
Expand Down
96 changes: 96 additions & 0 deletions tests/framework/callbacks/test_time_limit_interrupter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import unittest
from datetime import timedelta
from unittest.mock import MagicMock, Mock, patch

from torchtnt.framework._test_utils import DummyTrainUnit, get_dummy_train_state

from torchtnt.framework.callbacks.time_limit_interrupter import TimeLimitInterrupter


class TimeLimitInterrupterTest(unittest.TestCase):
def test_str_to_timedelta_conversion(self) -> None:
tli = TimeLimitInterrupter(duration="02:10:20")
self.assertEqual(
tli._duration, timedelta(days=2, hours=10, minutes=20).total_seconds()
)

with self.assertRaisesRegex(ValueError, "Invalid duration format"):
tli = TimeLimitInterrupter(duration="2:10:20")

with self.assertRaisesRegex(ValueError, "Invalid duration format"):
tli = TimeLimitInterrupter(duration="02:24:20")

with self.assertRaisesRegex(ValueError, "Invalid duration format"):
tli = TimeLimitInterrupter(duration="02:23:60")

@patch("time.monotonic")
def test_should_stop(self, mock_time_monotonic: MagicMock) -> None:
for duration in ("00:00:42", timedelta(minutes=42)):
tli = TimeLimitInterrupter(duration=duration)
state = get_dummy_train_state()

# setup start time
mock_time_monotonic.return_value = 0
tli.on_train_start(state, Mock())

# check that we don't stop before duration
mock_time_monotonic.return_value = 41 * 60
tli._should_stop(state)
self.assertFalse(state._should_stop)

# check that we stop after duration
mock_time_monotonic.return_value = 42 * 60
tli._should_stop(state)
self.assertTrue(state._should_stop)

def test_interval(self) -> None:
tli = TimeLimitInterrupter(duration="00:00:42", interval="epoch")
tli._should_stop = Mock()

state = Mock()
unit = DummyTrainUnit(input_dim=1)

tli.on_train_step_end(state, unit)
tli._should_stop.assert_not_called()

tli.on_train_epoch_end(state, unit)
tli._should_stop.assert_called_once()

tli = TimeLimitInterrupter(duration="00:00:42", interval="step")
tli._should_stop = Mock()

tli.on_train_epoch_end(state, unit)
tli._should_stop.assert_not_called()

tli.on_train_step_end(state, unit)
tli._should_stop.assert_called_once()

def test_interval_freq(self) -> None:
tli = TimeLimitInterrupter(
duration="00:00:42", interval="epoch", interval_freq=3
)
with patch.object(tli, "_should_stop") as should_stop_mock:
state = Mock()
unit = DummyTrainUnit(input_dim=1)

tli.on_train_epoch_end(state, unit) # epoch 0
should_stop_mock.assert_called_once()

unit.train_progress.increment_epoch() # epoch 1
tli.on_train_epoch_end(state, unit)
should_stop_mock.assert_called_once()

unit.train_progress.increment_epoch() # epoch 2
tli.on_train_epoch_end(state, unit)
should_stop_mock.assert_called_once()

unit.train_progress.increment_epoch() # epoch 3
tli.on_train_epoch_end(state, unit)
self.assertEqual(should_stop_mock.call_count, 2)
2 changes: 2 additions & 0 deletions torchtnt/framework/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .pytorch_profiler import PyTorchProfiler
from .system_resources_monitor import SystemResourcesMonitor
from .tensorboard_parameter_monitor import TensorBoardParameterMonitor
from .time_limit_interrupter import TimeLimitInterrupter
from .torch_compile import TorchCompile
from .torchsnapshot_saver import TorchSnapshotSaver
from .tqdm_progress_bar import TQDMProgressBar
Expand All @@ -32,6 +33,7 @@
"PyTorchProfiler",
"SystemResourcesMonitor",
"TensorBoardParameterMonitor",
"TimeLimitInterrupter",
"TorchCompile",
"TorchSnapshotSaver",
"TQDMProgressBar",
Expand Down
96 changes: 96 additions & 0 deletions torchtnt/framework/callbacks/time_limit_interrupter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import re
import time
from datetime import timedelta
from typing import Literal, Union

from torchtnt.framework.callback import Callback
from torchtnt.framework.state import State
from torchtnt.framework.unit import TTrainUnit
from torchtnt.utils.distributed import get_global_rank, sync_bool
from torchtnt.utils.rank_zero_log import rank_zero_info


class TimeLimitInterrupter(Callback):
"""
This callback tracks the time spent in training and stops the training loop when it exceeds the specified duration.
Args:
duration: The maximum amount of time to spend in training. Can be specified as a string in the form of DD:HH:MM (days, hours, minutes) or as a timedelta.
For example, to specify 20 hours is "00:20:00".
interval: Can be either "epoch" or "step". Determines whether to check for time limit exceeding on every epoch or step.
interval_freq: How often to check for time limit exceeding. For example, if interval is "epoch" and interval_freq is 2, then the callback will check every two epochs.
Note:
This callback uses the global process group to communicate between ranks.
"""

def __init__(
self,
duration: Union[str, timedelta],
interval: Literal["epoch", "step"] = "epoch",
interval_freq: int = 1,
) -> None:
if isinstance(duration, str):
# checks if string matches DD:HH:MM format and is within valid range
# 00 <= DD <= 99
# 00 <= HH <= 23
# 00 <= MM <= 59
pattern = r"^\d{2}:(0[0-9]|1[0-9]|2[0-3]):([0-5][0-9])$"
if not re.match(pattern, duration):
raise ValueError(
f"Invalid duration format '{duration}'. Expected format is DD:HH:MM"
)
duration_format = duration.strip().split(":")
duration_format = list(map(int, duration_format))
duration = timedelta(
days=duration_format[0],
hours=duration_format[1],
minutes=duration_format[2],
)

self._duration: float = duration.total_seconds()
self._interval = interval
self._interval_freq = interval_freq

self._rank: int = get_global_rank()
self._start_time: float = 0

def on_train_start(self, state: State, unit: TTrainUnit) -> None:
if self._rank == 0:
self._start_time = time.monotonic()

def on_train_step_end(self, state: State, unit: TTrainUnit) -> None:
if self._interval == "step":
if unit.train_progress.num_steps_completed % self._interval_freq == 0:
self._should_stop(state)

def on_train_epoch_end(self, state: State, unit: TTrainUnit) -> None:
if self._interval == "epoch":
if unit.train_progress.num_epochs_completed % self._interval_freq == 0:
self._should_stop(state)

def _should_stop(self, state: State) -> None:
"""
All ranks sync with rank 0 determine if time limit has exceeded.
If so, indicates the training loop to stop.
"""

if self._rank == 0:
time_elapsed = time.monotonic() - self._start_time
should_stop = time_elapsed >= self._duration
else:
should_stop = False

should_stop = sync_bool(should_stop, coherence_mode="rank_zero")
if should_stop:
rank_zero_info(
f"Training duration of {self._duration} seconds has exceeded. Time elapsed is {time.monotonic() - self._start_time} seconds. Stopping training."
)
state.stop()

0 comments on commit 0472f25

Please sign in to comment.