-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodels.py
117 lines (102 loc) · 4 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
import torch.nn as nn
import pretrainedmodels
import glob
model_configs = {'polynet':{
'input_size': 512,
'input_mean': [0.485, 0.456, 0.406, 0.406],
'input_std' : [0.229, 0.224, 0.225, 0.225]
},
'resnet101':{
'input_size': 512,
'input_mean': [0.485, 0.456, 0.406, 0.406],
'input_std' : [0.229, 0.224, 0.225, 0.225]
},
'resnet50':{
'input_size': 512,
'input_mean': [0.485, 0.456, 0.406, 0.406],
'input_std' : [0.229, 0.224, 0.225, 0.225]
},
'resnext101_32x4d':{
'input_size': 512,
'input_mean': [0.485, 0.456, 0.406, 0.406],
'input_std' : [0.229, 0.224, 0.225, 0.225]
},
'se_resnext50_32x4d':{
'input_size': 512,
'input_mean': [0.485, 0.456, 0.406, 0.406],
'input_std' : [0.229, 0.224, 0.225, 0.225]
},
'inceptionresnetv2':{
'input_size': 512,
'input_mean': [0.5],
'input_std' : [0.5]
},
'xception':{
'input_size': 512,
'input_mean': [0.5],
'input_std' : [0.5]
},
'dpn68':{
'input_size': 512,
'input_mean': [0.5],
'input_std' : [0.5]
},
'dpn98':{
'input_size': 512,
'input_mean': [0.5],
'input_std' : [0.5]
},
'inceptionv4':{
'input_size': 512,
'input_mean': [0.5],
'input_std' : [0.5]
},
}
def construct_rgby_model(model_name, split):
"""
Handle 4 dimensional input
"""
model = pretrainedmodels.__dict__[model_name](num_classes = 1000)
modules = list(model.modules())
first_conv_idx = list(filter(lambda x: isinstance(modules[x], nn.Conv2d), list(range(len(modules)))))[0]
conv_layer = modules[first_conv_idx]
container = modules[first_conv_idx - 1]
params = [x.clone() for x in conv_layer.parameters()]
kernel_size = params[0].size()
new_kernel_size = kernel_size[:1] + (4, ) + kernel_size[2:]
new_kernels = params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous()
new_conv = nn.Conv2d(4, conv_layer.out_channels,
conv_layer.kernel_size, conv_layer.stride, conv_layer.padding,
bias=True if len(params) == 2 else False)
new_conv.weight.data = new_kernels
if len(params) == 2:
new_conv.bias.data = params[1].data
layer_name = list(container.state_dict().keys())[0][:-7]
setattr(container, layer_name, new_conv)
"""
Handle 512 input size by changing the average pooling layer
"""
if 'resnet' in model_name:
model.avgpool = nn.AdaptiveAvgPool2d(1)
elif 'resnext' in model_name:
model.avg_pool = nn.AdaptiveAvgPool2d(1)
"""
Changing the last linear layer output to 28
"""
if 'dpn' in model_name:
in_channels = model.last_linear.in_channels
kernel_size = model.last_linear.kernel_size
model.last_linear = nn.Conv2d(in_channels, 28, kernel_size)
else:
num_features = model.last_linear.in_features
model.last_linear = nn.Linear(num_features, 28)
"""
Load existing checkpoint files
"""
if glob.glob('{}_rgby_focal_{}*'.format(model_name, split)):
pth_file = torch.load('{}_rgby_focal_{}.pth.tar'.format(model_name, split))
state_dict = pth_file['state_dict']
model.load_state_dict(state_dict)
start_epoch = pth_file['epoch']
return model