Skip to content

Commit

Permalink
feat: torch module output shape calculator
Browse files Browse the repository at this point in the history
  • Loading branch information
LutingWang committed Jan 9, 2024
1 parent 8f6802d commit f333b60
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 0 deletions.
28 changes: 28 additions & 0 deletions tests/test_utils.py/test_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch
from torch import nn

from todd.utils.torch import Shape


class TestShapePytest:

def test_conv(self):
module = nn.Conv1d(3, 6, 3, padding=1)
x = torch.randn(2, 3, 4)
assert Shape.conv(module, x) == (2, 6, 4)

module = nn.Conv2d(3, 6, 3, padding=1)
x = torch.randn(2, 3, 4, 4)
assert Shape.conv(module, x) == (2, 6, 4, 4)

module = nn.Conv2d(1, 1, 3, padding=1)
x = torch.randn(1, 1, 5, 5)
assert Shape.conv(module, x) == (1, 1, 5, 5)

module = nn.Conv2d(3, 2, 3, padding=1)
x = torch.randn(1, 3, 6, 6)
assert Shape.conv(module, x) == (1, 2, 6, 6)

module = nn.Conv2d(1, 1, 1)
x = torch.randn(1, 1, 10, 10)
assert Shape.conv(module, x) == (1, 1, 10, 10)
45 changes: 45 additions & 0 deletions todd/utils/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
'get_world_size',
'all_gather',
'all_gather_',
'Shape',
]

import functools
import itertools
import operator
import os

import torch
import torch.distributed as dist
from torch import nn


def get_rank(*args, **kwargs) -> int:
Expand Down Expand Up @@ -82,3 +85,45 @@ def all_gather_(
for shape, numel, container in zip(shapes, numel_list, containers)
]
return tensors


class Shape:

@classmethod
def module(cls, module: nn.Module, x: torch.Tensor) -> tuple[int, ...]:
if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
return cls.conv(module, x)
raise TypeError(f"Unknown type {type(module)}")

@staticmethod
def _conv(
x: int,
padding: int,
dilation: int,
kernel_size: int,
stride: int,
) -> int:
x += 2 * padding - dilation * (kernel_size - 1) - 1
return x // stride + 1

@classmethod
def conv(
cls,
module: nn.Conv1d | nn.Conv2d | nn.Conv3d,
x: torch.Tensor,
) -> tuple[int, ...]:
b, c, *shape = x.shape

assert c == module.in_channels
c = module.out_channels

return b, c, *itertools.starmap(
cls._conv,
zip(
shape,
module.padding,
module.dilation,
module.kernel_size,
module.stride,
),
)

0 comments on commit f333b60

Please sign in to comment.