Skip to content

Commit

Permalink
Add first train batch to train unit for extracting example inputs
Browse files Browse the repository at this point in the history
Summary: - Extract example inputs for flop counting or quantization

Differential Revision: D69698015
  • Loading branch information
clarkdykang authored and facebook-github-bot committed Feb 15, 2025
1 parent 7390c77 commit 8a04d86
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion torchtnt/framework/unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
from typing import Any, cast, Dict, Generic, Iterator, TypeVar, Union

import torch
from pyre_extensions import none_throws
from torchtnt.framework._unit_utils import (
_find_optimizers_for_module,
_step_requires_iterator,
)

from torchtnt.framework.state import State
from torchtnt.utils.lr_scheduler import TLRScheduler
from torchtnt.utils.prepare_module import _is_fsdp_module, FSDPOptimizerWrapper
Expand Down Expand Up @@ -312,6 +312,7 @@ def on_train_epoch_end(self, state: State) -> None:
def __init__(self) -> None:
super().__init__()
self.train_progress = Progress()
self.first_train_batch: TTrainData | None = None

def on_train_start(self, state: State) -> None:
"""Hook called before training starts.
Expand All @@ -329,6 +330,14 @@ def on_train_epoch_start(self, state: State) -> None:
"""
pass

@property
def first_train_batch(self) -> TTrainData:
return none_throws(self.first_train_batch)

@first_train_batch.setter
def first_train_batch(self, data: TTrainData) -> None:
self.first_train_batch = data

@abstractmethod
# pyre-fixme[3]: Return annotation cannot be `Any`.
def train_step(self, state: State, data: TTrainData) -> Any:
Expand Down

0 comments on commit 8a04d86

Please sign in to comment.