diff --git a/tests/framework/callbacks/test_time_limit_interrupter.py b/tests/framework/callbacks/test_time_limit_interrupter.py index 2198d48c82..e8ab05a80c 100644 --- a/tests/framework/callbacks/test_time_limit_interrupter.py +++ b/tests/framework/callbacks/test_time_limit_interrupter.py @@ -8,11 +8,12 @@ import unittest -from datetime import timedelta +from datetime import datetime, timedelta from unittest.mock import MagicMock, Mock, patch -from torchtnt.framework._test_utils import DummyTrainUnit, get_dummy_train_state +import torchtnt.framework.callbacks.time_limit_interrupter as time_limit_interrupter +from torchtnt.framework._test_utils import DummyTrainUnit, get_dummy_train_state from torchtnt.framework.callbacks.time_limit_interrupter import TimeLimitInterrupter @@ -52,6 +53,64 @@ def test_should_stop(self, mock_time_monotonic: MagicMock) -> None: tli._should_stop(state) self.assertTrue(state._should_stop) + @patch(f"{time_limit_interrupter.__name__}.datetime", wraps=datetime) + @patch("time.monotonic") + def test_should_stop_with_timestamp_limit( + self, + mock_time_monotonic: MagicMock, + mock_datetime: MagicMock, + ) -> None: + tli = TimeLimitInterrupter( + duration="00:00:25", timestamp=datetime(2024, 3, 12, 15, 25, 0).astimezone() + ) + state = get_dummy_train_state() + + mock_time_monotonic.return_value = 0 + tli.on_train_start(state, Mock()) + + # Max duration not reached, timestamp limit not reached -> Should not stop + mock_datetime.now.return_value = datetime(2024, 3, 12, 15, 0, 0) + mock_time_monotonic.return_value = 5 * 60 + tli._should_stop(state) + self.assertFalse(state._should_stop) + + # Max duration reached, timestamp limit not reached -> Should stop + mock_datetime.now.return_value = datetime(2024, 3, 12, 15, 0, 0) + mock_time_monotonic.return_value = 50 * 60 + state._should_stop = False + tli._should_stop(state) + self.assertTrue(state._should_stop) + + # Max duration not reached, timestamp limit reached -> Should stop + mock_datetime.now.return_value = datetime(2024, 3, 12, 15, 25, 0) + mock_time_monotonic.return_value = 5 * 60 + state._should_stop = False + tli._should_stop(state) + self.assertTrue(state._should_stop) + + # Test timestamp limit reached with a different timezone -> Should stop + tli = TimeLimitInterrupter( + duration="00:00:25", + timestamp=datetime.strptime( + "2024-03-13 10:00:00 +0000", "%Y-%m-%d %H:%M:%S %z" + ), + ) + state = get_dummy_train_state() + mock_time_monotonic.return_value = 0 + tli.on_train_start(state, Mock()) + mock_datetime.now.return_value = datetime.strptime( + "2024-03-13 9:00:00 -0100", "%Y-%m-%d %H:%M:%S %z" + ) + mock_time_monotonic.return_value = 5 * 60 + tli._should_stop(state) + self.assertTrue(state._should_stop) + + # Test not timezone aware datetime -> Expected error + with self.assertRaisesRegex( + ValueError, "Invalid timestamp. Expected a timezone aware datetime object." + ): + tli = TimeLimitInterrupter(duration="00:00:25", timestamp=datetime.now()) + def test_interval(self) -> None: tli = TimeLimitInterrupter(duration="00:00:42", interval="epoch") tli._should_stop = Mock() diff --git a/torchtnt/framework/callbacks/time_limit_interrupter.py b/torchtnt/framework/callbacks/time_limit_interrupter.py index 7409f7ab4f..786c6b4d0c 100644 --- a/torchtnt/framework/callbacks/time_limit_interrupter.py +++ b/torchtnt/framework/callbacks/time_limit_interrupter.py @@ -8,8 +8,8 @@ import re import time -from datetime import timedelta -from typing import Literal, Union +from datetime import datetime, timedelta +from typing import Literal, Optional, Union from torchtnt.framework.callback import Callback from torchtnt.framework.state import State @@ -27,6 +27,12 @@ class TimeLimitInterrupter(Callback): For example, to specify 20 hours is "00:20:00". interval: Can be either "epoch" or "step". Determines whether to check for time limit exceeding on every epoch or step. interval_freq: How often to check for time limit exceeding. For example, if interval is "epoch" and interval_freq is 2, then the callback will check every two epochs. + timestamp: Optional datetime object indicating the timestamp at which the training should end. The training will be stopped even if the maximum + job duration has not been reached yet. Object must be timezone aware. + + Raises: + ValueError: If the duration is not specified as a string in the form of DD:HH:MM or as a timedelta. + Or if the timestamp datetime object is not timezone aware. Note: This callback uses the global process group to communicate between ranks. @@ -38,6 +44,7 @@ def __init__( duration: Union[str, timedelta], interval: Literal["epoch", "step"] = "epoch", interval_freq: int = 1, + timestamp: Optional[datetime] = None, ) -> None: if isinstance(duration, str): # checks if string matches DD:HH:MM format and is within valid range @@ -64,6 +71,12 @@ def __init__( self._rank: int = get_global_rank() self._start_time: float = 0 + self._timestamp = timestamp + if timestamp and not timestamp.tzinfo: + raise ValueError( + "Invalid timestamp. Expected a timezone aware datetime object." + ) + def on_train_start(self, state: State, unit: TTrainUnit) -> None: if self._rank == 0: self._start_time = time.monotonic() @@ -80,19 +93,32 @@ def on_train_epoch_end(self, state: State, unit: TTrainUnit) -> None: def _should_stop(self, state: State) -> None: """ - All ranks sync with rank 0 determine if time limit has exceeded. + Check the max duration and the max timestamp to determine if training should stop. + All ranks sync with rank 0 to determine if any of the stop conditions are met. If so, indicates the training loop to stop. """ + past_timestamp_limit = False + past_duration_limit = False if self._rank == 0: + if timestamp := self._timestamp: + past_timestamp_limit = datetime.now().astimezone() >= timestamp + time_elapsed = time.monotonic() - self._start_time - should_stop = time_elapsed >= self._duration - else: - should_stop = False - - should_stop = sync_bool(should_stop, coherence_mode="rank_zero") - if should_stop: - rank_zero_info( - f"Training duration of {self._duration} seconds has exceeded. Time elapsed is {time.monotonic() - self._start_time} seconds. Stopping training." - ) + past_duration_limit = time_elapsed >= self._duration + + local_should_stop = past_timestamp_limit or past_duration_limit + global_should_stop = sync_bool(local_should_stop, coherence_mode="rank_zero") + + if global_should_stop: + reason = "" + if past_timestamp_limit: + reason = f"Training timestamp limit {self._timestamp} has been reached." + elif past_duration_limit: + reason = ( + f"Training duration of {self._duration} seconds has exceeded. " + f"Time elapsed is {time.monotonic() - self._start_time} seconds." + ) + + rank_zero_info(f"{reason} Stopping training.") state.stop()