diff --git a/tests/framework/test_callback_handler.py b/tests/framework/test_callback_handler.py index df99358f39..e506e7a315 100644 --- a/tests/framework/test_callback_handler.py +++ b/tests/framework/test_callback_handler.py @@ -18,7 +18,13 @@ from torchtnt.framework.callback import Callback from torchtnt.framework.callbacks.lambda_callback import Lambda from torchtnt.framework.state import EntryPoint, State -from torchtnt.framework.unit import TEvalUnit, TPredictUnit, TTrainUnit +from torchtnt.framework.unit import ( + TEvalUnit, + TPredictUnit, + TrainUnit, + TTrainData, + TTrainUnit, +) from torchtnt.utils.timer import Timer @@ -231,3 +237,56 @@ def dummy_fn(x, y): self.assertIn(hook, implemented_cbs) self.assertEqual(len(implemented_cbs[hook]), 1) self.assertNotIn("on_exception", implemented_cbs) + + def test_callback_ordering(self) -> None: + # ensure that callbacks are executed in the order they are passed + class DummyUnit(TrainUnit[None]): + def __init__(self) -> None: + self.on_train_start_callback_order = [] + self.on_train_end_callback_order = [] + + def train_step(self, state: State, data: TTrainData) -> None: + pass + + class FirstCallback(Callback): + callback_name = "first_callback" + + def on_train_start(self, state: State, unit: TTrainUnit) -> None: + unit.on_train_start_callback_order.append(self.callback_name) + + def on_train_end(self, state: State, unit: TTrainUnit) -> None: + unit.on_train_end_callback_order.append(self.callback_name) + + class SecondCallback(Callback): + callback_name = "second_callback" + + def on_train_start(self, state: State, unit: TTrainUnit) -> None: + unit.on_train_start_callback_order.append(self.callback_name) + + def on_train_end(self, state: State, unit: TTrainUnit) -> None: + unit.on_train_end_callback_order.append(self.callback_name) + + unit = DummyUnit() + state = MagicMock(spec=State) + first_callback = FirstCallback() + second_callback = SecondCallback() + callback_handler = CallbackHandler([first_callback, second_callback]) + + callback_handler.on_train_start(state, unit) + callback_handler.on_train_end(state, unit) + self.assertEqual( + unit.on_train_start_callback_order, + [first_callback.callback_name, second_callback.callback_name], + ) + self.assertEqual( + unit.on_train_end_callback_order, + [first_callback.callback_name, second_callback.callback_name], + ) + + self.assertEqual( + callback_handler._callbacks, + { + "on_train_start": [first_callback, second_callback], + "on_train_end": [first_callback, second_callback], + }, + )