-
Notifications
You must be signed in to change notification settings - Fork 282
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add time limit interrupter callback (#725)
Summary: Adds callback to interrupt training loop when specified duration is exceeded Reviewed By: galrotem Differential Revision: D54600265
- Loading branch information
1 parent
ac10dc8
commit 0472f25
Showing
4 changed files
with
195 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
import unittest | ||
from datetime import timedelta | ||
from unittest.mock import MagicMock, Mock, patch | ||
|
||
from torchtnt.framework._test_utils import DummyTrainUnit, get_dummy_train_state | ||
|
||
from torchtnt.framework.callbacks.time_limit_interrupter import TimeLimitInterrupter | ||
|
||
|
||
class TimeLimitInterrupterTest(unittest.TestCase): | ||
def test_str_to_timedelta_conversion(self) -> None: | ||
tli = TimeLimitInterrupter(duration="02:10:20") | ||
self.assertEqual( | ||
tli._duration, timedelta(days=2, hours=10, minutes=20).total_seconds() | ||
) | ||
|
||
with self.assertRaisesRegex(ValueError, "Invalid duration format"): | ||
tli = TimeLimitInterrupter(duration="2:10:20") | ||
|
||
with self.assertRaisesRegex(ValueError, "Invalid duration format"): | ||
tli = TimeLimitInterrupter(duration="02:24:20") | ||
|
||
with self.assertRaisesRegex(ValueError, "Invalid duration format"): | ||
tli = TimeLimitInterrupter(duration="02:23:60") | ||
|
||
@patch("time.monotonic") | ||
def test_should_stop(self, mock_time_monotonic: MagicMock) -> None: | ||
for duration in ("00:00:42", timedelta(minutes=42)): | ||
tli = TimeLimitInterrupter(duration=duration) | ||
state = get_dummy_train_state() | ||
|
||
# setup start time | ||
mock_time_monotonic.return_value = 0 | ||
tli.on_train_start(state, Mock()) | ||
|
||
# check that we don't stop before duration | ||
mock_time_monotonic.return_value = 41 * 60 | ||
tli._should_stop(state) | ||
self.assertFalse(state._should_stop) | ||
|
||
# check that we stop after duration | ||
mock_time_monotonic.return_value = 42 * 60 | ||
tli._should_stop(state) | ||
self.assertTrue(state._should_stop) | ||
|
||
def test_interval(self) -> None: | ||
tli = TimeLimitInterrupter(duration="00:00:42", interval="epoch") | ||
tli._should_stop = Mock() | ||
|
||
state = Mock() | ||
unit = DummyTrainUnit(input_dim=1) | ||
|
||
tli.on_train_step_end(state, unit) | ||
tli._should_stop.assert_not_called() | ||
|
||
tli.on_train_epoch_end(state, unit) | ||
tli._should_stop.assert_called_once() | ||
|
||
tli = TimeLimitInterrupter(duration="00:00:42", interval="step") | ||
tli._should_stop = Mock() | ||
|
||
tli.on_train_epoch_end(state, unit) | ||
tli._should_stop.assert_not_called() | ||
|
||
tli.on_train_step_end(state, unit) | ||
tli._should_stop.assert_called_once() | ||
|
||
def test_interval_freq(self) -> None: | ||
tli = TimeLimitInterrupter( | ||
duration="00:00:42", interval="epoch", interval_freq=3 | ||
) | ||
with patch.object(tli, "_should_stop") as should_stop_mock: | ||
state = Mock() | ||
unit = DummyTrainUnit(input_dim=1) | ||
|
||
tli.on_train_epoch_end(state, unit) # epoch 0 | ||
should_stop_mock.assert_called_once() | ||
|
||
unit.train_progress.increment_epoch() # epoch 1 | ||
tli.on_train_epoch_end(state, unit) | ||
should_stop_mock.assert_called_once() | ||
|
||
unit.train_progress.increment_epoch() # epoch 2 | ||
tli.on_train_epoch_end(state, unit) | ||
should_stop_mock.assert_called_once() | ||
|
||
unit.train_progress.increment_epoch() # epoch 3 | ||
tli.on_train_epoch_end(state, unit) | ||
self.assertEqual(should_stop_mock.call_count, 2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import re | ||
import time | ||
from datetime import timedelta | ||
from typing import Literal, Union | ||
|
||
from torchtnt.framework.callback import Callback | ||
from torchtnt.framework.state import State | ||
from torchtnt.framework.unit import TTrainUnit | ||
from torchtnt.utils.distributed import get_global_rank, sync_bool | ||
from torchtnt.utils.rank_zero_log import rank_zero_info | ||
|
||
|
||
class TimeLimitInterrupter(Callback): | ||
""" | ||
This callback tracks the time spent in training and stops the training loop when it exceeds the specified duration. | ||
Args: | ||
duration: The maximum amount of time to spend in training. Can be specified as a string in the form of DD:HH:MM (days, hours, minutes) or as a timedelta. | ||
For example, to specify 20 hours is "00:20:00". | ||
interval: Can be either "epoch" or "step". Determines whether to check for time limit exceeding on every epoch or step. | ||
interval_freq: How often to check for time limit exceeding. For example, if interval is "epoch" and interval_freq is 2, then the callback will check every two epochs. | ||
Note: | ||
This callback uses the global process group to communicate between ranks. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
duration: Union[str, timedelta], | ||
interval: Literal["epoch", "step"] = "epoch", | ||
interval_freq: int = 1, | ||
) -> None: | ||
if isinstance(duration, str): | ||
# checks if string matches DD:HH:MM format and is within valid range | ||
# 00 <= DD <= 99 | ||
# 00 <= HH <= 23 | ||
# 00 <= MM <= 59 | ||
pattern = r"^\d{2}:(0[0-9]|1[0-9]|2[0-3]):([0-5][0-9])$" | ||
if not re.match(pattern, duration): | ||
raise ValueError( | ||
f"Invalid duration format '{duration}'. Expected format is DD:HH:MM" | ||
) | ||
duration_format = duration.strip().split(":") | ||
duration_format = list(map(int, duration_format)) | ||
duration = timedelta( | ||
days=duration_format[0], | ||
hours=duration_format[1], | ||
minutes=duration_format[2], | ||
) | ||
|
||
self._duration: float = duration.total_seconds() | ||
self._interval = interval | ||
self._interval_freq = interval_freq | ||
|
||
self._rank: int = get_global_rank() | ||
self._start_time: float = 0 | ||
|
||
def on_train_start(self, state: State, unit: TTrainUnit) -> None: | ||
if self._rank == 0: | ||
self._start_time = time.monotonic() | ||
|
||
def on_train_step_end(self, state: State, unit: TTrainUnit) -> None: | ||
if self._interval == "step": | ||
if unit.train_progress.num_steps_completed % self._interval_freq == 0: | ||
self._should_stop(state) | ||
|
||
def on_train_epoch_end(self, state: State, unit: TTrainUnit) -> None: | ||
if self._interval == "epoch": | ||
if unit.train_progress.num_epochs_completed % self._interval_freq == 0: | ||
self._should_stop(state) | ||
|
||
def _should_stop(self, state: State) -> None: | ||
""" | ||
All ranks sync with rank 0 determine if time limit has exceeded. | ||
If so, indicates the training loop to stop. | ||
""" | ||
|
||
if self._rank == 0: | ||
time_elapsed = time.monotonic() - self._start_time | ||
should_stop = time_elapsed >= self._duration | ||
else: | ||
should_stop = False | ||
|
||
should_stop = sync_bool(should_stop, coherence_mode="rank_zero") | ||
if should_stop: | ||
rank_zero_info( | ||
f"Training duration of {self._duration} seconds has exceeded. Time elapsed is {time.monotonic() - self._start_time} seconds. Stopping training." | ||
) | ||
state.stop() |