Skip to content

Commit

Permalink
feat(models): vit uses patched sequential
Browse files Browse the repository at this point in the history
  • Loading branch information
LutingWang committed Aug 29, 2024
1 parent db4d971 commit 9affc5c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
3 changes: 2 additions & 1 deletion todd/models/modules/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch import nn

from ...bases.configs import Config
from ...patches.torch import Sequential
from ...registries import InitWeightsMixin
from ..utils import interpolate_position_embedding

Expand Down Expand Up @@ -88,7 +89,7 @@ def __init__(
self._position_embedding = nn.Parameter(
torch.empty(self.num_patches + 1, width),
)
self._blocks = nn.Sequential(
self._blocks = Sequential(
*[
self.BLOCK_TYPE(width=width, num_heads=num_heads)
for _ in range(depth)
Expand Down
10 changes: 9 additions & 1 deletion todd/patches/torch/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,17 @@ def forward(self, *args, **kwargs) -> dict[str, nn.Module]:

class Sequential(nn.Sequential):

def __init__(self, *args, unpack_args: bool = False, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._unpack_args = unpack_args

def forward(self, *args, **kwargs) -> tuple[Any, ...]:
if not self._unpack_args:
args, = args
for m in self:
args = m(*args, **kwargs)
args = (
m(*args, **kwargs) if self._unpack_args else m(args, **kwargs)
)
return args


Expand Down

0 comments on commit 9affc5c

Please sign in to comment.