Skip to content

Commit

Permalink
Set duration parameter as optional in TimeLimitInterrupter callback (#…
Browse files Browse the repository at this point in the history
…744)

Summary: Pull Request resolved: #744

Reviewed By: JKSenthil

Differential Revision: D55028165

fbshipit-source-id: 0b9e6ad94e295d1853d5675ca50fcfcbda1c67cd
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Mar 18, 2024
1 parent 5cef47a commit c69f3d5
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 11 deletions.
59 changes: 56 additions & 3 deletions tests/framework/callbacks/test_time_limit_interrupter.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,8 @@ def test_should_stop_with_timestamp_limit(
tli._should_stop(state)
self.assertTrue(state._should_stop)

# Test timestamp limit reached with a different timezone -> Should stop
# Test timestamp limit reached with a different timezone, no duration -> Should stop
tli = TimeLimitInterrupter(
duration="00:00:25",
timestamp=datetime.strptime(
"2024-03-13 10:00:00 +0000", "%Y-%m-%d %H:%M:%S %z"
),
Expand All @@ -101,7 +100,6 @@ def test_should_stop_with_timestamp_limit(
mock_datetime.now.return_value = datetime.strptime(
"2024-03-13 9:00:00 -0100", "%Y-%m-%d %H:%M:%S %z"
)
mock_time_monotonic.return_value = 5 * 60
tli._should_stop(state)
self.assertTrue(state._should_stop)

Expand All @@ -111,6 +109,61 @@ def test_should_stop_with_timestamp_limit(
):
tli = TimeLimitInterrupter(duration="00:00:25", timestamp=datetime.now())

@patch(f"{time_limit_interrupter.__name__}.datetime", wraps=datetime)
@patch("time.monotonic")
def test_should_stop_optional_params(
self,
mock_time_monotonic: MagicMock,
mock_datetime: MagicMock,
) -> None:
# Test only input duration
tli = TimeLimitInterrupter(duration="00:00:42")
self.assertEqual(tli._duration, 42 * 60)
self.assertIsNone(tli._timestamp)

state = get_dummy_train_state()
mock_time_monotonic.return_value = 0
tli.on_train_start(state, Mock())

mock_time_monotonic.return_value = 42 * 60
tli._should_stop(state)
self.assertTrue(state._should_stop)

# Test only input timestamp
tms = datetime(2024, 3, 12, 15, 25, 0).astimezone()
tli = TimeLimitInterrupter(timestamp=tms)
self.assertEqual(tli._timestamp, tms)
self.assertIsNone(tli._duration)

state = get_dummy_train_state()
mock_time_monotonic.return_value = 0
tli.on_train_start(state, Mock())

mock_datetime.now.return_value = tms
tli._should_stop(state)
self.assertTrue(state._should_stop)

# Test input both duration and timestamp
mock_time_monotonic.return_value = 0
tms = datetime.now().astimezone()
tli = TimeLimitInterrupter(timestamp=tms, duration="00:00:42")
self.assertEqual(tli._timestamp, tms)
self.assertEqual(tli._duration, 42 * 60)

# Test no input error
with self.assertRaisesRegex(
ValueError,
"Invalid parameters. Expected at least one of duration or timestamp to be specified.",
):
TimeLimitInterrupter()

# Test empty duration i.e. not input error
with self.assertRaisesRegex(
ValueError,
"Invalid parameters. Expected at least one of duration or timestamp to be specified.",
):
TimeLimitInterrupter(duration="")

def test_interval(self) -> None:
tli = TimeLimitInterrupter(duration="00:00:42", interval="epoch")
tli._should_stop = Mock()
Expand Down
28 changes: 20 additions & 8 deletions torchtnt/framework/callbacks/time_limit_interrupter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,22 @@

class TimeLimitInterrupter(Callback):
"""
This callback tracks the time spent in training and stops the training loop when it exceeds the specified duration.
This callback tracks the time spent in training and stops the training loop when a time limit is reached. It is possible to define a maximum duration for the training job,
and/or an absolute timestamp limit. At least one of them should be provided. If both are provided, the callback will stop the training loop when the first condition is met.
Args:
duration: The maximum amount of time to spend in training. Can be specified as a string in the form of DD:HH:MM (days, hours, minutes) or as a timedelta.
duration: Optional, the maximum amount of time to spend in training. Can be specified as a string in the form of DD:HH:MM (days, hours, minutes) or as a timedelta.
For example, to specify 20 hours is "00:20:00".
interval: Can be either "epoch" or "step". Determines whether to check for time limit exceeding on every epoch or step.
interval_freq: How often to check for time limit exceeding. For example, if interval is "epoch" and interval_freq is 2, then the callback will check every two epochs.
timestamp: Optional datetime object indicating the timestamp at which the training should end. The training will be stopped even if the maximum
job duration has not been reached yet. Object must be timezone aware.
Raises:
ValueError: If the duration is not specified as a string in the form of DD:HH:MM or as a timedelta.
Or if the timestamp datetime object is not timezone aware.
ValueError:
- If the duration is not specified as a string in the form of DD:HH:MM or as a timedelta.
- If the timestamp datetime object is not timezone aware.
- If both duration and timestamp are None (i.e. at least one must be specified).
Note:
This callback uses the global process group to communicate between ranks.
Expand All @@ -41,11 +44,16 @@ class TimeLimitInterrupter(Callback):

def __init__(
self,
duration: Union[str, timedelta],
duration: Optional[Union[str, timedelta]] = None,
interval: Literal["epoch", "step"] = "epoch",
interval_freq: int = 1,
timestamp: Optional[datetime] = None,
) -> None:
if not (duration or timestamp):
raise ValueError(
"Invalid parameters. Expected at least one of duration or timestamp to be specified."
)

if isinstance(duration, str):
# checks if string matches DD:HH:MM format and is within valid range
# 00 <= DD <= 99
Expand All @@ -64,7 +72,10 @@ def __init__(
minutes=duration_format[2],
)

self._duration: float = duration.total_seconds()
self._duration: Optional[float] = None
if duration:
self._duration = duration.total_seconds()

self._interval = interval
self._interval_freq = interval_freq

Expand Down Expand Up @@ -104,8 +115,9 @@ def _should_stop(self, state: State) -> None:
if timestamp := self._timestamp:
past_timestamp_limit = datetime.now().astimezone() >= timestamp

time_elapsed = time.monotonic() - self._start_time
past_duration_limit = time_elapsed >= self._duration
if duration := self._duration:
time_elapsed = time.monotonic() - self._start_time
past_duration_limit = time_elapsed >= duration

local_should_stop = past_timestamp_limit or past_duration_limit
global_should_stop = sync_bool(local_should_stop, coherence_mode="rank_zero")
Expand Down

0 comments on commit c69f3d5

Please sign in to comment.