diff --git a/tests/framework/test_callback_handler.py b/tests/framework/test_callback_handler.py index e506e7a315..c52013b314 100644 --- a/tests/framework/test_callback_handler.py +++ b/tests/framework/test_callback_handler.py @@ -46,6 +46,19 @@ def on_train_start(self, state: State, unit: TTrainUnit) -> None: def on_train_epoch_start(self, state: State, unit: TTrainUnit) -> None: self.called_hooks.add("on_train_epoch_start") + def on_train_dataloader_iter_creation_start( + self, state: State, unit: TTrainUnit + ) -> None: + self.called_hooks.add("on_train_dataloader_iter_creation_start") + + def on_train_dataloader_iter_creation_end( + self, state: State, unit: TTrainUnit + ) -> None: + self.called_hooks.add("on_train_dataloader_iter_creation_end") + + def on_train_get_next_batch_start(self, state: State, unit: TTrainUnit) -> None: + self.called_hooks.add("on_train_get_next_batch_start") + def on_train_get_next_batch_end(self, state: State, unit: TTrainUnit) -> None: self.called_hooks.add("on_train_get_next_batch_end") @@ -67,6 +80,19 @@ def on_eval_start(self, state: State, unit: TEvalUnit) -> None: def on_eval_epoch_start(self, state: State, unit: TEvalUnit) -> None: self.called_hooks.add("on_eval_epoch_start") + def on_eval_dataloader_iter_creation_start( + self, state: State, unit: TEvalUnit + ) -> None: + self.called_hooks.add("on_eval_dataloader_iter_creation_start") + + def on_eval_dataloader_iter_creation_end( + self, state: State, unit: TEvalUnit + ) -> None: + self.called_hooks.add("on_eval_dataloader_iter_creation_end") + + def on_eval_get_next_batch_start(self, state: State, unit: TEvalUnit) -> None: + self.called_hooks.add("on_eval_get_next_batch_start") + def on_eval_get_next_batch_end(self, state: State, unit: TEvalUnit) -> None: self.called_hooks.add("on_eval_get_next_batch_end") @@ -85,6 +111,19 @@ def on_eval_end(self, state: State, unit: TEvalUnit) -> None: def on_predict_start(self, state: State, unit: TPredictUnit) -> None: self.called_hooks.add("on_predict_start") + def on_predict_dataloader_iter_creation_start( + self, state: State, unit: TPredictUnit + ) -> None: + self.called_hooks.add("on_predict_dataloader_iter_creation_start") + + def on_predict_dataloader_iter_creation_end( + self, state: State, unit: TPredictUnit + ) -> None: + self.called_hooks.add("on_predict_dataloader_iter_creation_end") + + def on_predict_get_next_batch_start(self, state: State, unit: TPredictUnit) -> None: + self.called_hooks.add("on_predict_get_next_batch_start") + def on_predict_epoch_start(self, state: State, unit: TPredictUnit) -> None: self.called_hooks.add("on_predict_epoch_start") @@ -129,6 +168,15 @@ def test_callback_handler(self) -> None: cb_handler.on_train_epoch_start(state, unit) self.assertIn("on_train_epoch_start", called_hooks) + cb_handler.on_train_dataloader_iter_creation_start(state, unit) + self.assertIn("on_train_dataloader_iter_creation_start", called_hooks) + + cb_handler.on_train_dataloader_iter_creation_end(state, unit) + self.assertIn("on_train_dataloader_iter_creation_end", called_hooks) + + cb_handler.on_train_get_next_batch_start(state, unit) + self.assertIn("on_train_get_next_batch_start", called_hooks) + cb_handler.on_train_get_next_batch_end(state, unit) self.assertIn("on_train_get_next_batch_end", called_hooks) @@ -154,6 +202,15 @@ def test_callback_handler(self) -> None: cb_handler.on_eval_epoch_start(state, unit) self.assertIn("on_eval_epoch_start", called_hooks) + cb_handler.on_eval_dataloader_iter_creation_start(state, unit) + self.assertIn("on_eval_dataloader_iter_creation_start", called_hooks) + + cb_handler.on_eval_dataloader_iter_creation_end(state, unit) + self.assertIn("on_eval_dataloader_iter_creation_end", called_hooks) + + cb_handler.on_eval_get_next_batch_start(state, unit) + self.assertIn("on_eval_get_next_batch_start", called_hooks) + cb_handler.on_eval_get_next_batch_end(state, unit) self.assertIn("on_eval_get_next_batch_end", called_hooks) @@ -179,6 +236,15 @@ def test_callback_handler(self) -> None: cb_handler.on_predict_epoch_start(state, unit) self.assertIn("on_predict_epoch_start", called_hooks) + cb_handler.on_predict_dataloader_iter_creation_start(state, unit) + self.assertIn("on_predict_dataloader_iter_creation_start", called_hooks) + + cb_handler.on_predict_dataloader_iter_creation_end(state, unit) + self.assertIn("on_predict_dataloader_iter_creation_end", called_hooks) + + cb_handler.on_predict_get_next_batch_start(state, unit) + self.assertIn("on_predict_get_next_batch_start", called_hooks) + cb_handler.on_predict_get_next_batch_end(state, unit) self.assertIn("on_predict_get_next_batch_end", called_hooks) @@ -202,6 +268,9 @@ def test_get_implemented_callback_mapping(self) -> None: remaining_callback_hooks = ( "on_train_start", "on_train_epoch_start", + "on_train_dataloader_iter_creation_start", + "on_train_dataloader_iter_creation_end", + "on_train_get_next_batch_start", "on_train_get_next_batch_end", "on_train_step_start", "on_train_step_end", @@ -209,6 +278,9 @@ def test_get_implemented_callback_mapping(self) -> None: "on_train_end", "on_eval_start", "on_eval_epoch_start", + "on_eval_dataloader_iter_creation_start", + "on_eval_dataloader_iter_creation_end", + "on_eval_get_next_batch_start", "on_eval_get_next_batch_end", "on_eval_step_start", "on_eval_step_end", @@ -216,6 +288,9 @@ def test_get_implemented_callback_mapping(self) -> None: "on_eval_end", "on_predict_start", "on_predict_epoch_start", + "on_predict_dataloader_iter_creation_start", + "on_predict_dataloader_iter_creation_end", + "on_predict_get_next_batch_start", "on_predict_get_next_batch_end", "on_predict_step_start", "on_predict_step_end", diff --git a/torchtnt/framework/_callback_handler.py b/torchtnt/framework/_callback_handler.py index 4d242d78d0..bab699282c 100644 --- a/torchtnt/framework/_callback_handler.py +++ b/torchtnt/framework/_callback_handler.py @@ -63,6 +63,9 @@ def _get_implemented_callback_mapping( "on_exception", "on_train_start", "on_train_epoch_start", + "on_train_dataloader_iter_creation_start", + "on_train_dataloader_iter_creation_end", + "on_train_get_next_batch_start", "on_train_get_next_batch_end", "on_train_step_start", "on_train_step_end", @@ -70,6 +73,9 @@ def _get_implemented_callback_mapping( "on_train_end", "on_eval_start", "on_eval_epoch_start", + "on_eval_dataloader_iter_creation_start", + "on_eval_dataloader_iter_creation_end", + "on_eval_get_next_batch_start", "on_eval_get_next_batch_end", "on_eval_step_start", "on_eval_step_end", @@ -77,6 +83,9 @@ def _get_implemented_callback_mapping( "on_eval_end", "on_predict_start", "on_predict_epoch_start", + "on_predict_dataloader_iter_creation_start", + "on_predict_dataloader_iter_creation_end", + "on_predict_get_next_batch_start", "on_predict_get_next_batch_end", "on_predict_step_start", "on_predict_step_end", @@ -127,6 +136,28 @@ def on_train_epoch_start(self, state: State, unit: TTrainUnit) -> None: for cb in callbacks: cb.on_train_epoch_start(state, unit) + def on_train_dataloader_iter_creation_start( + self, state: State, unit: TTrainUnit + ) -> None: + fn_name = "on_train_dataloader_iter_creation_start" + callbacks = self._callbacks.get(fn_name, []) + for cb in callbacks: + cb.on_train_dataloader_iter_creation_start(state, unit) + + def on_train_dataloader_iter_creation_end( + self, state: State, unit: TTrainUnit + ) -> None: + fn_name = "on_train_dataloader_iter_creation_end" + callbacks = self._callbacks.get(fn_name, []) + for cb in callbacks: + cb.on_train_dataloader_iter_creation_end(state, unit) + + def on_train_get_next_batch_start(self, state: State, unit: TTrainUnit) -> None: + fn_name = "on_train_get_next_batch_start" + callbacks = self._callbacks.get(fn_name, []) + for cb in callbacks: + cb.on_train_get_next_batch_start(state, unit) + def on_train_get_next_batch_end(self, state: State, unit: TTrainUnit) -> None: fn_name = "on_train_get_next_batch_end" callbacks = self._callbacks.get(fn_name, []) @@ -169,6 +200,28 @@ def on_eval_epoch_start(self, state: State, unit: TEvalUnit) -> None: for cb in callbacks: cb.on_eval_epoch_start(state, unit) + def on_eval_dataloader_iter_creation_start( + self, state: State, unit: TEvalUnit + ) -> None: + fn_name = "on_eval_dataloader_iter_creation_start" + callbacks = self._callbacks.get(fn_name, []) + for cb in callbacks: + cb.on_eval_dataloader_iter_creation_start(state, unit) + + def on_eval_dataloader_iter_creation_end( + self, state: State, unit: TEvalUnit + ) -> None: + fn_name = "on_eval_dataloader_iter_creation_end" + callbacks = self._callbacks.get(fn_name, []) + for cb in callbacks: + cb.on_eval_dataloader_iter_creation_end(state, unit) + + def on_eval_get_next_batch_start(self, state: State, unit: TEvalUnit) -> None: + fn_name = "on_eval_get_next_batch_start" + callbacks = self._callbacks.get(fn_name, []) + for cb in callbacks: + cb.on_eval_get_next_batch_start(state, unit) + def on_eval_get_next_batch_end(self, state: State, unit: TEvalUnit) -> None: fn_name = "on_eval_get_next_batch_end" callbacks = self._callbacks.get(fn_name, []) @@ -211,6 +264,28 @@ def on_predict_epoch_start(self, state: State, unit: TPredictUnit) -> None: for cb in callbacks: cb.on_predict_epoch_start(state, unit) + def on_predict_dataloader_iter_creation_start( + self, state: State, unit: TPredictUnit + ) -> None: + fn_name = "on_predict_dataloader_iter_creation_start" + callbacks = self._callbacks.get(fn_name, []) + for cb in callbacks: + cb.on_predict_dataloader_iter_creation_start(state, unit) + + def on_predict_dataloader_iter_creation_end( + self, state: State, unit: TPredictUnit + ) -> None: + fn_name = "on_predict_dataloader_iter_creation_end" + callbacks = self._callbacks.get(fn_name, []) + for cb in callbacks: + cb.on_predict_dataloader_iter_creation_end(state, unit) + + def on_predict_get_next_batch_start(self, state: State, unit: TPredictUnit) -> None: + fn_name = "on_predict_get_next_batch_start" + callbacks = self._callbacks.get(fn_name, []) + for cb in callbacks: + cb.on_predict_get_next_batch_start(state, unit) + def on_predict_get_next_batch_end(self, state: State, unit: TPredictUnit) -> None: fn_name = "on_predict_get_next_batch_end" callbacks = self._callbacks.get(fn_name, []) diff --git a/torchtnt/framework/callback.py b/torchtnt/framework/callback.py index 6920d2f976..f2e67d4ae9 100644 --- a/torchtnt/framework/callback.py +++ b/torchtnt/framework/callback.py @@ -77,6 +77,22 @@ def on_train_epoch_start(self, state: State, unit: TTrainUnit) -> None: """Hook called before a new train epoch starts.""" pass + def on_train_dataloader_iter_creation_start( + self, state: State, unit: TTrainUnit + ) -> None: + """Hook called before the dataloader iterator is created.""" + pass + + def on_train_dataloader_iter_creation_end( + self, state: State, unit: TTrainUnit + ) -> None: + """Hook called after the dataloader iterator is created.""" + pass + + def on_train_get_next_batch_start(self, state: State, unit: TTrainUnit) -> None: + """Hook called before getting the data batch for the next train step.""" + pass + def on_train_get_next_batch_end(self, state: State, unit: TTrainUnit) -> None: """Hook called after getting the data batch for the next train step.""" pass @@ -105,6 +121,22 @@ def on_eval_epoch_start(self, state: State, unit: TEvalUnit) -> None: """Hook called before a new eval epoch starts.""" pass + def on_eval_dataloader_iter_creation_start( + self, state: State, unit: TEvalUnit + ) -> None: + """Hook called before the dataloader iterator is created.""" + pass + + def on_eval_dataloader_iter_creation_end( + self, state: State, unit: TEvalUnit + ) -> None: + """Hook called after the dataloader iterator is created.""" + pass + + def on_eval_get_next_batch_start(self, state: State, unit: TEvalUnit) -> None: + """Hook called before getting the data batch for the next eval step.""" + pass + def on_eval_get_next_batch_end(self, state: State, unit: TEvalUnit) -> None: """Hook called after getting the data batch for the next eval step.""" pass @@ -133,6 +165,22 @@ def on_predict_epoch_start(self, state: State, unit: TPredictUnit) -> None: """Hook called before a new predict epoch starts.""" pass + def on_predict_dataloader_iter_creation_start( + self, state: State, unit: TPredictUnit + ) -> None: + """Hook called before the dataloader iterator is created.""" + pass + + def on_predict_dataloader_iter_creation_end( + self, state: State, unit: TPredictUnit + ) -> None: + """Hook called after the dataloader iterator is created.""" + pass + + def on_predict_get_next_batch_start(self, state: State, unit: TPredictUnit) -> None: + """Hook called before getting the data batch for the next predict step.""" + pass + def on_predict_get_next_batch_end(self, state: State, unit: TPredictUnit) -> None: """Hook called after getting the data batch for the next predict step.""" pass diff --git a/torchtnt/framework/callbacks/lambda_callback.py b/torchtnt/framework/callbacks/lambda_callback.py index 8e44677fc1..894d733594 100644 --- a/torchtnt/framework/callbacks/lambda_callback.py +++ b/torchtnt/framework/callbacks/lambda_callback.py @@ -82,6 +82,15 @@ def __init__( ] = None, on_train_start: Optional[Callable[[State, TTrainUnit], None]] = None, on_train_epoch_start: Optional[Callable[[State, TTrainUnit], None]] = None, + on_train_dataloader_iter_creation_start: Optional[ + Callable[[State, TTrainUnit], None] + ] = None, + on_train_dataloader_iter_creation_end: Optional[ + Callable[[State, TTrainUnit], None] + ] = None, + on_train_get_next_batch_start: Optional[ + Callable[[State, TTrainUnit], None] + ] = None, on_train_get_next_batch_end: Optional[ Callable[[State, TTrainUnit], None] ] = None, @@ -91,6 +100,15 @@ def __init__( on_train_end: Optional[Callable[[State, TTrainUnit], None]] = None, on_eval_start: Optional[Callable[[State, TEvalUnit], None]] = None, on_eval_epoch_start: Optional[Callable[[State, TEvalUnit], None]] = None, + on_eval_dataloader_iter_creation_start: Optional[ + Callable[[State, TTrainUnit], None] + ] = None, + on_eval_dataloader_iter_creation_end: Optional[ + Callable[[State, TTrainUnit], None] + ] = None, + on_eval_get_next_batch_start: Optional[ + Callable[[State, TTrainUnit], None] + ] = None, on_eval_get_next_batch_end: Optional[Callable[[State, TEvalUnit], None]] = None, on_eval_step_start: Optional[Callable[[State, TEvalUnit], None]] = None, on_eval_step_end: Optional[Callable[[State, TEvalUnit], None]] = None, @@ -98,6 +116,15 @@ def __init__( on_eval_end: Optional[Callable[[State, TEvalUnit], None]] = None, on_predict_start: Optional[Callable[[State, TPredictUnit], None]] = None, on_predict_epoch_start: Optional[Callable[[State, TPredictUnit], None]] = None, + on_predict_dataloader_iter_creation_start: Optional[ + Callable[[State, TTrainUnit], None] + ] = None, + on_predict_dataloader_iter_creation_end: Optional[ + Callable[[State, TTrainUnit], None] + ] = None, + on_predict_get_next_batch_start: Optional[ + Callable[[State, TTrainUnit], None] + ] = None, on_predict_get_next_batch_end: Optional[ Callable[[State, TPredictUnit], None] ] = None, diff --git a/torchtnt/framework/evaluate.py b/torchtnt/framework/evaluate.py index 8c61794219..e1f4cea041 100644 --- a/torchtnt/framework/evaluate.py +++ b/torchtnt/framework/evaluate.py @@ -132,9 +132,10 @@ def _evaluate_impl( eval_unit.on_eval_epoch_start(state) callback_handler.on_eval_epoch_start(state, eval_unit) + callback_handler.on_eval_dataloader_iter_creation_start(state, eval_unit) with get_timing_context(state, "evaluate.iter(dataloader)"): data_iter = iter(eval_state.dataloader) - step_input = data_iter + callback_handler.on_eval_dataloader_iter_creation_end(state, eval_unit) prev_steps_in_epoch = eval_unit.eval_progress.num_steps_completed_in_epoch @@ -151,6 +152,7 @@ def _evaluate_impl( with get_timing_context( state, "evaluate.next(data_iter)" ), eval_state.iteration_timer.time("data_wait_time"): + callback_handler.on_eval_get_next_batch_start(state, eval_unit) step_input = eval_unit.get_next_eval_batch(state, data_iter) callback_handler.on_eval_get_next_batch_end(state, eval_unit) diff --git a/torchtnt/framework/predict.py b/torchtnt/framework/predict.py index c4d362b3f6..33207309bc 100644 --- a/torchtnt/framework/predict.py +++ b/torchtnt/framework/predict.py @@ -147,9 +147,10 @@ def _predict_impl( predict_unit.on_predict_epoch_start(state) callback_handler.on_predict_epoch_start(state, predict_unit) + callback_handler.on_predict_dataloader_iter_creation_start(state, predict_unit) with get_timing_context(state, "predict.iter(dataloader)"): data_iter = iter(predict_state.dataloader) - step_input = data_iter + callback_handler.on_predict_dataloader_iter_creation_end(state, predict_unit) prev_steps_in_epoch = predict_unit.predict_progress.num_steps_completed_in_epoch @@ -166,6 +167,7 @@ def _predict_impl( with get_timing_context( state, "predict.next(data_iter)" ), predict_state.iteration_timer.time("data_wait_time"): + callback_handler.on_predict_get_next_batch_start(state, predict_unit) step_input = predict_unit.get_next_predict_batch(state, data_iter) callback_handler.on_predict_get_next_batch_end(state, predict_unit) diff --git a/torchtnt/framework/train.py b/torchtnt/framework/train.py index 2320082caa..3ed895a783 100644 --- a/torchtnt/framework/train.py +++ b/torchtnt/framework/train.py @@ -192,8 +192,10 @@ def _train_epoch_impl( train_state.dataloader, train_unit.train_progress.num_epochs_completed ) + callback_handler.on_train_dataloader_iter_creation_start(state, train_unit) with get_timing_context(state, "train.iter(dataloader)"): data_iter = iter(train_state.dataloader) + callback_handler.on_train_dataloader_iter_creation_end(state, train_unit) prev_steps_in_epoch = train_unit.train_progress.num_steps_completed_in_epoch @@ -210,6 +212,7 @@ def _train_epoch_impl( with get_timing_context( state, "train.next(data_iter)" ), train_state.iteration_timer.time("data_wait_time"): + callback_handler.on_train_get_next_batch_start(state, train_unit) step_input = train_unit.get_next_train_batch(state, data_iter) callback_handler.on_train_get_next_batch_end(state, train_unit)