Skip to content

Commit

Permalink
Add timestamp limit to TimeLimitInterrupter
Browse files Browse the repository at this point in the history
Differential Revision: D54860033
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Mar 13, 2024
1 parent 67159a4 commit cd0203d
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 14 deletions.
57 changes: 55 additions & 2 deletions tests/framework/callbacks/test_time_limit_interrupter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@


import unittest
from datetime import timedelta
from datetime import datetime, timedelta
from unittest.mock import MagicMock, Mock, patch

from torchtnt.framework._test_utils import DummyTrainUnit, get_dummy_train_state
import torchtnt.framework.callbacks.time_limit_interrupter as time_limit_interrupter

from torchtnt.framework._test_utils import DummyTrainUnit, get_dummy_train_state
from torchtnt.framework.callbacks.time_limit_interrupter import TimeLimitInterrupter


Expand Down Expand Up @@ -52,6 +53,58 @@ def test_should_stop(self, mock_time_monotonic: MagicMock) -> None:
tli._should_stop(state)
self.assertTrue(state._should_stop)

@patch(f"{time_limit_interrupter.__name__}.datetime", wraps=datetime)
@patch("time.monotonic")
def test_should_stop_with_timestamp_limit(
self,
mock_time_monotonic: MagicMock,
mock_datetime: MagicMock,
) -> None:
tli = TimeLimitInterrupter(
duration="00:00:25", timestamp=datetime(2024, 3, 12, 15, 25, 0)
)
state = get_dummy_train_state()

mock_time_monotonic.return_value = 0
tli.on_train_start(state, Mock())

# Max duration not reached, timestamp limit not reached -> Should not stop
mock_datetime.now.return_value = datetime(2024, 3, 12, 15, 0, 0)
mock_time_monotonic.return_value = 5 * 60
tli._should_stop(state)
self.assertFalse(state._should_stop)

# Max duration reached, timestamp limit not reached -> Should stop
mock_datetime.now.return_value = datetime(2024, 3, 12, 15, 0, 0)
mock_time_monotonic.return_value = 50 * 60
state._should_stop = False
tli._should_stop(state)
self.assertTrue(state._should_stop)

# Max duration not reached, timestamp limit reached -> Should stop
mock_datetime.now.return_value = datetime(2024, 3, 12, 15, 25, 0)
mock_time_monotonic.return_value = 5 * 60
state._should_stop = False
tli._should_stop(state)
self.assertTrue(state._should_stop)

# Test timestamp limit reached with a timezone aware input -> Should stop
tli = TimeLimitInterrupter(
duration="00:00:25",
timestamp=datetime.strptime(
"2024-03-13 9:00:00 -0100", "%Y-%m-%d %H:%M:%S %z"
),
)
state = get_dummy_train_state()
mock_time_monotonic.return_value = 0
tli.on_train_start(state, Mock())
mock_datetime.now.return_value = datetime.strptime(
"2024-03-13 10:00:00 +0000", "%Y-%m-%d %H:%M:%S %z"
)
mock_time_monotonic.return_value = 5 * 60
tli._should_stop(state)
self.assertTrue(state._should_stop)

def test_interval(self) -> None:
tli = TimeLimitInterrupter(duration="00:00:42", interval="epoch")
tli._should_stop = Mock()
Expand Down
44 changes: 32 additions & 12 deletions torchtnt/framework/callbacks/time_limit_interrupter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

import re
import time
from datetime import timedelta
from typing import Literal, Union
from datetime import datetime, timedelta
from typing import Literal, Optional, Union

from torchtnt.framework.callback import Callback
from torchtnt.framework.state import State
Expand All @@ -27,6 +27,8 @@ class TimeLimitInterrupter(Callback):
For example, to specify 20 hours is "00:20:00".
interval: Can be either "epoch" or "step". Determines whether to check for time limit exceeding on every epoch or step.
interval_freq: How often to check for time limit exceeding. For example, if interval is "epoch" and interval_freq is 2, then the callback will check every two epochs.
timestamp: Optional datetime object indicating the timestamp at which the training should end. The training will be stopped at the specified timestamp even if the maximum
job duration has not been reached yet. Local timezone is assumed if the provided object is not timezone aware.
Note:
This callback uses the global process group to communicate between ranks.
Expand All @@ -38,6 +40,7 @@ def __init__(
duration: Union[str, timedelta],
interval: Literal["epoch", "step"] = "epoch",
interval_freq: int = 1,
timestamp: Optional[datetime] = None,
) -> None:
if isinstance(duration, str):
# checks if string matches DD:HH:MM format and is within valid range
Expand All @@ -64,6 +67,10 @@ def __init__(
self._rank: int = get_global_rank()
self._start_time: float = 0

self._timestamp: Optional[datetime] = None
if _timestamp := timestamp:
self._timestamp = _timestamp.astimezone()

def on_train_start(self, state: State, unit: TTrainUnit) -> None:
if self._rank == 0:
self._start_time = time.monotonic()
Expand All @@ -80,19 +87,32 @@ def on_train_epoch_end(self, state: State, unit: TTrainUnit) -> None:

def _should_stop(self, state: State) -> None:
"""
All ranks sync with rank 0 determine if time limit has exceeded.
Check the max duration and the max timestamp to determine if training should stop.
All ranks sync with rank 0 to determine if any of the stop conditions are met.
If so, indicates the training loop to stop.
"""
past_timestamp_limit = False
past_duration_limit = False

if self._rank == 0:
if timestamp := self._timestamp:
past_timestamp_limit = datetime.now().astimezone() >= timestamp

time_elapsed = time.monotonic() - self._start_time
should_stop = time_elapsed >= self._duration
else:
should_stop = False

should_stop = sync_bool(should_stop, coherence_mode="rank_zero")
if should_stop:
rank_zero_info(
f"Training duration of {self._duration} seconds has exceeded. Time elapsed is {time.monotonic() - self._start_time} seconds. Stopping training."
)
past_duration_limit = time_elapsed >= self._duration

local_should_stop = past_timestamp_limit or past_duration_limit
global_should_stop = sync_bool(local_should_stop, coherence_mode="rank_zero")

if global_should_stop:
reason = ""
if past_timestamp_limit:
reason = f"Training timestamp limit {self._timestamp} has been reached."
elif past_duration_limit:
reason = (
f"Training duration of {self._duration} seconds has exceeded."
f"Time elapsed is {time.monotonic() - self._start_time} seconds."
)

rank_zero_info(f"{reason} Stopping training.")
state.stop()

0 comments on commit cd0203d

Please sign in to comment.