From ce4e0a09b994f781d8114aa6bf59fd2b6aa315d8 Mon Sep 17 00:00:00 2001 From: Gal Rotem Date: Tue, 16 Jan 2024 11:55:18 -0800 Subject: [PATCH] fix pyre issue in timer Summary: Fix pyre issue and double logger definition Differential Revision: D52810354 --- torchtnt/utils/timer.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/torchtnt/utils/timer.py b/torchtnt/utils/timer.py index 6606bda196..f84e994cb4 100644 --- a/torchtnt/utils/timer.py +++ b/torchtnt/utils/timer.py @@ -21,7 +21,6 @@ runtime_checkable, Sequence, Tuple, - TypeVar, ) import numpy as np @@ -29,14 +28,11 @@ import torch import torch.distributed as dist from tabulate import tabulate +from torch.distributed.distributed_c10d import Work from torchtnt.utils.distributed import PGWrapper logger: logging.Logger = logging.getLogger(__name__) -AsyncOperator = TypeVar("AsyncOperator") - -logger: logging.Logger = logging.getLogger(__name__) - _TABLE_ROW = Tuple[str, float, int, float, float] _TABLE_DATA = List[_TABLE_ROW] @@ -451,15 +447,13 @@ def __init__(self, interval: datetime.timedelta, cpu_pg: dist.ProcessGroup) -> N self._cpu_pg = cpu_pg self._prev_time: float = perf_counter() self._timeout_tensor: torch.Tensor = torch.zeros(1, dtype=torch.int) - # pyre-fixme[34]: `Variable[AsyncOperator]` isn't present in the function's parameters. - self._prev_work: Optional[AsyncOperator] = None + self._prev_work: Optional[Work] = None def check(self) -> bool: ret = False curr_time = perf_counter() if self._prev_work is not None: - # pyre-fixme[16]: `Variable[AsyncOperator]` has no attribute wait. self._prev_work.wait() ret = self._timeout_tensor[0].item() == 1 if ret: