Skip to content

Commit

Permalink
iteration time logger support for MetricLogger
Browse files Browse the repository at this point in the history
Summary: Add support for generic MetricLogger in IterationTimeLogger

Differential Revision: D55333343
  • Loading branch information
galrotem authored and facebook-github-bot committed Mar 25, 2024
1 parent c2dcee9 commit ae54622
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 45 deletions.
27 changes: 20 additions & 7 deletions tests/framework/callbacks/test_iteration_time_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,12 @@

from torchtnt.framework.state import State
from torchtnt.framework.train import train
from torchtnt.utils.loggers import TensorBoardLogger
from torchtnt.utils.loggers.logger import MetricLogger


class IterationTimeLoggerTest(unittest.TestCase):
def test_iteration_time_logger_test_on_train_step_end(self) -> None:
logger = MagicMock(spec=TensorBoardLogger)
logger.writer = MagicMock(spec=SummaryWriter)
logger = MagicMock(spec=MetricLogger)
state = MagicMock(spec=State)

# Test that the recorded times are tracked separately and that we properly
Expand Down Expand Up @@ -64,7 +63,7 @@ def test_iteration_time_logger_test_on_train_step_end(self) -> None:
callback.on_eval_step_end(state, eval_unit)
callback.on_predict_step_end(state, predict_unit)

logger.writer.add_scalar.assert_has_calls(
logger.log.assert_has_calls(
[
call(
"Train Iteration Time (seconds)",
Expand All @@ -85,12 +84,26 @@ def test_with_train_epoch(self) -> None:
"""

my_unit = DummyTrainUnit(input_dim=2)
logger = MagicMock(spec=TensorBoardLogger)
logger.writer = MagicMock(spec=SummaryWriter)
logger = MagicMock(spec=MetricLogger)
callback = IterationTimeLogger(logger, moving_avg_window=1, log_every_n_steps=3)
dataloader = generate_random_dataloader(
num_samples=12, input_dim=2, batch_size=2
)
train(my_unit, dataloader, max_epochs=2, callbacks=[callback])
# 2 epochs, 6 iterations each, logging every third step
self.assertEqual(logger.writer.add_scalar.call_count, 4)
self.assertEqual(logger.log.call_count, 4)

def test_with_summary_writer(self) -> None:
"""
Test IterationTimeLogger callback with train entry point and SummaryWriter
"""

my_unit = DummyTrainUnit(input_dim=2)
logger = MagicMock(spec=SummaryWriter)
callback = IterationTimeLogger(logger, moving_avg_window=1, log_every_n_steps=3)
dataloader = generate_random_dataloader(
num_samples=12, input_dim=2, batch_size=2
)
train(my_unit, dataloader, max_epochs=2, callbacks=[callback])
# 2 epochs, 6 iterations each, logging every third step
self.assertEqual(logger.add_scalar.call_count, 4)
71 changes: 33 additions & 38 deletions torchtnt/framework/callbacks/iteration_time_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@
# pyre-strict


from typing import Optional, Union
from typing import cast, Optional, Union

from pyre_extensions import none_throws
from torch.utils.tensorboard import SummaryWriter

from torchtnt.framework.callback import Callback
from torchtnt.framework.state import State
from torchtnt.framework.unit import TEvalUnit, TPredictUnit, TTrainUnit
from torchtnt.utils.distributed import get_global_rank
from torchtnt.utils.loggers.tensorboard import TensorBoardLogger
from torchtnt.utils.distributed import rank_zero_fn
from torchtnt.utils.loggers.logger import MetricLogger
from torchtnt.utils.timer import TimerProtocol


Expand All @@ -35,23 +35,17 @@ class IterationTimeLogger(Callback):

def __init__(
self,
logger: Union[TensorBoardLogger, SummaryWriter],
logger: Union[MetricLogger, SummaryWriter],
moving_avg_window: int = 1,
log_every_n_steps: int = 1,
) -> None:
if isinstance(logger, TensorBoardLogger):
logger = logger.writer

if get_global_rank() == 0: # only write from the main rank
self._writer = none_throws(
logger, "TensorBoardLogger.writer should not be None"
)
self._logger = logger
self.moving_avg_window = moving_avg_window
self.log_every_n_steps = log_every_n_steps

@rank_zero_fn
def _log_step_metrics(
self,
writer: SummaryWriter,
metric_label: str,
iteration_timer: TimerProtocol,
step_logging_for: int,
Expand All @@ -75,38 +69,39 @@ def _log_step_metrics(
return

last_n_values = time_list[-self.moving_avg_window :]
writer.add_scalar(
human_metric_names[metric_label],
sum(last_n_values) / len(last_n_values),
step_logging_for,
)
if isinstance(self._logger, SummaryWriter):
self._logger.add_scalar(
human_metric_names[metric_label],
sum(last_n_values) / len(last_n_values),
step_logging_for,
)
else:
cast(MetricLogger, self._logger).log(
human_metric_names[metric_label],
sum(last_n_values) / len(last_n_values),
step_logging_for,
)

def on_train_step_end(self, state: State, unit: TTrainUnit) -> None:
timer = none_throws(state.train_state).iteration_timer
if writer := self._writer:
self._log_step_metrics(
writer,
"train_iteration_time",
timer,
unit.train_progress.num_steps_completed,
)
self._log_step_metrics(
"train_iteration_time",
timer,
unit.train_progress.num_steps_completed,
)

def on_eval_step_end(self, state: State, unit: TEvalUnit) -> None:
timer = none_throws(state.eval_state).iteration_timer
if writer := self._writer:
self._log_step_metrics(
writer,
"eval_iteration_time",
timer,
unit.eval_progress.num_steps_completed,
)
self._log_step_metrics(
"eval_iteration_time",
timer,
unit.eval_progress.num_steps_completed,
)

def on_predict_step_end(self, state: State, unit: TPredictUnit) -> None:
timer = none_throws(state.predict_state).iteration_timer
if writer := self._writer:
self._log_step_metrics(
writer,
"predict_iteration_time",
timer,
unit.predict_progress.num_steps_completed,
)
self._log_step_metrics(
"predict_iteration_time",
timer,
unit.predict_progress.num_steps_completed,
)

0 comments on commit ae54622

Please sign in to comment.