-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathaux_unet.py
executable file
·387 lines (324 loc) · 12.4 KB
/
aux_unet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch, math
from torch import nn
from torch.nn import functional as F
from typing import List, Tuple
class NormUnet(nn.Module):
"""
Normalized U-Net model.
This is the same as a regular U-Net, but with normalization applied to the
input before the U-Net. This keeps the values more numerically stable
during training.
"""
def __init__(
self,
chans: int,
num_pools: int,
in_chans: int = 2,
out_chans: int = 2,
drop_prob: float = 0.0,
):
"""
Args:
chans: Number of output channels of the first convolution layer.
num_pools: Number of down-sampling and up-sampling layers.
in_chans: Number of channels in the input to the U-Net model.
out_chans: Number of channels in the output to the U-Net model.
drop_prob: Dropout probability.
"""
super().__init__()
self.unet = Unet(
in_chans=in_chans,
out_chans=out_chans,
chans=chans,
num_pool_layers=num_pools,
drop_prob=drop_prob,
)
def complex_to_chan_dim(self, x: torch.Tensor) -> torch.Tensor:
b, c, h, w, two = x.shape
assert two == 2
return x.permute(0, 4, 1, 2, 3).reshape(b, 2 * c, h, w)
def chan_complex_to_last_dim(self, x: torch.Tensor) -> torch.Tensor:
b, c2, h, w = x.shape
assert c2 % 2 == 0
c = c2 // 2
return x.view(b, 2, c, h, w).permute(0, 2, 3, 4, 1).contiguous()
def norm(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# group norm
b, c, h, w = x.shape
x = x.reshape(b, 2, c // 2 * h * w)
mean = x.mean(dim=2).view(b, c, 1, 1)
std = x.std(dim=2).view(b, c, 1, 1)
x = x.view(b, c, h, w)
return (x - mean) / std, mean, std
def unnorm(
self, x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor
) -> torch.Tensor:
return x * std + mean
def pad(
self, x: torch.Tensor
) -> Tuple[torch.Tensor, Tuple[List[int], List[int], int, int]]:
_, _, h, w = x.shape
w_mult = ((w - 1) | 15) + 1
h_mult = ((h - 1) | 15) + 1
w_pad = [math.floor((w_mult - w) / 2), math.ceil((w_mult - w) / 2)]
h_pad = [math.floor((h_mult - h) / 2), math.ceil((h_mult - h) / 2)]
# TODO: fix this type when PyTorch fixes theirs
# the documentation lies - this actually takes a list
# https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L3457
# https://github.com/pytorch/pytorch/pull/16949
x = F.pad(x, w_pad + h_pad)
return x, (h_pad, w_pad, h_mult, w_mult)
def unpad(
self,
x: torch.Tensor,
h_pad: List[int],
w_pad: List[int],
h_mult: int,
w_mult: int,
) -> torch.Tensor:
return x[..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1]]
def forward(self, x: torch.Tensor) -> torch.Tensor:
if not x.shape[-1] == 2:
raise ValueError("Last dimension must be 2 for complex.")
# get shapes for unet and normalize
x = self.complex_to_chan_dim(x)
x, mean, std = self.norm(x)
x, pad_sizes = self.pad(x)
x = self.unet(x)
# get shapes back and unnormalize
x = self.unpad(x, *pad_sizes)
x = self.unnorm(x, mean, std)
x = self.chan_complex_to_last_dim(x)
return x
class FlippedNormUnet(nn.Module):
"""
Normalized U-Net model.
This is the same as a regular U-Net, but with normalization applied to the
input before the U-Net. This keeps the values more numerically stable
during training.
"""
def __init__(
self,
chans: int,
num_pools: int,
in_chans: int = 2,
out_chans: int = 2,
drop_prob: float = 0.0,
):
"""
Args:
chans: Number of output channels of the first convolution layer.
num_pools: Number of down-sampling and up-sampling layers.
in_chans: Number of channels in the input to the U-Net model.
out_chans: Number of channels in the output to the U-Net model.
drop_prob: Dropout probability.
"""
super().__init__()
self.unet = Unet(
in_chans=in_chans,
out_chans=out_chans,
chans=chans,
num_pool_layers=num_pools,
drop_prob=drop_prob,
)
def complex_to_chan_dim(self, x: torch.Tensor) -> torch.Tensor:
b, c, h, w, two = x.shape
assert two == 2
return x.permute(0, 4, 1, 2, 3).reshape(b, 2 * c, h, w)
def chan_complex_to_last_dim(self, x: torch.Tensor) -> torch.Tensor:
b, c2, h, w = x.shape
assert c2 % 2 == 0
c = c2 // 2
return x.view(b, 2, c, h, w).permute(0, 2, 3, 4, 1).contiguous()
def norm(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# group norm
b, c, h, w = x.shape
x = x.reshape(b, 2, c // 2 * h * w)
mean = x.mean(dim=2).view(b, c, 1, 1)
std = x.std(dim=2).view(b, c, 1, 1)
x = x.view(b, c, h, w)
return (x - mean) / std, mean, std
def unnorm(
self, x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor
) -> torch.Tensor:
return x * std + mean
def pad(
self, x: torch.Tensor
) -> Tuple[torch.Tensor, Tuple[List[int], List[int], int, int]]:
_, _, h, w = x.shape
w_mult = ((w - 1) | 15) + 1
h_mult = ((h - 1) | 15) + 1
w_pad = [math.floor((w_mult - w) / 2), math.ceil((w_mult - w) / 2)]
h_pad = [math.floor((h_mult - h) / 2), math.ceil((h_mult - h) / 2)]
# TODO: fix this type when PyTorch fixes theirs
# the documentation lies - this actually takes a list
# https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L3457
# https://github.com/pytorch/pytorch/pull/16949
x = F.pad(x, w_pad + h_pad)
return x, (h_pad, w_pad, h_mult, w_mult)
def unpad(
self,
x: torch.Tensor,
h_pad: List[int],
w_pad: List[int],
h_mult: int,
w_mult: int,
) -> torch.Tensor:
return x[..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1]]
def forward(self, x: torch.Tensor) -> torch.Tensor:
if not x.shape[-1] == 2:
raise ValueError("Last dimension must be 2 for complex.")
# get shapes for unet and normalize
n = self.complex_to_chan_dim(x)
n, mean, std = self.norm(n)
n, pad_sizes = self.pad(n)
n = self.unet(n)
# get shapes back and unnormalize
n = self.unpad(n, *pad_sizes)
n = self.unnorm(n, mean, std)
n = self.chan_complex_to_last_dim(n)
return x - n
class Unet(nn.Module):
"""
PyTorch implementation of a U-Net model.
O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks
for biomedical image segmentation. In International Conference on Medical
image computing and computer-assisted intervention, pages 234–241.
Springer, 2015.
"""
def __init__(
self,
in_chans: int,
out_chans: int,
chans: int = 32,
num_pool_layers: int = 4,
drop_prob: float = 0.0,
):
"""
Args:
in_chans: Number of channels in the input to the U-Net model.
out_chans: Number of channels in the output to the U-Net model.
chans: Number of output channels of the first convolution layer.
num_pool_layers: Number of down-sampling and up-sampling layers.
drop_prob: Dropout probability.
"""
super().__init__()
self.in_chans = in_chans
self.out_chans = out_chans
self.chans = chans
self.num_pool_layers = num_pool_layers
self.drop_prob = drop_prob
self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)])
ch = chans
for _ in range(num_pool_layers - 1):
self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob))
ch *= 2
self.conv = ConvBlock(ch, ch * 2, drop_prob)
self.up_conv = nn.ModuleList()
self.up_transpose_conv = nn.ModuleList()
for _ in range(num_pool_layers - 1):
self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch))
self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob))
ch //= 2
self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch))
self.up_conv.append(
nn.Sequential(
ConvBlock(ch * 2, ch, drop_prob),
nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1),
)
)
def forward(self, image: torch.Tensor) -> torch.Tensor:
"""
Args:
image: Input 4D tensor of shape `(N, in_chans, H, W)`.
Returns:
Output tensor of shape `(N, out_chans, H, W)`.
"""
stack = []
output = image
# apply down-sampling layers
for layer in self.down_sample_layers:
output = layer(output)
stack.append(output)
output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0)
output = self.conv(output)
# apply up-sampling layers
for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv):
downsample_layer = stack.pop()
output = transpose_conv(output)
# reflect pad on the right/botton if needed to handle odd input dimensions
padding = [0, 0, 0, 0]
if output.shape[-1] != downsample_layer.shape[-1]:
padding[1] = 1 # padding right
if output.shape[-2] != downsample_layer.shape[-2]:
padding[3] = 1 # padding bottom
if torch.sum(torch.tensor(padding)) != 0:
output = F.pad(output, padding, "reflect")
output = torch.cat([output, downsample_layer], dim=1)
output = conv(output)
return output
class ConvBlock(nn.Module):
"""
A Convolutional Block that consists of two convolution layers each followed by
instance normalization, LeakyReLU activation and dropout.
"""
def __init__(self, in_chans: int, out_chans: int, drop_prob: float):
"""
Args:
in_chans: Number of channels in the input.
out_chans: Number of channels in the output.
drop_prob: Dropout probability.
"""
super().__init__()
self.in_chans = in_chans
self.out_chans = out_chans
self.drop_prob = drop_prob
self.layers = nn.Sequential(
nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False),
nn.InstanceNorm2d(out_chans),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Dropout2d(drop_prob),
nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False),
nn.InstanceNorm2d(out_chans),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Dropout2d(drop_prob),
)
def forward(self, image: torch.Tensor) -> torch.Tensor:
"""
Args:
image: Input 4D tensor of shape `(N, in_chans, H, W)`.
Returns:
Output tensor of shape `(N, out_chans, H, W)`.
"""
return self.layers(image)
class TransposeConvBlock(nn.Module):
"""
A Transpose Convolutional Block that consists of one convolution transpose
layers followed by instance normalization and LeakyReLU activation.
"""
def __init__(self, in_chans: int, out_chans: int):
"""
Args:
in_chans: Number of channels in the input.
out_chans: Number of channels in the output.
"""
super().__init__()
self.in_chans = in_chans
self.out_chans = out_chans
self.layers = nn.Sequential(
nn.ConvTranspose2d(
in_chans, out_chans, kernel_size=2, stride=2, bias=False
),
nn.InstanceNorm2d(out_chans),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
)
def forward(self, image: torch.Tensor) -> torch.Tensor:
"""
Args:
image: Input 4D tensor of shape `(N, in_chans, H, W)`.
Returns:
Output tensor of shape `(N, out_chans, H*2, W*2)`.
"""
return self.layers(image)