-
Notifications
You must be signed in to change notification settings - Fork 9
/
ecbsr.py
202 lines (185 loc) · 8.88 KB
/
ecbsr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
'''
Add some future on ECBSR
support 1x1_3x3, 3x3_1x1, residual, 1x1
'''
class SeqConv3x3(nn.Module):
def __init__(self, seq_type, inp_planes, out_planes, depth_multiplier):
super(SeqConv3x3, self).__init__()
self.type = seq_type
self.inp_planes = inp_planes
self.out_planes = out_planes
if self.type == 'conv1x1-conv3x3':
# all have bias
self.mid_planes = int(out_planes * depth_multiplier)
conv0 = torch.nn.Conv2d(self.inp_planes, self.mid_planes, kernel_size=1, padding=0)
self.k0 = conv0.weight
self.b0 = conv0.bias
# without padding is important, NEED to add bias1 for padding
# plain version 3*3 has padding=1
conv1 = torch.nn.Conv2d(self.mid_planes, self.out_planes, kernel_size=3)
self.k1 = conv1.weight
self.b1 = conv1.bias
elif self.type == 'conv1x1-sobelx':
conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0)
self.k0 = conv0.weight
self.b0 = conv0.bias
# init scale & bias
scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3
self.scale = nn.Parameter(scale)
# bias = 0.0
# bias = [bias for c in range(self.out_planes)]
# bias = torch.FloatTensor(bias)
bias = torch.randn(self.out_planes) * 1e-3
bias = torch.reshape(bias, (self.out_planes,))
self.bias = nn.Parameter(bias)
# init mask
self.mask = torch.zeros((self.out_planes, 1, 3, 3), dtype=torch.float32)
for i in range(self.out_planes):
self.mask[i, 0, 0, 0] = 1.0
self.mask[i, 0, 1, 0] = 2.0
self.mask[i, 0, 2, 0] = 1.0
self.mask[i, 0, 0, 2] = -1.0
self.mask[i, 0, 1, 2] = -2.0
self.mask[i, 0, 2, 2] = -1.0
self.mask = nn.Parameter(data=self.mask, requires_grad=False)
elif self.type == 'conv1x1-sobely':
conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0)
self.k0 = conv0.weight
self.b0 = conv0.bias
# init scale & bias
scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3
self.scale = nn.Parameter(torch.FloatTensor(scale))
# bias = 0.0
# bias = [bias for c in range(self.out_planes)]
# bias = torch.FloatTensor(bias)
bias = torch.randn(self.out_planes) * 1e-3
bias = torch.reshape(bias, (self.out_planes,))
self.bias = nn.Parameter(torch.FloatTensor(bias))
# init mask
self.mask = torch.zeros((self.out_planes, 1, 3, 3), dtype=torch.float32)
for i in range(self.out_planes):
self.mask[i, 0, 0, 0] = 1.0
self.mask[i, 0, 0, 1] = 2.0
self.mask[i, 0, 0, 2] = 1.0
self.mask[i, 0, 2, 0] = -1.0
self.mask[i, 0, 2, 1] = -2.0
self.mask[i, 0, 2, 2] = -1.0
self.mask = nn.Parameter(data=self.mask, requires_grad=False)
elif self.type == 'conv1x1-laplacian':
conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0)
self.k0 = conv0.weight
self.b0 = conv0.bias
# init scale & bias
scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3
self.scale = nn.Parameter(torch.FloatTensor(scale))
# bias = 0.0
# bias = [bias for c in range(self.out_planes)]
# bias = torch.FloatTensor(bias)
bias = torch.randn(self.out_planes) * 1e-3
bias = torch.reshape(bias, (self.out_planes,))
self.bias = nn.Parameter(torch.FloatTensor(bias))
# init mask
self.mask = torch.zeros((self.out_planes, 1, 3, 3), dtype=torch.float32)
for i in range(self.out_planes):
self.mask[i, 0, 0, 1] = 1.0
self.mask[i, 0, 1, 0] = 1.0
self.mask[i, 0, 1, 2] = 1.0
self.mask[i, 0, 2, 1] = 1.0
self.mask[i, 0, 1, 1] = -4.0
self.mask = nn.Parameter(data=self.mask, requires_grad=False)
elif self.type == 'conv3x3-conv1x1':
# all have bias
self.mid_planes = int(out_planes * depth_multiplier)
conv0 = torch.nn.Conv2d(self.inp_planes, self.mid_planes, kernel_size=3, padding=1)
self.k0 = conv0.weight
self.b0 = conv0.bias
# without padding is important, NEED to add bias1 for padding
# plain version 3*3 has padding=1
conv1 = torch.nn.Conv2d(self.mid_planes, self.out_planes, kernel_size=1, padding=0)
self.k1 = conv1.weight
self.b1 = conv1.bias
elif self.type == 'conv1x1':
conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0)
self.k0 = conv0.weight
self.b0 = conv0.bias
else:
raise ValueError('the type of seqconv is not supported!')
def forward(self, x):
if self.type == 'conv1x1-conv3x3':
# conv-1x1
y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
# explicitly padding with bias
y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
b0_pad = self.b0.view(1, -1, 1, 1) # 1*B*1*1
y0[:, :, 0:1, :] = b0_pad
y0[:, :, -1:, :] = b0_pad
y0[:, :, :, 0:1] = b0_pad
y0[:, :, :, -1:] = b0_pad
# conv-3x3, without padding
y1 = F.conv2d(input=y0, weight=self.k1, bias=self.b1, stride=1)
elif self.type == 'conv3x3-conv1x1':
y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1, padding=1)
y1 = F.conv2d(input=y0, weight=self.k1, bias=self.b1, stride=1)
elif self.type == 'conv1x1':
y1 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1, padding=0)
else:
y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
# explicitly padding with bias
y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
b0_pad = self.b0.view(1, -1, 1, 1)
y0[:, :, 0:1, :] = b0_pad
y0[:, :, -1:, :] = b0_pad
y0[:, :, :, 0:1] = b0_pad
y0[:, :, :, -1:] = b0_pad
# conv-3x3
y1 = F.conv2d(input=y0, weight=self.scale * self.mask, bias=self.bias, stride=1, groups=self.out_planes)
return y1
def rep_params(self):
device = self.k0.get_device()
if device < 0:
device = None
if self.type == 'conv1x1-conv3x3':
# re-param conv kernel, tranpose 1*1 weight, A * B * 1 * 1
# new kernel is C * A * 1 * 1
RK = F.conv2d(input=self.k1, weight=self.k0.permute(1, 0, 2, 3))
# re-param conv bias
RB = torch.ones(1, self.mid_planes, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
RB = F.conv2d(input=RB, weight=self.k1).view(-1,) + self.b1
elif self.type == 'conv3x3-conv1x1':
RK = torch.zeros(self.k1.data.shape[0], self.k0.data.shape[1], 3, 3)
RB = torch.zeros(self.k1.data.shape[0])
for i in range(self.k1.data.shape[0]):
RK[i, ...] = torch.sum(self.k0.data * \
self.k1.data[i, ...].unsqueeze(1), dim=0)
RB[i] = self.b1.data[i] + torch.sum(self.b0.data * self.k1.data[i, ...].squeeze(1).squeeze(1))
elif self.type == 'conv1x1':
RK = F.pad(self.k0, (1, 1, 1, 1), 'constant', 0)
RB = self.b0.data
else:
tmp = self.scale * self.mask
k1 = torch.zeros((self.out_planes, self.out_planes, 3, 3), device=device)
for i in range(self.out_planes):
k1[i, i, :, :] = tmp[i, 0, :, :]
b1 = self.bias
# re-param conv kernel
RK = F.conv2d(input=k1, weight=self.k0.permute(1, 0, 2, 3))
# re-param conv bias
RB = torch.ones(1, self.out_planes, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
RB = F.conv2d(input=RB, weight=k1).view(-1,) + b1
return RK, RB
def unittest():
#test seq-conv
x = torch.randn(1, 3, 100, 100) * 100
#conv = SeqConv3x3('conv3x3-conv1x1', 3, 3, 2)
conv = SeqConv3x3('conv1x1', 3, 3, 2)
y0 = conv(x)
print(y0.shape)
RK, RB = conv.rep_params()
y1 = F.conv2d(input=x, weight=RK, bias=RB, stride=1, padding=1) # default valid 3x3 conv, padding is zero, for inference
print(torch.sum(y0-y1), torch.mean(y0-y1))
#tensor(0.0001, grad_fn=<SumBackward0>) tensor(4.4449e-09, grad_fn=<MeanBackward0>)