Skip to content

Commit

Permalink
Add timestamp limit to TimeLimitInterrupter (#738)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #738

Reviewed By: JKSenthil

Differential Revision: D54860033

fbshipit-source-id: e8002410d9664fd5ac84a13810df707e0848ffbc
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Mar 15, 2024
1 parent e81d637 commit 52424ce
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 14 deletions.
63 changes: 61 additions & 2 deletions tests/framework/callbacks/test_time_limit_interrupter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@


import unittest
from datetime import timedelta
from datetime import datetime, timedelta
from unittest.mock import MagicMock, Mock, patch

from torchtnt.framework._test_utils import DummyTrainUnit, get_dummy_train_state
import torchtnt.framework.callbacks.time_limit_interrupter as time_limit_interrupter

from torchtnt.framework._test_utils import DummyTrainUnit, get_dummy_train_state
from torchtnt.framework.callbacks.time_limit_interrupter import TimeLimitInterrupter


Expand Down Expand Up @@ -52,6 +53,64 @@ def test_should_stop(self, mock_time_monotonic: MagicMock) -> None:
tli._should_stop(state)
self.assertTrue(state._should_stop)

@patch(f"{time_limit_interrupter.__name__}.datetime", wraps=datetime)
@patch("time.monotonic")
def test_should_stop_with_timestamp_limit(
self,
mock_time_monotonic: MagicMock,
mock_datetime: MagicMock,
) -> None:
tli = TimeLimitInterrupter(
duration="00:00:25", timestamp=datetime(2024, 3, 12, 15, 25, 0).astimezone()
)
state = get_dummy_train_state()

mock_time_monotonic.return_value = 0
tli.on_train_start(state, Mock())

# Max duration not reached, timestamp limit not reached -> Should not stop
mock_datetime.now.return_value = datetime(2024, 3, 12, 15, 0, 0)
mock_time_monotonic.return_value = 5 * 60
tli._should_stop(state)
self.assertFalse(state._should_stop)

# Max duration reached, timestamp limit not reached -> Should stop
mock_datetime.now.return_value = datetime(2024, 3, 12, 15, 0, 0)
mock_time_monotonic.return_value = 50 * 60
state._should_stop = False
tli._should_stop(state)
self.assertTrue(state._should_stop)

# Max duration not reached, timestamp limit reached -> Should stop
mock_datetime.now.return_value = datetime(2024, 3, 12, 15, 25, 0)
mock_time_monotonic.return_value = 5 * 60
state._should_stop = False
tli._should_stop(state)
self.assertTrue(state._should_stop)

# Test timestamp limit reached with a different timezone -> Should stop
tli = TimeLimitInterrupter(
duration="00:00:25",
timestamp=datetime.strptime(
"2024-03-13 10:00:00 +0000", "%Y-%m-%d %H:%M:%S %z"
),
)
state = get_dummy_train_state()
mock_time_monotonic.return_value = 0
tli.on_train_start(state, Mock())
mock_datetime.now.return_value = datetime.strptime(
"2024-03-13 9:00:00 -0100", "%Y-%m-%d %H:%M:%S %z"
)
mock_time_monotonic.return_value = 5 * 60
tli._should_stop(state)
self.assertTrue(state._should_stop)

# Test not timezone aware datetime -> Expected error
with self.assertRaisesRegex(
ValueError, "Invalid timestamp. Expected a timezone aware datetime object."
):
tli = TimeLimitInterrupter(duration="00:00:25", timestamp=datetime.now())

def test_interval(self) -> None:
tli = TimeLimitInterrupter(duration="00:00:42", interval="epoch")
tli._should_stop = Mock()
Expand Down
50 changes: 38 additions & 12 deletions torchtnt/framework/callbacks/time_limit_interrupter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

import re
import time
from datetime import timedelta
from typing import Literal, Union
from datetime import datetime, timedelta
from typing import Literal, Optional, Union

from torchtnt.framework.callback import Callback
from torchtnt.framework.state import State
Expand All @@ -27,6 +27,12 @@ class TimeLimitInterrupter(Callback):
For example, to specify 20 hours is "00:20:00".
interval: Can be either "epoch" or "step". Determines whether to check for time limit exceeding on every epoch or step.
interval_freq: How often to check for time limit exceeding. For example, if interval is "epoch" and interval_freq is 2, then the callback will check every two epochs.
timestamp: Optional datetime object indicating the timestamp at which the training should end. The training will be stopped even if the maximum
job duration has not been reached yet. Object must be timezone aware.
Raises:
ValueError: If the duration is not specified as a string in the form of DD:HH:MM or as a timedelta.
Or if the timestamp datetime object is not timezone aware.
Note:
This callback uses the global process group to communicate between ranks.
Expand All @@ -38,6 +44,7 @@ def __init__(
duration: Union[str, timedelta],
interval: Literal["epoch", "step"] = "epoch",
interval_freq: int = 1,
timestamp: Optional[datetime] = None,
) -> None:
if isinstance(duration, str):
# checks if string matches DD:HH:MM format and is within valid range
Expand All @@ -64,6 +71,12 @@ def __init__(
self._rank: int = get_global_rank()
self._start_time: float = 0

self._timestamp = timestamp
if timestamp and not timestamp.tzinfo:
raise ValueError(
"Invalid timestamp. Expected a timezone aware datetime object."
)

def on_train_start(self, state: State, unit: TTrainUnit) -> None:
if self._rank == 0:
self._start_time = time.monotonic()
Expand All @@ -80,19 +93,32 @@ def on_train_epoch_end(self, state: State, unit: TTrainUnit) -> None:

def _should_stop(self, state: State) -> None:
"""
All ranks sync with rank 0 determine if time limit has exceeded.
Check the max duration and the max timestamp to determine if training should stop.
All ranks sync with rank 0 to determine if any of the stop conditions are met.
If so, indicates the training loop to stop.
"""
past_timestamp_limit = False
past_duration_limit = False

if self._rank == 0:
if timestamp := self._timestamp:
past_timestamp_limit = datetime.now().astimezone() >= timestamp

time_elapsed = time.monotonic() - self._start_time
should_stop = time_elapsed >= self._duration
else:
should_stop = False

should_stop = sync_bool(should_stop, coherence_mode="rank_zero")
if should_stop:
rank_zero_info(
f"Training duration of {self._duration} seconds has exceeded. Time elapsed is {time.monotonic() - self._start_time} seconds. Stopping training."
)
past_duration_limit = time_elapsed >= self._duration

local_should_stop = past_timestamp_limit or past_duration_limit
global_should_stop = sync_bool(local_should_stop, coherence_mode="rank_zero")

if global_should_stop:
reason = ""
if past_timestamp_limit:
reason = f"Training timestamp limit {self._timestamp} has been reached."
elif past_duration_limit:
reason = (
f"Training duration of {self._duration} seconds has exceeded. "
f"Time elapsed is {time.monotonic() - self._start_time} seconds."
)

rank_zero_info(f"{reason} Stopping training.")
state.stop()

0 comments on commit 52424ce

Please sign in to comment.