Skip to content

Commit

Permalink
add more tests to test_callback_handler (#741)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #741

Add test case to `test_callback_handler` which ensures the order of execution of callbacks.

Reviewed By: JKSenthil

Differential Revision: D54908994

fbshipit-source-id: a5a0969acf1824f0b8b0e2c5c758baad7c672d50
  • Loading branch information
galrotem authored and facebook-github-bot committed Mar 15, 2024
1 parent 52424ce commit dbf80fd
Showing 1 changed file with 60 additions and 1 deletion.
61 changes: 60 additions & 1 deletion tests/framework/test_callback_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@
from torchtnt.framework.callback import Callback
from torchtnt.framework.callbacks.lambda_callback import Lambda
from torchtnt.framework.state import EntryPoint, State
from torchtnt.framework.unit import TEvalUnit, TPredictUnit, TTrainUnit
from torchtnt.framework.unit import (
TEvalUnit,
TPredictUnit,
TrainUnit,
TTrainData,
TTrainUnit,
)
from torchtnt.utils.timer import Timer


Expand Down Expand Up @@ -231,3 +237,56 @@ def dummy_fn(x, y):
self.assertIn(hook, implemented_cbs)
self.assertEqual(len(implemented_cbs[hook]), 1)
self.assertNotIn("on_exception", implemented_cbs)

def test_callback_ordering(self) -> None:
# ensure that callbacks are executed in the order they are passed
class DummyUnit(TrainUnit[None]):
def __init__(self) -> None:
self.on_train_start_callback_order = []
self.on_train_end_callback_order = []

def train_step(self, state: State, data: TTrainData) -> None:
pass

class FirstCallback(Callback):
callback_name = "first_callback"

def on_train_start(self, state: State, unit: TTrainUnit) -> None:
unit.on_train_start_callback_order.append(self.callback_name)

def on_train_end(self, state: State, unit: TTrainUnit) -> None:
unit.on_train_end_callback_order.append(self.callback_name)

class SecondCallback(Callback):
callback_name = "second_callback"

def on_train_start(self, state: State, unit: TTrainUnit) -> None:
unit.on_train_start_callback_order.append(self.callback_name)

def on_train_end(self, state: State, unit: TTrainUnit) -> None:
unit.on_train_end_callback_order.append(self.callback_name)

unit = DummyUnit()
state = MagicMock(spec=State)
first_callback = FirstCallback()
second_callback = SecondCallback()
callback_handler = CallbackHandler([first_callback, second_callback])

callback_handler.on_train_start(state, unit)
callback_handler.on_train_end(state, unit)
self.assertEqual(
unit.on_train_start_callback_order,
[first_callback.callback_name, second_callback.callback_name],
)
self.assertEqual(
unit.on_train_end_callback_order,
[first_callback.callback_name, second_callback.callback_name],
)

self.assertEqual(
callback_handler._callbacks,
{
"on_train_start": [first_callback, second_callback],
"on_train_end": [first_callback, second_callback],
},
)

0 comments on commit dbf80fd

Please sign in to comment.