From 8a04d86aad008fc9bd7785ca2103a003a4ac9a94 Mon Sep 17 00:00:00 2001 From: Clark Kang Date: Sat, 15 Feb 2025 00:23:12 -0800 Subject: [PATCH] Add first train batch to train unit for extracting example inputs Summary: - Extract example inputs for flop counting or quantization Differential Revision: D69698015 --- torchtnt/framework/unit.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/torchtnt/framework/unit.py b/torchtnt/framework/unit.py index 7e470f100d..41074d85d5 100644 --- a/torchtnt/framework/unit.py +++ b/torchtnt/framework/unit.py @@ -13,11 +13,11 @@ from typing import Any, cast, Dict, Generic, Iterator, TypeVar, Union import torch +from pyre_extensions import none_throws from torchtnt.framework._unit_utils import ( _find_optimizers_for_module, _step_requires_iterator, ) - from torchtnt.framework.state import State from torchtnt.utils.lr_scheduler import TLRScheduler from torchtnt.utils.prepare_module import _is_fsdp_module, FSDPOptimizerWrapper @@ -312,6 +312,7 @@ def on_train_epoch_end(self, state: State) -> None: def __init__(self) -> None: super().__init__() self.train_progress = Progress() + self.first_train_batch: TTrainData | None = None def on_train_start(self, state: State) -> None: """Hook called before training starts. @@ -329,6 +330,14 @@ def on_train_epoch_start(self, state: State) -> None: """ pass + @property + def first_train_batch(self) -> TTrainData: + return none_throws(self.first_train_batch) + + @first_train_batch.setter + def first_train_batch(self, data: TTrainData) -> None: + self.first_train_batch = data + @abstractmethod # pyre-fixme[3]: Return annotation cannot be `Any`. def train_step(self, state: State, data: TTrainData) -> Any: