-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathtransformers.py
787 lines (706 loc) · 31 KB
/
transformers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
"""
Implementations of transformers and variants
Implementation of encoder and decoder are adapted from DETR codebase by Facebook Research
https://github.com/facebookresearch/detr/blob/main/models/transformer.py
Implementation of Swin Transformers are adapted from microsoft/Swin-Transformer and torchvision
https://github.com/microsoft/Swin-Transformer/tree/main/models
https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py
Fred Zhang <frederic.zhang@anu.edu.au>
The Australian National University
Microsoft Research Asia
"""
import copy
import math
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from attention import MultiheadAttention
from typing import List, Optional, Callable
class TransformerEncoderLayer(nn.Module):
def __init__(self, dim, num_heads, ffn_interm_dim, dropout=0.1):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.dropout = dropout
self.ffn_interm_dim = ffn_interm_dim
# Linear projections on qkv have been removed in this custom layer.
self.attn = MultiheadAttention(dim, num_heads, dropout=dropout)
# Add the missing linear projections.
self.q_proj = nn.Linear(dim, dim)
self.k_proj = nn.Linear(dim, dim)
self.v_proj = nn.Linear(dim, dim)
# The positional embeddings include box centres, widths and heights,
# which will be twice the representation size.
self.qpos_proj = nn.Linear(2 * dim, dim)
self.kpos_proj = nn.Linear(2 * dim, dim)
self.ffn = nn.Sequential(
nn.Linear(dim, ffn_interm_dim), nn.ReLU(), nn.Dropout(dropout),
nn.Linear(ffn_interm_dim, dim)
)
self.ln1 = nn.LayerNorm(dim)
self.ln2 = nn.LayerNorm(dim)
self.dp1 = nn.Dropout(dropout)
self.dp2 = nn.Dropout(dropout)
def forward(self, x, pos,
attn_mask: Optional[Tensor] = None,
key_padding_mask: Optional[Tensor] = None,
):
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
q_pos = self.qpos_proj(pos)
k_pos = self.kpos_proj(pos)
q = q + q_pos
k = k + k_pos
attn, attn_weights = self.attn(
query=q, key=k, value=v,
attn_mask=attn_mask, key_padding_mask=key_padding_mask
)
x = self.ln1(x + self.dp1(attn))
x = self.ln2(x + self.dp2(self.ffn(x)))
return x, attn_weights
class TransformerEncoder(nn.Module):
def __init__(self, hidden_size=256, num_heads=8, num_layers=2, dropout=.1):
super().__init__()
self.num_layers = num_layers
self.layers = nn.ModuleList([TransformerEncoderLayer(
dim=hidden_size, num_heads=num_heads,
ffn_interm_dim=hidden_size * 4, dropout=dropout
) for _ in range(num_layers)])
def forward(self, x, pos):
attn_weights = []
for layer in self.layers:
x, attn = layer(x, pos)
attn_weights.append(attn)
return x, attn_weights
class TransformerDecoderLayer(nn.Module):
def __init__(self, q_dim, kv_dim, num_heads, ffn_interm_dim, dropout=0.1):
"""
Parameters:
-----------
q_dim: int
Dimension of the interaction queries.
kv_dim: int
Dimension of the image features.
num_heads: int
Number of heads used in multihead attention.
ffn_interm_dim: int
Dimension of the intermediate representation in the feedforward network.
dropout: float, default: 0.1
Dropout percentage used during training.
"""
super().__init__()
self.q_dim = q_dim
self.kv_dim = kv_dim
self.num_heads = num_heads
self.dropout = dropout
self.ffn_interm_dim = ffn_interm_dim
# Linear projections on qkv have been removed in this custom layer.
self.q_attn = MultiheadAttention(q_dim, num_heads, dropout=dropout)
# Add the missing linear projections.
self.q_attn_q_proj = nn.Linear(q_dim, q_dim)
self.q_attn_k_proj = nn.Linear(q_dim, q_dim)
self.q_attn_v_proj = nn.Linear(q_dim, q_dim)
# Each scalar is mapped to a vector of shape kv_dim // 2.
# For a box pair, the dimension is 8 * (kv_dim // 2).
self.q_attn_qpos_proj = nn.Linear(kv_dim * 4, q_dim)
self.q_attn_kpos_proj = nn.Linear(kv_dim * 4, q_dim)
self.qk_attn = MultiheadAttention(q_dim * 2, num_heads, dropout=dropout, vdim=q_dim)
self.qk_attn_q_proj = nn.Linear(q_dim, q_dim)
self.qk_attn_k_proj = nn.Linear(kv_dim, q_dim)
self.qk_attn_v_proj = nn.Linear(kv_dim, q_dim)
self.qk_attn_kpos_proj = nn.Linear(kv_dim, q_dim)
self.qk_attn_qpos_proj = nn.Linear(kv_dim * 2, q_dim)
self.ffn = nn.Sequential(
nn.Linear(q_dim, ffn_interm_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(ffn_interm_dim, q_dim)
)
self.ln1 = nn.LayerNorm(q_dim)
self.ln2 = nn.LayerNorm(q_dim)
self.ln3 = nn.LayerNorm(q_dim)
self.dp1 = nn.Dropout(dropout)
self.dp2 = nn.Dropout(dropout)
self.dp3 = nn.Dropout(dropout)
def forward(self,
queries: Tensor, features: Tensor,
q_pos: Tensor, k_pos: Tensor,
q_attn_mask: Optional[Tensor] = None,
qk_attn_mask: Optional[Tensor] = None,
q_padding_mask: Optional[Tensor] = None,
kv_padding_mask: Optional[Tensor] = None,
):
"""
Parameters:
-----------
queries: Tensor
Interaction queries of size (N, B, K).
features: Tensor
Image features of size (HW, B, C).
q_attn_mask: Tensor, default: None
Attention mask to be applied during the self attention of queries.
qk_attn_mask: Tensor, default: None
Attention mask to be applied during the cross attention from image
features to interaction queries.
q_padding_mask: Tensor, default: None
Padding mask for interaction queries of size (B, N). Values of `True`
indicate the corresponding query was padded and to be ignored.
kv_padding_mask: Tensor, default: None
Padding mask for image features of size (B, HW).
q_pos: Tensor, default: None
Positional encodings for the interaction queries.
k_pos: Tensor, default: None
Positional encodings for the image features.
Returns:
--------
queries: Tensor
"""
# Perform self attention amongst queries
q = self.q_attn_q_proj(queries)
k = self.q_attn_k_proj(queries)
v = self.q_attn_v_proj(queries)
q_p = self.q_attn_qpos_proj(q_pos["box"])
k_p = self.q_attn_kpos_proj(q_pos["box"])
q = q + q_p
k = k + k_p
q_attn = self.q_attn(
q, k, value=v, attn_mask=q_attn_mask,
key_padding_mask=q_padding_mask
)[0]
queries = self.ln1(queries + self.dp1(q_attn))
# Perform cross attention from memory features to queries
q = self.qk_attn_q_proj(queries)
k = self.qk_attn_k_proj(features)
v = self.qk_attn_v_proj(features)
q_p = self.qk_attn_qpos_proj(q_pos["centre"])
k_p = self.qk_attn_kpos_proj(k_pos)
n_q, bs, _ = q.shape
q = q.view(n_q, bs, self.num_heads, self.q_dim // self.num_heads)
q_p = q_p.view(n_q, bs, self.num_heads, self.q_dim // self.num_heads)
q = torch.cat([q, q_p], dim=3).view(n_q, bs, self.q_dim * 2)
hw, _, _ = k.shape
k = k.view(hw, bs, self.num_heads, self.q_dim // self.num_heads)
k_p = k_p.view(hw, bs, self.num_heads, self.q_dim // self.num_heads)
k = torch.cat([k, k_p], dim=3).view(hw, bs, self.q_dim * 2)
qk_attn = self.qk_attn(
query=q, key=k, value=v, attn_mask=qk_attn_mask,
key_padding_mask=kv_padding_mask
)[0]
queries = self.ln2(queries + self.dp2(qk_attn))
queries = self.ln3(queries + self.dp3(self.ffn(queries)))
return queries
class TransformerDecoder(nn.Module):
def __init__(self, decoder_layer, num_layers, return_intermediate=True):
super().__init__()
self.layers = nn.ModuleList([copy.deepcopy(decoder_layer) for i in range(num_layers)])
self.num_layers = num_layers
self.norm = nn.LayerNorm(decoder_layer.q_dim)
self.return_intermediate = return_intermediate
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, queries, features,
q_attn_mask: Optional[Tensor] = None,
qk_attn_mask: Optional[Tensor] = None,
q_padding_mask: Optional[Tensor] = None,
kv_padding_mask: Optional[Tensor] = None,
q_pos: Optional[Tensor] = None,
k_pos: Optional[Tensor] = None,
):
# Add support for zero layers
if self.num_layers == 0:
return queries.unsqueeze(0)
# Explicitly handle zero-size queries
if queries.numel() == 0:
rp = self.num_layers if self.return_intermediate else 1
return queries.unsqueeze(0).repeat(rp, 1, 1, 1)
output = queries
intermediate = []
for layer in self.layers:
output = layer(
output, features,
q_attn_mask=q_attn_mask,
qk_attn_mask=qk_attn_mask,
q_padding_mask=q_padding_mask,
kv_padding_mask=kv_padding_mask,
q_pos=q_pos, k_pos=k_pos,
)
if self.return_intermediate:
intermediate.append(self.norm(output))
if self.return_intermediate:
output = torch.stack(intermediate)
else:
output = self.norm(output).unsqueeze(0)
return output
def _get_relative_position_bias(
relative_position_bias_table: torch.Tensor, relative_position_index: torch.Tensor, window_size: List[int]
) -> torch.Tensor:
N = window_size[0] * window_size[1]
relative_position_bias = relative_position_bias_table[relative_position_index] # type: ignore[index]
relative_position_bias = relative_position_bias.view(N, N, -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)
return relative_position_bias
def stochastic_depth(input: Tensor, p: float, mode: str, training: bool = True) -> Tensor:
"""
Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth"
<https://arxiv.org/abs/1603.09382>`_ used for randomly dropping residual
branches of residual architectures.
Args:
input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one
being its batch i.e. a batch with ``N`` rows.
p (float): probability of the input to be zeroed.
mode (str): ``"batch"`` or ``"row"``.
``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes
randomly selected rows from the batch.
training: apply stochastic depth if is ``True``. Default: ``True``
Returns:
Tensor[N, ...]: The randomly zeroed tensor.
"""
if p < 0.0 or p > 1.0:
raise ValueError(f"drop probability has to be between 0 and 1, but got {p}")
if mode not in ["batch", "row"]:
raise ValueError(f"mode has to be either 'batch' or 'row', but got {mode}")
if not training or p == 0.0:
return input
survival_rate = 1.0 - p
if mode == "row":
size = [input.shape[0]] + [1] * (input.ndim - 1)
else:
size = [1] * input.ndim
noise = torch.empty(size, dtype=input.dtype, device=input.device)
noise = noise.bernoulli_(survival_rate)
if survival_rate > 0.0:
noise.div_(survival_rate)
return input * noise
class StochasticDepth(nn.Module):
def __init__(self, p: float, mode: str) -> None:
super().__init__()
self.p = p
self.mode = mode
def forward(self, input: Tensor) -> Tensor:
return stochastic_depth(input, self.p, self.mode, self.training)
def __repr__(self) -> str:
s = f"{self.__class__.__name__}(p={self.p}, mode={self.mode})"
return s
def shifted_window_attention(
input: Tensor,
qkv_weight: Tensor,
proj_weight: Tensor,
relative_position_bias: Tensor,
window_size: List[int],
num_heads: int,
shift_size: List[int],
attention_dropout: float = 0.0,
dropout: float = 0.0,
qkv_bias: Optional[Tensor] = None,
proj_bias: Optional[Tensor] = None,
logit_scale: Optional[torch.Tensor] = None,
):
"""
Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
input (Tensor[N, H, W, C]): The input tensor or 4-dimensions.
qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value.
proj_weight (Tensor[out_dim, out_dim]): The weight tensor of projection.
relative_position_bias (Tensor): The learned relative position bias added to attention.
window_size (List[int]): Window size.
num_heads (int): Number of attention heads.
shift_size (List[int]): Shift size for shifted window attention.
attention_dropout (float): Dropout ratio of attention weight. Default: 0.0.
dropout (float): Dropout ratio of output. Default: 0.0.
qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None.
proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None.
logit_scale (Tensor[out_dim], optional): Logit scale of cosine attention for Swin Transformer V2. Default: None.
Returns:
Tensor[N, H, W, C]: The output tensor after shifted window attention.
"""
B, H, W, C = input.shape
# pad feature maps to multiples of window size
pad_r = (window_size[1] - W % window_size[1]) % window_size[1]
pad_b = (window_size[0] - H % window_size[0]) % window_size[0]
x = F.pad(input, (0, 0, 0, pad_r, 0, pad_b))
_, pad_H, pad_W, _ = x.shape
shift_size = shift_size.copy()
# If window size is larger than feature size, there is no need to shift window
if window_size[0] >= pad_H:
shift_size[0] = 0
if window_size[1] >= pad_W:
shift_size[1] = 0
# cyclic shift
if sum(shift_size) > 0:
x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))
# partition windows
num_windows = (pad_H // window_size[0]) * (pad_W // window_size[1])
x = x.view(B, pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1], C)
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size[0] * window_size[1], C) # B*nW, Ws*Ws, C
# multi-head attention
if logit_scale is not None and qkv_bias is not None:
qkv_bias = qkv_bias.clone()
length = qkv_bias.numel() // 3
qkv_bias[length : 2 * length].zero_()
qkv = F.linear(x, qkv_weight, qkv_bias)
qkv = qkv.reshape(x.size(0), x.size(1), 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
if logit_scale is not None:
# cosine attention
attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
logit_scale = torch.clamp(logit_scale, max=math.log(100.0)).exp()
attn = attn * logit_scale
else:
q = q * (C // num_heads) ** -0.5
attn = q.matmul(k.transpose(-2, -1))
# add relative position bias
attn = attn + relative_position_bias
if sum(shift_size) > 0:
# generate attention mask
attn_mask = x.new_zeros((pad_H, pad_W))
h_slices = ((0, -window_size[0]), (-window_size[0], -shift_size[0]), (-shift_size[0], None))
w_slices = ((0, -window_size[1]), (-window_size[1], -shift_size[1]), (-shift_size[1], None))
count = 0
for h in h_slices:
for w in w_slices:
attn_mask[h[0] : h[1], w[0] : w[1]] = count
count += 1
attn_mask = attn_mask.view(pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1])
attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(num_windows, window_size[0] * window_size[1])
attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
attn = attn.view(x.size(0) // num_windows, num_windows, num_heads, x.size(1), x.size(1))
attn = attn + attn_mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, num_heads, x.size(1), x.size(1))
attn = F.softmax(attn, dim=-1)
attn = F.dropout(attn, p=attention_dropout)
x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), C)
x = F.linear(x, proj_weight, proj_bias)
x = F.dropout(x, p=dropout)
# reverse windows
x = x.view(B, pad_H // window_size[0], pad_W // window_size[1], window_size[0], window_size[1], C)
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, pad_H, pad_W, C)
# reverse cyclic shift
if sum(shift_size) > 0:
x = torch.roll(x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2))
# unpad features
x = x[:, :H, :W, :].contiguous()
return x
class ShiftedWindowAttention(nn.Module):
def __init__(
self,
dim: int,
window_size: List[int],
shift_size: List[int],
num_heads: int,
qkv_bias: bool = True,
proj_bias: bool = True,
attention_dropout: float = 0.0,
dropout: float = 0.0,
):
super().__init__()
if len(window_size) != 2 or len(shift_size) != 2:
raise ValueError("window_size and shift_size must be of length 2")
self.window_size = window_size
self.shift_size = shift_size
self.num_heads = num_heads
self.attention_dropout = attention_dropout
self.dropout = dropout
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.define_relative_position_bias_table()
self.define_relative_position_index()
def define_relative_position_bias_table(self):
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), self.num_heads)
) # 2*Wh-1 * 2*Ww-1, nH
nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
def define_relative_position_index(self):
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid(coords_h, coords_w)) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1).flatten() # Wh*Ww*Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
def get_relative_position_bias(self) -> torch.Tensor:
return _get_relative_position_bias(
self.relative_position_bias_table, self.relative_position_index, self.window_size # type: ignore[arg-type]
)
def forward(self, x: Tensor):
"""
Args:
x (Tensor): Tensor with layout of [B, H, W, C]
Returns:
Tensor with same layout as input, i.e. [B, H, W, C]
"""
relative_position_bias = self.get_relative_position_bias()
return shifted_window_attention(
x,
self.qkv.weight,
self.proj.weight,
relative_position_bias,
self.window_size,
self.num_heads,
shift_size=self.shift_size,
attention_dropout=self.attention_dropout,
dropout=self.dropout,
qkv_bias=self.qkv.bias,
proj_bias=self.proj.bias,
)
class ShiftedWindowAttentionV2(ShiftedWindowAttention):
def __init__(
self,
dim: int,
window_size: List[int],
shift_size: List[int],
num_heads: int,
qkv_bias: bool = True,
proj_bias: bool = True,
attention_dropout: float = 0.0,
dropout: float = 0.0,
):
super().__init__(
dim,
window_size,
shift_size,
num_heads,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
attention_dropout=attention_dropout,
dropout=dropout,
)
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
# mlp to generate continuous relative position bias
self.cpb_mlp = nn.Sequential(
nn.Linear(2, 512, bias=True), nn.ReLU(inplace=True), nn.Linear(512, num_heads, bias=False)
)
if qkv_bias:
length = self.qkv.bias.numel() // 3
self.qkv.bias[length : 2 * length].data.zero_()
def define_relative_position_bias_table(self):
# get relative_coords_table
relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
relative_coords_table = torch.stack(torch.meshgrid(relative_coords_h, relative_coords_w))
relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1
relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1
relative_coords_table *= 8 # normalize to -8, 8
relative_coords_table = (
torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / 3.0
)
self.register_buffer("relative_coords_table", relative_coords_table)
def get_relative_position_bias(self) -> torch.Tensor:
relative_position_bias = _get_relative_position_bias(
self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads),
self.relative_position_index, # type: ignore[arg-type]
self.window_size,
)
relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
return relative_position_bias
def forward(self, x: Tensor):
"""
Args:
x (Tensor): Tensor with layout of [B, H, W, C]
Returns:
Tensor with same layout as input, i.e. [B, H, W, C]
"""
relative_position_bias = self.get_relative_position_bias()
return shifted_window_attention(
x,
self.qkv.weight,
self.proj.weight,
relative_position_bias,
self.window_size,
self.num_heads,
shift_size=self.shift_size,
attention_dropout=self.attention_dropout,
dropout=self.dropout,
qkv_bias=self.qkv.bias,
proj_bias=self.proj.bias,
logit_scale=self.logit_scale,
)
class MLP(torch.nn.Sequential):
"""This block implements the multi-layer perceptron (MLP) module.
Args:
in_channels (int): Number of channels of the input
hidden_channels (List[int]): List of the hidden channel dimensions
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the linear layer. If ``None`` this layer won't be used. Default: ``None``
activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the linear layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU``
inplace (bool, optional): Parameter for the activation layer, which can optionally do the operation in-place.
Default is ``None``, which uses the respective default values of the ``activation_layer`` and Dropout layer.
bias (bool): Whether to use bias in the linear layer. Default ``True``
dropout (float): The probability for the dropout layer. Default: 0.0
"""
def __init__(
self,
in_channels: int,
hidden_channels: List[int],
norm_layer: Optional[Callable[..., torch.nn.Module]] = None,
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
inplace: Optional[bool] = None,
bias: bool = True,
dropout: float = 0.0,
):
# The addition of `norm_layer` is inspired from the implementation of TorchMultimodal:
# https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.py
params = {} if inplace is None else {"inplace": inplace}
layers = []
in_dim = in_channels
for hidden_dim in hidden_channels[:-1]:
layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias))
if norm_layer is not None:
layers.append(norm_layer(hidden_dim))
layers.append(activation_layer(**params))
layers.append(torch.nn.Dropout(dropout, **params))
in_dim = hidden_dim
layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias))
layers.append(torch.nn.Dropout(dropout, **params))
super().__init__(*layers)
class SwinTransformerBlock(nn.Module):
"""
Swin Transformer Block.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (List[int]): Window size.
shift_size (List[int]): Shift size for shifted window attention.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
dropout (float): Dropout rate. Default: 0.0.
attention_dropout (float): Attention dropout rate. Default: 0.0.
stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0.
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttention
"""
def __init__(
self,
dim: int,
num_heads: int,
window_size: List[int],
shift_size: List[int],
mlp_ratio: float = 4.0,
dropout: float = 0.0,
attention_dropout: float = 0.0,
stochastic_depth_prob: float = 0.0,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_layer: Callable[..., nn.Module] = ShiftedWindowAttention,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = attn_layer(
dim,
window_size,
shift_size,
num_heads,
attention_dropout=attention_dropout,
dropout=dropout,
)
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
self.norm2 = norm_layer(dim)
self.mlp = MLP(dim, [int(dim * mlp_ratio), dim], activation_layer=nn.GELU, inplace=None, dropout=dropout)
for m in self.mlp.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.normal_(m.bias, std=1e-6)
def forward(self, x: Tensor):
x = x + self.stochastic_depth(self.attn(self.norm1(x)))
x = x + self.stochastic_depth(self.mlp(self.norm2(x)))
return x
class SwinTransformerBlockV2(SwinTransformerBlock):
"""
Swin Transformer V2 Block.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (List[int]): Window size.
shift_size (List[int]): Shift size for shifted window attention.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
dropout (float): Dropout rate. Default: 0.0.
attention_dropout (float): Attention dropout rate. Default: 0.0.
stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0.
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttentionV2.
"""
def __init__(
self,
dim: int,
num_heads: int,
window_size: List[int],
shift_size: List[int],
mlp_ratio: float = 4.0,
dropout: float = 0.0,
attention_dropout: float = 0.0,
stochastic_depth_prob: float = 0.0,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_layer: Callable[..., nn.Module] = ShiftedWindowAttentionV2,
):
super().__init__(
dim,
num_heads,
window_size,
shift_size,
mlp_ratio=mlp_ratio,
dropout=dropout,
attention_dropout=attention_dropout,
stochastic_depth_prob=stochastic_depth_prob,
norm_layer=norm_layer,
attn_layer=attn_layer,
)
def forward(self, x: Tensor):
# Here is the difference, we apply norm after the attention in V2.
# In V1 we applied norm before the attention.
x = x + self.stochastic_depth(self.norm1(self.attn(x)))
x = x + self.stochastic_depth(self.norm2(self.mlp(x)))
return x
class SwinTransformer(nn.Module):
def __init__(self, dim, num_layers):
"""
A feature stage consisting of a series of Swin Transformer V2 blocks.
Parameters:
-----------
dim: int
Dimension of the input features.
"""
super().__init__()
self.dim = dim
self.depth = num_layers
self.num_heads = dim // 32
self.window_size = 8
self.base_sd_prob = 0.2
shift_size = [
[self.window_size // 2] * 2 if i % 2
else [0, 0] for i in range(self.depth)
]
# TODO fix this hack
# Use stochastic depth parameters from the third stage of Swin-T variant.
# In practice, varying this value does not make a significant difference.
sd_prob = (torch.linspace(0, 1, 12)[10-num_layers:10] * self.base_sd_prob).tolist()
blocks: List[nn.Module] = []
for i in range(self.depth):
blocks.append(SwinTransformerBlockV2(
dim=dim, num_heads=self.num_heads,
window_size=[self.window_size] * 2,
shift_size=shift_size[i],
stochastic_depth_prob=sd_prob[i]
))
self.blocks = nn.Sequential(*blocks)
def forward(self, x):
"""
Parameters:
-----------
x: Tensor
Input features maps of shape (B, H, W, C).
Returns:
--------
Tensor
Output feature maps of shape (B, H, W, C).
"""
return self.blocks(x)