-
Notifications
You must be signed in to change notification settings - Fork 35
/
Copy pathcvt.py
121 lines (92 loc) · 4.23 KB
/
cvt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
from torch import nn
from einops import repeat
from einops.layers.torch import Rearrange
from module import ConvAttention, PreNorm, FeedForward
import numpy as np
class Transformer(nn.Module):
def __init__(self, dim, img_size, depth, heads, dim_head, mlp_dim, dropout=0., last_stage=False):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, ConvAttention(dim, img_size, heads=heads, dim_head=dim_head, dropout=dropout, last_stage=last_stage)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class CvT(nn.Module):
def __init__(self, image_size, in_channels, num_classes, dim=64, kernels=[7, 3, 3], strides=[4, 2, 2],
heads=[1, 3, 6] , depth = [1, 2, 10], pool='cls', dropout=0., emb_dropout=0., scale_dim=4):
super().__init__()
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.pool = pool
self.dim = dim
##### Stage 1 #######
self.stage1_conv_embed = nn.Sequential(
nn.Conv2d(in_channels, dim, kernels[0], strides[0], 2),
Rearrange('b c h w -> b (h w) c', h = image_size//4, w = image_size//4),
nn.LayerNorm(dim)
)
self.stage1_transformer = nn.Sequential(
Transformer(dim=dim, img_size=image_size//4,depth=depth[0], heads=heads[0], dim_head=self.dim,
mlp_dim=dim * scale_dim, dropout=dropout),
Rearrange('b (h w) c -> b c h w', h = image_size//4, w = image_size//4)
)
##### Stage 2 #######
in_channels = dim
scale = heads[1]//heads[0]
dim = scale*dim
self.stage2_conv_embed = nn.Sequential(
nn.Conv2d(in_channels, dim, kernels[1], strides[1], 1),
Rearrange('b c h w -> b (h w) c', h = image_size//8, w = image_size//8),
nn.LayerNorm(dim)
)
self.stage2_transformer = nn.Sequential(
Transformer(dim=dim, img_size=image_size//8, depth=depth[1], heads=heads[1], dim_head=self.dim,
mlp_dim=dim * scale_dim, dropout=dropout),
Rearrange('b (h w) c -> b c h w', h = image_size//8, w = image_size//8)
)
##### Stage 3 #######
in_channels = dim
scale = heads[2] // heads[1]
dim = scale * dim
self.stage3_conv_embed = nn.Sequential(
nn.Conv2d(in_channels, dim, kernels[2], strides[2], 1),
Rearrange('b c h w -> b (h w) c', h = image_size//16, w = image_size//16),
nn.LayerNorm(dim)
)
self.stage3_transformer = nn.Sequential(
Transformer(dim=dim, img_size=image_size//16, depth=depth[2], heads=heads[2], dim_head=self.dim,
mlp_dim=dim * scale_dim, dropout=dropout, last_stage=True),
)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout_large = nn.Dropout(emb_dropout)
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
xs = self.stage1_conv_embed(img)
xs = self.stage1_transformer(xs)
xs = self.stage2_conv_embed(xs)
xs = self.stage2_transformer(xs)
xs = self.stage3_conv_embed(xs)
b, n, _ = xs.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
xs = torch.cat((cls_tokens, xs), dim=1)
xs = self.stage3_transformer(xs)
xs = xs.mean(dim=1) if self.pool == 'mean' else xs[:, 0]
xs = self.mlp_head(xs)
return xs
if __name__ == "__main__":
img = torch.ones([1, 3, 224, 224])
model = CvT(224, 3, 1000)
parameters = filter(lambda p: p.requires_grad, model.parameters())
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
print('Trainable Parameters: %.3fM' % parameters)
out = model(img)
print("Shape of out :", out.shape) # [B, num_classes]