Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace all pyre-ignore with pyre-fixme #689

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/framework/callbacks/test_base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ def test_best_checkpoint_no_top_k(self) -> None:
save_every_n_epochs=1,
best_checkpoint_config=BestCheckpointConfig(
monitored_metric="train_loss",
# pyre-ignore: Incompatible parameter type [6]
# pyre-fixme: Incompatible parameter type [6]
mode=mode,
),
)
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/test_swa.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def test_input_checks(self) -> None:
with self.assertRaisesRegex(
ValueError, "Unknown averaging method: foo. Only ema and swa are supported."
):
# pyre-ignore On purpose to test run time exception
# pyre-fixme On purpose to test run time exception
AveragedModel(model, averaging_method="foo")

def test_lit_ema(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion torchtnt/framework/callbacks/base_csv_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def get_step_output_rows(
self,
state: State,
unit: TPredictUnit,
# pyre-ignore: Missing parameter annotation [2]
# pyre-fixme: Missing parameter annotation [2]
step_output: Any,
) -> Union[List[str], List[List[str]]]:
...
Expand Down
2 changes: 1 addition & 1 deletion torchtnt/framework/callbacks/module_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
process_fn: Callable[
[List[ModuleSummaryObj]], None
] = _log_module_summary_tables,
# pyre-ignore
# pyre-fixme
module_inputs: Optional[
MutableMapping[str, Tuple[Tuple[Any, ...], Dict[str, Any]]]
] = None,
Expand Down
4 changes: 2 additions & 2 deletions torchtnt/framework/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def _train_epoch_impl(
):
_evaluate_impl(
state,
# pyre-ignore: Incompatible parameter type [6]
# pyre-fixme: Incompatible parameter type [6]
train_unit,
callback_handler,
)
Expand Down Expand Up @@ -257,7 +257,7 @@ def _train_epoch_impl(
):
_evaluate_impl(
state,
# pyre-ignore: Incompatible parameter type [6]
# pyre-fixme: Incompatible parameter type [6]
train_unit,
callback_handler,
)
Expand Down
2 changes: 1 addition & 1 deletion torchtnt/utils/data/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def __init__(
name: torch.IntTensor([idx])
for idx, name in enumerate(self._iterator_names)
}
# pyre-ignore[4]: missing attribute annotation
# pyre-fixme[4]: missing attribute annotation
self._process_group = dist.new_group(backend="gloo", ranks=None)

self._iterators_finished: List[str] = []
Expand Down
2 changes: 1 addition & 1 deletion torchtnt/utils/data/profile_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def profile_dataloader(
with timer.time("copy_data_to_device"), record_function(
"copy_data_to_device"
):
# pyre-ignore [6]: device is checked as not None before calling this
# pyre-fixme [6]: device is checked as not None before calling this
data = copy_data_to_device(data, device)

steps_completed += 1
Expand Down
4 changes: 2 additions & 2 deletions torchtnt/utils/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,14 +313,14 @@ def collect_system_stats(device: torch.device) -> Dict[str, Any]:
system_stats: Dict[str, Any] = {}
cpu_stats = get_psutil_cpu_stats()

# pyre-ignore
# pyre-fixme
system_stats.update(cpu_stats)

if torch.cuda.is_available():
try:
gpu_stats = get_nvidia_smi_gpu_stats(device)

# pyre-ignore
# pyre-fixme
system_stats.update(gpu_stats)
system_stats.update(torch.cuda.memory_stats())
except FileNotFoundError:
Expand Down
40 changes: 20 additions & 20 deletions torchtnt/utils/flops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
aten: torch._ops._OpNamespace = torch.ops.aten


# pyre-ignore [2] we don't care the type in outputs
# pyre-fixme [2] we don't care the type in outputs
def _matmul_flop_jit(inputs: Tuple[torch.Tensor], outputs: Tuple[Any]) -> Number:
"""
Count flops for matmul.
Expand All @@ -35,7 +35,7 @@ def _matmul_flop_jit(inputs: Tuple[torch.Tensor], outputs: Tuple[Any]) -> Number
return flop


# pyre-ignore [2] we don't care the type in outputs
# pyre-fixme [2] we don't care the type in outputs
def _addmm_flop_jit(inputs: Tuple[torch.Tensor], outputs: Tuple[Any]) -> Number:
"""
Count flops for fully connected layers.
Expand All @@ -53,7 +53,7 @@ def _addmm_flop_jit(inputs: Tuple[torch.Tensor], outputs: Tuple[Any]) -> Number:
return flops


# pyre-ignore [2] we don't care the type in outputs
# pyre-fixme [2] we don't care the type in outputs
def _bmm_flop_jit(inputs: Tuple[torch.Tensor], outputs: Tuple[Any]) -> Number:
"""
Count flops for the bmm operation.
Expand Down Expand Up @@ -98,7 +98,7 @@ def _conv_flop_count(


def _conv_flop_jit(
inputs: Tuple[Any], # pyre-ignore [2] the inputs can be union of Tensor/bool/Tuple
inputs: Tuple[Any], # pyre-fixme [2] the inputs can be union of Tensor/bool/Tuple
outputs: Tuple[torch.Tensor],
) -> Number:
"""
Expand All @@ -118,7 +118,7 @@ def _transpose_shape(shape: torch.Size) -> List[int]:
return [shape[1], shape[0]] + list(shape[2:])


# pyre-ignore [2] the inputs can be union of Tensor/bool/Tuple & we don't care about outputs
# pyre-fixme [2] the inputs can be union of Tensor/bool/Tuple & we don't care about outputs
def _conv_backward_flop_jit(inputs: Tuple[Any], outputs: Tuple[Any]) -> Number:
grad_out_shape, x_shape, w_shape = [i.shape for i in inputs[:3]]
output_mask = inputs[-1]
Expand All @@ -127,7 +127,7 @@ def _conv_backward_flop_jit(inputs: Tuple[Any], outputs: Tuple[Any]) -> Number:

if output_mask[0]:
grad_input_shape = outputs[0].shape
# pyre-ignore [58] this is actually sum of Number and Number
# pyre-fixme [58] this is actually sum of Number and Number
flop_count = flop_count + _conv_flop_count(
grad_out_shape, w_shape, grad_input_shape, not fwd_transposed
)
Expand All @@ -143,7 +143,7 @@ def _conv_backward_flop_jit(inputs: Tuple[Any], outputs: Tuple[Any]) -> Number:
return flop_count


# pyre-ignore [5]
# pyre-fixme [5]
flop_mapping: Dict[Callable[..., Any], Callable[[Tuple[Any], Tuple[Any]], Number]] = {
aten.mm: _matmul_flop_jit,
aten.matmul: _matmul_flop_jit,
Expand All @@ -163,7 +163,7 @@ def _conv_backward_flop_jit(inputs: Tuple[Any], outputs: Tuple[Any]) -> Number:
}


# pyre-ignore [2, 3] it can be Tuple of anything.
# pyre-fixme [2, 3] it can be Tuple of anything.
def _normalize_tuple(x: Any) -> Tuple[Any]:
if not isinstance(x, tuple):
return (x,)
Expand Down Expand Up @@ -213,33 +213,33 @@ def __init__(self, module: torch.nn.Module) -> None:
)
self._parents: List[str] = [""]

# pyre-ignore
# pyre-fixme
def __exit__(self, exc_type, exc_val, exc_tb):
for hook_handle in self._all_hooks:
hook_handle.remove()
super().__exit__(exc_type, exc_val, exc_tb)

def __torch_dispatch__(
self,
func: Callable[..., Any], # pyre-ignore [2] func can be any func
types: Tuple[Any], # pyre-ignore [2]
args=(), # pyre-ignore [2]
kwargs=None, # pyre-ignore [2]
func: Callable[..., Any], # pyre-fixme [2] func can be any func
types: Tuple[Any], # pyre-fixme [2]
args=(), # pyre-fixme [2]
kwargs=None, # pyre-fixme [2]
) -> PyTree:
rs = func(*args, **kwargs)
outs = _normalize_tuple(rs)

if func in flop_mapping:
flop_count = flop_mapping[func](args, outs)
for par in self._parents:
# pyre-ignore [58]
# pyre-fixme [58]
self.flop_counts[par][func.__name__] += flop_count
else:
logging.debug(f"{func} is not yet supported in FLOPs calculation.")

return rs

# pyre-ignore [3]
# pyre-fixme [3]
def _create_backwards_push(self, name: str) -> Callable[..., Any]:
class PushState(torch.autograd.Function):
@staticmethod
Expand All @@ -262,7 +262,7 @@ def backward(ctx, *grad_outs):
# using a function parameter.
return PushState.apply

# pyre-ignore [3]
# pyre-fixme [3]
def _create_backwards_pop(self, name: str) -> Callable[..., Any]:
class PopState(torch.autograd.Function):
@staticmethod
Expand All @@ -286,9 +286,9 @@ def backward(ctx, *grad_outs):
# using a function parameter.
return PopState.apply

# pyre-ignore [3] Return a callable function
# pyre-fixme [3] Return a callable function
def _enter_module(self, name: str) -> Callable[..., Any]:
# pyre-ignore [2, 3]
# pyre-fixme [2, 3]
def f(module: torch.nn.Module, inputs: Tuple[Any]):
parents = self._parents
parents.append(name)
Expand All @@ -298,9 +298,9 @@ def f(module: torch.nn.Module, inputs: Tuple[Any]):

return f

# pyre-ignore [3] Return a callable function
# pyre-fixme [3] Return a callable function
def _exit_module(self, name: str) -> Callable[..., Any]:
# pyre-ignore [2, 3]
# pyre-fixme [2, 3]
def f(module: torch.nn.Module, inputs: Tuple[Any], outputs: Tuple[Any]):
parents = self._parents
assert parents[-1] == name
Expand Down
4 changes: 2 additions & 2 deletions torchtnt/utils/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@


def _is_named_tuple(
# pyre-ignore: Missing parameter annotation [2]: Parameter `x` must have a type other than `Any`.
# pyre-fixme: Missing parameter annotation [2]: Parameter `x` must have a type other than `Any`.
x: Any,
) -> bool:
return isinstance(x, tuple) and hasattr(x, "_asdict") and hasattr(x, "_fields")


def get_tensor_size_bytes_map(
# pyre-ignore: Missing parameter annotation [2]: Parameter `obj` must have a type other than `Any`.
# pyre-fixme: Missing parameter annotation [2]: Parameter `obj` must have a type other than `Any`.
obj: Any,
) -> Dict[torch.Tensor, int]:
tensor_map = {}
Expand Down
24 changes: 12 additions & 12 deletions torchtnt/utils/module_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,9 @@ def _clean_flops(flop: DefaultDict[str, DefaultDict[str, int]], N: int) -> None:

def _get_module_flops_and_activation_sizes(
module: torch.nn.Module,
# pyre-ignore
# pyre-fixme
module_args: Optional[Tuple[Any, ...]] = None,
# pyre-ignore
# pyre-fixme
module_kwargs: Optional[MutableMapping[str, Any]] = None,
) -> _ModuleSummaryData:
# a mapping from module name to activation size tuple (in_size, out_size)
Expand Down Expand Up @@ -309,9 +309,9 @@ def _has_tensor(item: Optional[PyTree]) -> bool:

def get_module_summary(
module: torch.nn.Module,
# pyre-ignore
# pyre-fixme
module_args: Optional[Tuple[Any, ...]] = None,
# pyre-ignore
# pyre-fixme
module_kwargs: Optional[MutableMapping[str, Any]] = None,
) -> ModuleSummary:
"""
Expand Down Expand Up @@ -669,13 +669,13 @@ def _activation_size_hook(
activation_sizes: Dict[
str, Tuple[Union[TUnknown, List[int]], Union[TUnknown, List[int]]]
],
# pyre-ignore: Invalid type parameters [24]
# pyre-fixme: Invalid type parameters [24]
) -> Callable[[str], Callable]:
# pyre-ignore: Missing parameter annotation [2]
# pyre-fixme: Missing parameter annotation [2]
def intermediate_hook(
module_name: str,
) -> Callable[[torch.nn.Module, Any, Any], None]:
# pyre-ignore
# pyre-fixme
def hook(_: torch.nn.Module, inp: Any, out: Any) -> None:
if len(inp) == 1:
inp = inp[0]
Expand All @@ -690,9 +690,9 @@ def hook(_: torch.nn.Module, inp: Any, out: Any) -> None:

def _forward_time_pre_hook(
timer_mapping: Dict[str, float]
# pyre-ignore: Invalid type parameters [24]
# pyre-fixme: Invalid type parameters [24]
) -> Callable[[str], Callable]:
# pyre-ignore: Missing parameter annotation [2]
# pyre-fixme: Missing parameter annotation [2]
def intermediate_hook(
module_name: str,
) -> Callable[[torch.nn.Module, Any], None]:
Expand All @@ -707,9 +707,9 @@ def hook(_module: torch.nn.Module, _inp: Any) -> None:
def _forward_time_hook(
timer_mapping: Dict[str, float],
elapsed_times: Dict[str, float],
# pyre-ignore: Invalid type parameters [24]
# pyre-fixme: Invalid type parameters [24]
) -> Callable[[str], Callable]:
# pyre-ignore: Missing parameter annotation [2]
# pyre-fixme: Missing parameter annotation [2]
def intermediate_hook(
module_name: str,
) -> Callable[[torch.nn.Module, Any, Any], None]:
Expand All @@ -725,7 +725,7 @@ def hook(_module: torch.nn.Module, _inp: Any, _out: Any) -> None:

def _register_hooks(
module: torch.nn.Module,
# pyre-ignore: Invalid type parameters [24]
# pyre-fixme: Invalid type parameters [24]
hooks: List[Tuple[Callable, _HookType]],
) -> List[RemovableHandle]:
"""
Expand Down
2 changes: 1 addition & 1 deletion torchtnt/utils/prepare_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class TorchCompileParams:

fullgraph: bool = False
dynamic: bool = False
# pyre-ignore: Invalid type parameters [24]
# pyre-fixme: Invalid type parameters [24]
backend: Union[str, Callable] = "inductor"
mode: Union[str, None] = None
options: Optional[Dict[str, Union[str, int, bool]]] = None
Expand Down
10 changes: 5 additions & 5 deletions torchtnt/utils/swa.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch

# TODO: torch/optim/swa_utils.pyi needs to be updated
# pyre-ignore Undefined import [21]: Could not find a name `get_ema_multi_avg_fn` defined in module `torch.optim.swa_utils`.
# pyre-fixme Undefined import [21]: Could not find a name `get_ema_multi_avg_fn` defined in module `torch.optim.swa_utils`.
from torch.optim.swa_utils import (
AveragedModel as PyTorchAveragedModel,
get_ema_multi_avg_fn,
Expand Down Expand Up @@ -55,11 +55,11 @@ def __init__(
raise ValueError(f"Decay must be between 0 and 1, got {ema_decay}")

# TODO: torch/optim/swa_utils.pyi needs to be updated
# pyre-ignore Undefined attribute [16]: Module `torch.optim.swa_utils` has no attribute `get_ema_multi_avg_fn`.
# pyre-fixme Undefined attribute [16]: Module `torch.optim.swa_utils` has no attribute `get_ema_multi_avg_fn`.
multi_avg_fn = get_ema_multi_avg_fn(ema_decay)
elif averaging_method == "swa":
# TODO: torch/optim/swa_utils.pyi needs to be updated
# pyre-ignore Undefined attribute [16]: Module `torch.optim.swa_utils` has no attribute `get_swa_multi_avg_fn`.
# pyre-fixme Undefined attribute [16]: Module `torch.optim.swa_utils` has no attribute `get_swa_multi_avg_fn`.
multi_avg_fn = get_swa_multi_avg_fn()

if use_lit:
Expand Down Expand Up @@ -88,7 +88,7 @@ def __init__(
# use default init implementation

# TODO: torch/optim/swa_utils.pyi needs to be updated
# pyre-ignore Unexpected keyword [28]
# pyre-fixme Unexpected keyword [28]
super().__init__(
model,
device=device,
Expand All @@ -104,6 +104,6 @@ def update_parameters(self, model: torch.nn.Module) -> None:
)

# TODO: torch/optim/swa_utils.pyi needs to be updated
# pyre-ignore Undefined attribute [16]: Module `torch.optim.swa_utils` has no attribute `get_ema_multi_avg_fn`.
# pyre-fixme Undefined attribute [16]: Module `torch.optim.swa_utils` has no attribute `get_ema_multi_avg_fn`.
self.multi_avg_fn = get_ema_multi_avg_fn(decay)
super().update_parameters(model)
2 changes: 1 addition & 1 deletion torchtnt/utils/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def _sync_durations(
pg_wrapper.all_gather_object(outputs, recorded_durations)
ret = defaultdict(list)
for output in outputs:
# pyre-ignore [16]: `Optional` has no attribute `__getitem__`.
# pyre-fixme [16]: `Optional` has no attribute `__getitem__`.
for k, v in output.items():
if k not in ret:
ret[k] = []
Expand Down
Loading