From 5a9c4a531f9b01bc8d5ac4998860178484cfe00d Mon Sep 17 00:00:00 2001 From: Gal Rotem Date: Mon, 25 Mar 2024 13:13:04 -0700 Subject: [PATCH] iteration time logger support for MetricLogger (#759) Summary: Add support for generic MetricLogger in IterationTimeLogger Differential Revision: D55333343 --- .../callbacks/test_iteration_time_logger.py | 27 +++++-- .../callbacks/iteration_time_logger.py | 71 +++++++++---------- 2 files changed, 53 insertions(+), 45 deletions(-) diff --git a/tests/framework/callbacks/test_iteration_time_logger.py b/tests/framework/callbacks/test_iteration_time_logger.py index 4a29b4707b..fae7d69ecb 100644 --- a/tests/framework/callbacks/test_iteration_time_logger.py +++ b/tests/framework/callbacks/test_iteration_time_logger.py @@ -21,13 +21,12 @@ from torchtnt.framework.state import State from torchtnt.framework.train import train -from torchtnt.utils.loggers import TensorBoardLogger +from torchtnt.utils.loggers.logger import MetricLogger class IterationTimeLoggerTest(unittest.TestCase): def test_iteration_time_logger_test_on_train_step_end(self) -> None: - logger = MagicMock(spec=TensorBoardLogger) - logger.writer = MagicMock(spec=SummaryWriter) + logger = MagicMock(spec=MetricLogger) state = MagicMock(spec=State) # Test that the recorded times are tracked separately and that we properly @@ -64,7 +63,7 @@ def test_iteration_time_logger_test_on_train_step_end(self) -> None: callback.on_eval_step_end(state, eval_unit) callback.on_predict_step_end(state, predict_unit) - logger.writer.add_scalar.assert_has_calls( + logger.log.assert_has_calls( [ call( "Train Iteration Time (seconds)", @@ -85,12 +84,26 @@ def test_with_train_epoch(self) -> None: """ my_unit = DummyTrainUnit(input_dim=2) - logger = MagicMock(spec=TensorBoardLogger) - logger.writer = MagicMock(spec=SummaryWriter) + logger = MagicMock(spec=MetricLogger) callback = IterationTimeLogger(logger, moving_avg_window=1, log_every_n_steps=3) dataloader = generate_random_dataloader( num_samples=12, input_dim=2, batch_size=2 ) train(my_unit, dataloader, max_epochs=2, callbacks=[callback]) # 2 epochs, 6 iterations each, logging every third step - self.assertEqual(logger.writer.add_scalar.call_count, 4) + self.assertEqual(logger.log.call_count, 4) + + def test_with_summary_writer(self) -> None: + """ + Test IterationTimeLogger callback with train entry point and SummaryWriter + """ + + my_unit = DummyTrainUnit(input_dim=2) + logger = MagicMock(spec=SummaryWriter) + callback = IterationTimeLogger(logger, moving_avg_window=1, log_every_n_steps=3) + dataloader = generate_random_dataloader( + num_samples=12, input_dim=2, batch_size=2 + ) + train(my_unit, dataloader, max_epochs=2, callbacks=[callback]) + # 2 epochs, 6 iterations each, logging every third step + self.assertEqual(logger.add_scalar.call_count, 4) diff --git a/torchtnt/framework/callbacks/iteration_time_logger.py b/torchtnt/framework/callbacks/iteration_time_logger.py index 541ec8b1a7..4d93b43f7c 100644 --- a/torchtnt/framework/callbacks/iteration_time_logger.py +++ b/torchtnt/framework/callbacks/iteration_time_logger.py @@ -7,7 +7,7 @@ # pyre-strict -from typing import Optional, Union +from typing import cast, Optional, Union from pyre_extensions import none_throws from torch.utils.tensorboard import SummaryWriter @@ -15,8 +15,8 @@ from torchtnt.framework.callback import Callback from torchtnt.framework.state import State from torchtnt.framework.unit import TEvalUnit, TPredictUnit, TTrainUnit -from torchtnt.utils.distributed import get_global_rank -from torchtnt.utils.loggers.tensorboard import TensorBoardLogger +from torchtnt.utils.distributed import rank_zero_fn +from torchtnt.utils.loggers.logger import MetricLogger from torchtnt.utils.timer import TimerProtocol @@ -35,23 +35,17 @@ class IterationTimeLogger(Callback): def __init__( self, - logger: Union[TensorBoardLogger, SummaryWriter], + logger: Union[MetricLogger, SummaryWriter], moving_avg_window: int = 1, log_every_n_steps: int = 1, ) -> None: - if isinstance(logger, TensorBoardLogger): - logger = logger.writer - - if get_global_rank() == 0: # only write from the main rank - self._writer = none_throws( - logger, "TensorBoardLogger.writer should not be None" - ) + self._logger = logger self.moving_avg_window = moving_avg_window self.log_every_n_steps = log_every_n_steps + @rank_zero_fn def _log_step_metrics( self, - writer: SummaryWriter, metric_label: str, iteration_timer: TimerProtocol, step_logging_for: int, @@ -75,38 +69,39 @@ def _log_step_metrics( return last_n_values = time_list[-self.moving_avg_window :] - writer.add_scalar( - human_metric_names[metric_label], - sum(last_n_values) / len(last_n_values), - step_logging_for, - ) + if isinstance(self._logger, SummaryWriter): + self._logger.add_scalar( + human_metric_names[metric_label], + sum(last_n_values) / len(last_n_values), + step_logging_for, + ) + else: + cast(MetricLogger, self._logger).log( + human_metric_names[metric_label], + sum(last_n_values) / len(last_n_values), + step_logging_for, + ) def on_train_step_end(self, state: State, unit: TTrainUnit) -> None: timer = none_throws(state.train_state).iteration_timer - if writer := self._writer: - self._log_step_metrics( - writer, - "train_iteration_time", - timer, - unit.train_progress.num_steps_completed, - ) + self._log_step_metrics( + "train_iteration_time", + timer, + unit.train_progress.num_steps_completed, + ) def on_eval_step_end(self, state: State, unit: TEvalUnit) -> None: timer = none_throws(state.eval_state).iteration_timer - if writer := self._writer: - self._log_step_metrics( - writer, - "eval_iteration_time", - timer, - unit.eval_progress.num_steps_completed, - ) + self._log_step_metrics( + "eval_iteration_time", + timer, + unit.eval_progress.num_steps_completed, + ) def on_predict_step_end(self, state: State, unit: TPredictUnit) -> None: timer = none_throws(state.predict_state).iteration_timer - if writer := self._writer: - self._log_step_metrics( - writer, - "predict_iteration_time", - timer, - unit.predict_progress.num_steps_completed, - ) + self._log_step_metrics( + "predict_iteration_time", + timer, + unit.predict_progress.num_steps_completed, + )