-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
82 lines (69 loc) · 2.44 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch.nn.functional as F
import torch.nn as nn
from transformer import MHAEncoderXY
class BaseFeatureExtractor(nn.Module):
'''
a base class for feature extractor
'''
def forward(self, *input):
pass
def __init__(self):
super(BaseFeatureExtractor, self).__init__()
def output_num(self):
pass
def train(self, mode=True):
# freeze BN mean and std
for module in self.children():
if isinstance(module, nn.BatchNorm2d):
module.train(False)
else:
module.train(mode)
class MLPTransXY(BaseFeatureExtractor):
def __init__(self, input_dim, hidden_dim):
super(MLPTransXY, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.MHA = MHAEncoderXY(in_dim=hidden_dim, gap=4, n_heads=4, emb_dim=hidden_dim)
self.output_dim = self.MHA.output_dim
def forward(self, x, pos_x, pos_y):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
residual = x
x = self.MHA(x, pos_x, pos_y)
x += residual
return x
class ProtoCLS(nn.Module):
"""
prototype-based classifier
L2-norm + a fc layer (without bias)
"""
def __init__(self, in_dim, out_dim, temp=0.05):
super(ProtoCLS, self).__init__()
self.fc = nn.Linear(in_dim, out_dim, bias=False)
self.tmp = temp
self.weight_norm()
def forward(self, x):
x = F.normalize(x)
x = self.fc(x) / self.tmp
return x
def weight_norm(self):
w = self.fc.weight.data
norm = w.norm(p=2, dim=1, keepdim=True)
self.fc.weight.data = w.div(norm.expand_as(w))
class CLS(nn.Module):
"""
a classifier made up of projection head and prototype-based classifier
"""
def __init__(self, in_dim, out_dim, hidden_mlp=2048, feat_dim=256, temp=0.05):
super(CLS, self).__init__()
self.projection_head = nn.Sequential(
nn.Linear(in_dim, hidden_mlp),
nn.ReLU(inplace=True),
nn.Linear(hidden_mlp, feat_dim))
self.ProtoCLS = ProtoCLS(feat_dim, out_dim, temp)
def forward(self, x):
before_lincls_feat = self.projection_head(x)
after_lincls = self.ProtoCLS(before_lincls_feat)
return before_lincls_feat, after_lincls