Skip to content

Commit

Permalink
Merge branch 'pythonicworkflow' of https://github.com/KumoLiu/MONAI i…
Browse files Browse the repository at this point in the history
…nto pythonicworkflow
  • Loading branch information
KumoLiu committed Nov 27, 2024
2 parents 42d5d0b + d40ec95 commit 1f136f9
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 18 deletions.
22 changes: 18 additions & 4 deletions monai/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from __future__ import annotations

from typing import Tuple, Union
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -154,10 +154,12 @@ def __init__(
)
self.input_size = input_size

def forward(self, x):
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
"""
Args:
x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C
attn_mask (torch.Tensor, optional): mask to apply to the attention matrix.
B x (s_dim_1 * ... * s_dim_n). Defaults to None.
Return:
torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
Expand All @@ -176,7 +178,13 @@ def forward(self, x):

if self.use_flash_attention:
x = F.scaled_dot_product_attention(
query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal
query=q,
key=k,
value=v,
attn_mask=attn_mask,
scale=self.scale,
dropout_p=self.dropout_rate,
is_causal=self.causal,
)
else:
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
Expand All @@ -186,10 +194,16 @@ def forward(self, x):
att_mat = self.rel_positional_embedding(x, att_mat, q)

if self.causal:
if attn_mask is not None:
raise ValueError("Causal attention does not support attention masks.")
att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float("-inf"))

att_mat = att_mat.softmax(dim=-1)
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(1).unsqueeze(2)
attn_mask = attn_mask.expand(-1, self.num_heads, -1, -1)
att_mat = att_mat.masked_fill(attn_mask == 0, float("-inf"))

att_mat = att_mat.softmax(dim=-1)
if self.save_attn:
# no gradients and new tensor;
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
Expand Down
6 changes: 4 additions & 2 deletions monai/networks/blocks/transformerblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,10 @@ def __init__(
use_flash_attention=use_flash_attention,
)

def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:
x = x + self.attn(self.norm1(x))
def forward(
self, x: torch.Tensor, context: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
x = x + self.attn(self.norm1(x), attn_mask=attn_mask)
if self.with_cross_attention:
x = x + self.cross_attn(self.norm_cross_attn(x), context=context)
x = x + self.mlp(self.norm2(x))
Expand Down
36 changes: 24 additions & 12 deletions monai/networks/nets/swin_unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import itertools
from collections.abc import Sequence
from typing import Final

import numpy as np
import torch
Expand Down Expand Up @@ -51,8 +50,6 @@ class SwinUNETR(nn.Module):
<https://arxiv.org/abs/2201.01266>"
"""

patch_size: Final[int] = 2

@deprecated_arg(
name="img_size",
since="1.3",
Expand All @@ -65,18 +62,24 @@ def __init__(
img_size: Sequence[int] | int,
in_channels: int,
out_channels: int,
patch_size: int = 2,
depths: Sequence[int] = (2, 2, 2, 2),
num_heads: Sequence[int] = (3, 6, 12, 24),
window_size: Sequence[int] | int = 7,
qkv_bias: bool = True,
mlp_ratio: float = 4.0,
feature_size: int = 24,
norm_name: tuple | str = "instance",
drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
dropout_path_rate: float = 0.0,
normalize: bool = True,
norm_layer: type[LayerNorm] = nn.LayerNorm,
patch_norm: bool = True,
use_checkpoint: bool = False,
spatial_dims: int = 3,
downsample="merging",
use_v2=False,
downsample: str | nn.Module = "merging",
use_v2: bool = False,
) -> None:
"""
Args:
Expand All @@ -86,14 +89,20 @@ def __init__(
It will be removed in an upcoming version.
in_channels: dimension of input channels.
out_channels: dimension of output channels.
patch_size: size of the patch token.
feature_size: dimension of network feature size.
depths: number of layers in each stage.
num_heads: number of attention heads.
window_size: local window size.
qkv_bias: add a learnable bias to query, key, value.
mlp_ratio: ratio of mlp hidden dim to embedding dim.
norm_name: feature normalization type and arguments.
drop_rate: dropout rate.
attn_drop_rate: attention dropout rate.
dropout_path_rate: drop path rate.
normalize: normalize output intermediate features in each stage.
norm_layer: normalization layer.
patch_norm: whether to apply normalization to the patch embedding.
use_checkpoint: use gradient checkpointing for reduced memory usage.
spatial_dims: number of spatial dims.
downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
Expand All @@ -116,13 +125,15 @@ def __init__(

super().__init__()

img_size = ensure_tuple_rep(img_size, spatial_dims)
patch_sizes = ensure_tuple_rep(self.patch_size, spatial_dims)
window_size = ensure_tuple_rep(7, spatial_dims)

if spatial_dims not in (2, 3):
raise ValueError("spatial dimension should be 2 or 3.")

self.patch_size = patch_size

img_size = ensure_tuple_rep(img_size, spatial_dims)
patch_sizes = ensure_tuple_rep(self.patch_size, spatial_dims)
window_size = ensure_tuple_rep(window_size, spatial_dims)

self._check_input_size(img_size)

if not (0 <= drop_rate <= 1):
Expand All @@ -146,12 +157,13 @@ def __init__(
patch_size=patch_sizes,
depths=depths,
num_heads=num_heads,
mlp_ratio=4.0,
qkv_bias=True,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=dropout_path_rate,
norm_layer=nn.LayerNorm,
norm_layer=norm_layer,
patch_norm=patch_norm,
use_checkpoint=use_checkpoint,
spatial_dims=spatial_dims,
downsample=look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample,
Expand Down
18 changes: 18 additions & 0 deletions tests/test_selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,24 @@ def test_causal(self):
# check upper triangular part of the attention matrix is zero
assert torch.triu(block.att_mat, diagonal=1).sum() == 0

def test_masked_selfattention(self):
n = 64
block = SABlock(hidden_size=128, num_heads=1, dropout_rate=0.1, sequence_length=16, save_attn=True)
input_shape = (1, n, 128)
# generate a mask randomly with zeros and ones of shape (1, n)
mask = torch.randint(0, 2, (1, n)).bool()
block(torch.randn(input_shape), attn_mask=mask)
att_mat = block.att_mat.squeeze()
# ensure all masked columns are zeros
assert torch.allclose(att_mat[:, ~mask.squeeze(0)], torch.zeros_like(att_mat[:, ~mask.squeeze(0)]))

def test_causal_and_mask(self):
with self.assertRaises(ValueError):
block = SABlock(hidden_size=128, num_heads=1, causal=True, sequence_length=64)
inputs = torch.randn(2, 64, 128)
mask = torch.randint(0, 2, (2, 64)).bool()
block(inputs, attn_mask=mask)

@skipUnless(has_einops, "Requires einops")
def test_access_attn_matrix(self):
# input format
Expand Down

0 comments on commit 1f136f9

Please sign in to comment.