Skip to content

Commit

Permalink
add generic types to state (#740)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #740

Add generics to state for data type and output type

Reviewed By: JKSenthil

Differential Revision: D54699102

fbshipit-source-id: bf9c2446202d919ddff6a940a09752faaa6bbb10
  • Loading branch information
galrotem authored and facebook-github-bot committed Mar 16, 2024
1 parent dbf80fd commit 4991725
Showing 1 changed file with 19 additions and 18 deletions.
37 changes: 19 additions & 18 deletions torchtnt/framework/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@

import logging
from enum import auto, Enum
from typing import Any, Iterable, Optional
from typing import Generic, Iterable, Optional, TypeVar

from torchtnt.utils.timer import BoundedTimer, TimerProtocol

_logger: logging.Logger = logging.getLogger(__name__)

TStepOutput = TypeVar("TStepOutput")
TData = TypeVar("TData")


def _check_loop_condition(name: str, val: Optional[int]) -> None:
if val is not None and val < 0:
Expand Down Expand Up @@ -58,16 +61,15 @@ class ActivePhase(Enum):
PREDICT = auto()


class PhaseState:
class PhaseState(Generic[TData, TStepOutput]):
"""State for each phase (train, eval, predict).
Modified by the framework, read-only for the user.
"""

def __init__(
self,
*,
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
dataloader: Iterable[Any],
dataloader: Iterable[TData],
max_epochs: Optional[int] = None, # used only for train
max_steps: Optional[int] = None, # used only for train
max_steps_per_epoch: Optional[int] = None,
Expand All @@ -80,23 +82,20 @@ def __init__(
_check_loop_condition("evaluate_every_n_steps", evaluate_every_n_steps)
_check_loop_condition("evaluate_every_n_epochs", evaluate_every_n_epochs)

# pyre-fixme[4]: Attribute annotation cannot contain `Any`.
self._dataloader: Iterable[Any] = dataloader
self._dataloader: Iterable[TData] = dataloader
self._max_epochs = max_epochs
self._max_steps = max_steps
self._max_steps_per_epoch = max_steps_per_epoch
self._evaluate_every_n_steps = evaluate_every_n_steps
self._evaluate_every_n_epochs = evaluate_every_n_epochs

# pyre-fixme[4]: Attribute annotation cannot be `Any`.
self._step_output: Any = None
self._step_output: Optional[TStepOutput] = None
self._iteration_timer = BoundedTimer(
cuda_sync=False, lower_bound=1_000, upper_bound=5_000
)

@property
# pyre-fixme[3]: Return annotation cannot contain `Any`.
def dataloader(self) -> Iterable[Any]:
def dataloader(self) -> Iterable[TData]:
"""Dataloader defined by the user."""
return self._dataloader

Expand Down Expand Up @@ -126,8 +125,7 @@ def evaluate_every_n_epochs(self) -> Optional[int]:
return self._evaluate_every_n_epochs

@property
# pyre-fixme[3]: Return annotation cannot be `Any`.
def step_output(self) -> Any:
def step_output(self) -> Optional[TStepOutput]:
"""Output of the last step."""
return self._step_output

Expand All @@ -137,6 +135,9 @@ def iteration_timer(self) -> TimerProtocol:
return self._iteration_timer


TPhaseState = PhaseState[TData, TStepOutput]


class State:
"""Parent State class which can contain up to 3 instances of PhaseState, for the 3 phases.
Modified by the framework, read-only for the user.
Expand All @@ -147,9 +148,9 @@ def __init__(
*,
entry_point: EntryPoint,
timer: Optional[TimerProtocol] = None,
train_state: Optional[PhaseState] = None,
eval_state: Optional[PhaseState] = None,
predict_state: Optional[PhaseState] = None,
train_state: Optional[TPhaseState] = None,
eval_state: Optional[TPhaseState] = None,
predict_state: Optional[TPhaseState] = None,
) -> None:
self._entry_point = entry_point
self._timer = timer
Expand All @@ -175,17 +176,17 @@ def timer(self) -> Optional[TimerProtocol]:
return self._timer

@property
def train_state(self) -> Optional[PhaseState]:
def train_state(self) -> Optional[TPhaseState]:
"""A :class:`~torchtnt.framework.state.PhaseState` object which contains meta information about the train phase."""
return self._train_state

@property
def eval_state(self) -> Optional[PhaseState]:
def eval_state(self) -> Optional[TPhaseState]:
"""A :class:`~torchtnt.framework.state.PhaseState` object which contains meta information about the eval phase."""
return self._eval_state

@property
def predict_state(self) -> Optional[PhaseState]:
def predict_state(self) -> Optional[TPhaseState]:
"""A :class:`~torchtnt.framework.state.PhaseState` object which contains meta information about the predict phase."""
return self._predict_state

Expand Down

0 comments on commit 4991725

Please sign in to comment.