Skip to content

Commit

Permalink
Adding some missing annotations in torchtnt/utils/flops.py (#742)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #742

Fixed some of the annotation issues :
1. Unused Input -> Convert to _xxx
2. Used Union and cast for variable type
3. Added Generic for Function
4. Fixed a test file which didn't need annotations

Reviewed By: galrotem

Differential Revision: D54950533

fbshipit-source-id: 550235bf8543e9aea58e00fc3d871079ffa9d39d
  • Loading branch information
andywag authored and facebook-github-bot committed Mar 18, 2024
1 parent 4991725 commit 0bcbd55
Showing 1 changed file with 22 additions and 22 deletions.
44 changes: 22 additions & 22 deletions torchtnt/utils/flops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,18 @@
from collections import defaultdict
from functools import reduce
from numbers import Number
from typing import Any, Callable, DefaultDict, Dict, List, Tuple
from typing import Any, Callable, cast, DefaultDict, Dict, List, Tuple, TypeVar, Union

import torch
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import PyTree, tree_map

aten: torch._ops._OpNamespace = torch.ops.aten
T = TypeVar("T")
InputType = Union[torch.Tensor, bool, tuple[bool]]


# pyre-fixme [2] we don't care the type in outputs
def _matmul_flop_jit(inputs: Tuple[torch.Tensor], outputs: Tuple[Any]) -> Number:
def _matmul_flop_jit(inputs: Tuple[torch.Tensor], _outputs: Tuple[Any]) -> Number:
"""
Count flops for matmul.
"""
Expand All @@ -37,8 +38,7 @@ def _matmul_flop_jit(inputs: Tuple[torch.Tensor], outputs: Tuple[Any]) -> Number
return flop


# pyre-fixme [2] we don't care the type in outputs
def _addmm_flop_jit(inputs: Tuple[torch.Tensor], outputs: Tuple[Any]) -> Number:
def _addmm_flop_jit(inputs: Tuple[torch.Tensor], _outputs: Tuple[Any]) -> Number:
"""
Count flops for fully connected layers.
"""
Expand All @@ -55,8 +55,7 @@ def _addmm_flop_jit(inputs: Tuple[torch.Tensor], outputs: Tuple[Any]) -> Number:
return flops


# pyre-fixme [2] we don't care the type in outputs
def _bmm_flop_jit(inputs: Tuple[torch.Tensor], outputs: Tuple[Any]) -> Number:
def _bmm_flop_jit(inputs: Tuple[torch.Tensor], _outputs: Tuple[Any]) -> Number:
"""
Count flops for the bmm operation.
"""
Expand All @@ -71,9 +70,9 @@ def _bmm_flop_jit(inputs: Tuple[torch.Tensor], outputs: Tuple[Any]) -> Number:


def _conv_flop_count(
x_shape: List[int],
w_shape: List[int],
out_shape: List[int],
x_shape: Union[torch.Size, List[int]],
w_shape: Union[torch.Size, List[int]],
out_shape: Union[torch.Size, List[int]],
transposed: bool = False,
) -> Number:
"""
Expand All @@ -100,16 +99,16 @@ def _conv_flop_count(


def _conv_flop_jit(
inputs: Tuple[Any], # pyre-fixme [2] the inputs can be union of Tensor/bool/Tuple
inputs: List[Union[torch.Tensor, bool, Tuple[bool]]],
outputs: Tuple[torch.Tensor],
) -> Number:
"""
Count flops for convolution.
"""
x: torch.Tensor = inputs[0]
w: torch.Tensor = inputs[1]
x: torch.Tensor = cast(torch.Tensor, inputs[0])
w: torch.Tensor = cast(torch.Tensor, inputs[1])
x_shape, w_shape, out_shape = (x.shape, w.shape, outputs[0].shape)
transposed: bool = inputs[6]
transposed: bool = cast(bool, inputs[6])

return _conv_flop_count(
list(x_shape), list(w_shape), list(out_shape), transposed=transposed
Expand All @@ -120,26 +119,28 @@ def _transpose_shape(shape: torch.Size) -> List[int]:
return [shape[1], shape[0]] + list(shape[2:])


# pyre-fixme [2] the inputs can be union of Tensor/bool/Tuple & we don't care about outputs
def _conv_backward_flop_jit(inputs: Tuple[Any], outputs: Tuple[Any]) -> Number:
grad_out_shape, x_shape, w_shape = [i.shape for i in inputs[:3]]
def _conv_backward_flop_jit(
inputs: Tuple[Union[torch.Tensor, bool, Tuple[bool]]], outputs: Tuple[torch.Tensor]
) -> Number:

grad_out_shape, x_shape, w_shape = [cast(torch.Tensor, i).shape for i in inputs[:3]]
output_mask = inputs[-1]
fwd_transposed = inputs[7]
flop_count: Number = 0

if output_mask[0]:
if cast(Tuple[bool], output_mask)[0]:
grad_input_shape = outputs[0].shape
# pyre-fixme [58] this is actually sum of Number and Number
flop_count = flop_count + _conv_flop_count(
grad_out_shape, w_shape, grad_input_shape, not fwd_transposed
)
if output_mask[1]:
if cast(Tuple[bool], output_mask)[1]:
grad_weight_shape = outputs[1].shape
flop_count += _conv_flop_count(
list(_transpose_shape(x_shape)),
list(grad_out_shape),
list(grad_weight_shape),
fwd_transposed,
cast(bool, fwd_transposed),
)

return flop_count
Expand All @@ -165,8 +166,7 @@ def _conv_backward_flop_jit(inputs: Tuple[Any], outputs: Tuple[Any]) -> Number:
}


# pyre-fixme [2, 3] it can be Tuple of anything.
def _normalize_tuple(x: Any) -> Tuple[Any]:
def _normalize_tuple(x: T) -> Tuple[T]:
if not isinstance(x, tuple):
return (x,)
return x
Expand Down

0 comments on commit 0bcbd55

Please sign in to comment.