Skip to content

Commit

Permalink
add warmup steps to time wait for batch logger (#838)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #838

Adding warmup steps to different performance loggers

Reviewed By: diego-urgell

Differential Revision: D57595989

fbshipit-source-id: 496edd9c3c3f92a2454eb9ae9c9e3bf7496d670c
  • Loading branch information
galrotem authored and facebook-github-bot committed May 21, 2024
1 parent 41b9918 commit ba310cf
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 6 deletions.
28 changes: 23 additions & 5 deletions tests/framework/callbacks/test_time_wait_for_batch_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from torchtnt.framework.state import EntryPoint, PhaseState, State
from torchtnt.framework.train import _train_impl
from torchtnt.utils.loggers.logger import MetricLogger
from torchtnt.utils.timer import TimerProtocol
from torchtnt.utils.timer import Timer, TimerProtocol


class TimeWaitForBatchLoggerTest(unittest.TestCase):
Expand Down Expand Up @@ -119,10 +119,28 @@ def test_with_predict(self) -> None:
],
)

def test_invalid_log_every_n_steps(self) -> None:
def test_warmup_steps(self) -> None:
logger = MagicMock(spec=MetricLogger)
callback = TimeWaitForBatchLogger(logger=logger, warmup_steps=1)
timer = Timer()
timer.recorded_durations = {"data_wait_time": [1, 2]}

# ensure that we don't log for the first step
callback._log_step_metrics(timer=timer, label="foo", step=1)
logger.log.assert_not_called()

# second step should log
callback._log_step_metrics(timer=timer, label="foo", step=2)
self.assertEqual(logger.log.call_count, 1)

def test_invalid_params(self) -> None:
logger_mock = MagicMock(spec=MetricLogger)
with self.assertRaisesRegex(
ValueError, "log_every_n_steps must be at least 1, got 0"
):
TimeWaitForBatchLogger(
logger=MagicMock(spec=MetricLogger), log_every_n_steps=0
)
TimeWaitForBatchLogger(logger=logger_mock, log_every_n_steps=0)

with self.assertRaisesRegex(
ValueError, "warmup_steps must be at least 0, got -1"
):
TimeWaitForBatchLogger(logger=logger_mock, warmup_steps=-1)
12 changes: 11 additions & 1 deletion torchtnt/framework/callbacks/time_wait_for_batch_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,16 @@ class TimeWaitForBatchLogger(Callback):
Args:
logger: Either a subclass of :class:`torchtnt.utils.loggers.logger.MetricLogger`
or a :class:`torch.utils.tensorboard.SummaryWriter` instance.
log_every_n_steps: an optional int to control the log frequency
log_every_n_steps: an int to control the log frequency. Default is 1.
warmup_steps: an int to control the number of warmup steps. We will start logging only after the amount of warmup steps were completed. Default is 0.
"""

def __init__(
self,
logger: Union[MetricLogger, SummaryWriter],
*,
log_every_n_steps: int = 1,
warmup_steps: int = 0,
) -> None:
self._logger = logger
if log_every_n_steps < 1:
Expand All @@ -40,6 +43,10 @@ def __init__(
)
self._log_every_n_steps = log_every_n_steps

if warmup_steps < 0:
raise ValueError(f"warmup_steps must be at least 0, got {warmup_steps}")
self._warmup_steps = warmup_steps

@rank_zero_fn
def _log_step_metrics(
self,
Expand All @@ -48,6 +55,9 @@ def _log_step_metrics(
label: str,
step: int,
) -> None:
if step <= self._warmup_steps:
return

if step % self._log_every_n_steps != 0:
return

Expand Down

0 comments on commit ba310cf

Please sign in to comment.