-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcustom_model.py
74 lines (62 loc) · 2.72 KB
/
custom_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from compressai.models import MeanScaleHyperprior
from torch import nn
from compressai.models.utils import conv, deconv
class CustomMeanScaleHyperprior(MeanScaleHyperprior):
r"""Scale Hyperprior with non zero-mean Gaussian conditionals from D.
Minnen, J. Balle, G.D. Toderici: `"Joint Autoregressive and Hierarchical
Priors for Learned Image Compression" <https://arxiv.org/abs/1809.02736>`_,
Adv. in Neural Information Processing Systems 31 (NeurIPS 2018).
Args:
N (int): Number of channels
M (int): Number of channels in the expansion layers (last layer of the
encoder and last layer of the hyperprior decoder)
"""
def __init__(self, N, M, **kwargs):
super().__init__(N, M, **kwargs)
self.h_a = nn.Sequential(
conv(M, N, stride=1, kernel_size=3),
nn.LeakyReLU(inplace=True),
conv(N, N),
nn.LeakyReLU(inplace=True),
conv(N, N),
)
self.h_s = nn.Sequential(
deconv(N, M),
nn.LeakyReLU(inplace=True),
deconv(M, M * 3 // 2),
nn.LeakyReLU(inplace=True),
conv(M * 3 // 2, M * 2, stride=1, kernel_size=3),
)
def forward(self, x):
y = self.g_a(x)
z = self.h_a(y)
z_hat, z_likelihoods = self.entropy_bottleneck(z)
gaussian_params = self.h_s(z_hat)
scales_hat, means_hat = gaussian_params.chunk(2, 1)
y_hat, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat)
x_hat = self.g_s(y_hat)
return {
"x_hat": x_hat,
"likelihoods": {"y": y_likelihoods, "z": z_likelihoods},
}
def compress(self, x):
y = self.g_a(x)
z = self.h_a(y)
z_strings = self.entropy_bottleneck.compress(z)
z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:])
gaussian_params = self.h_s(z_hat)
scales_hat, means_hat = gaussian_params.chunk(2, 1)
indexes = self.gaussian_conditional.build_indexes(scales_hat)
y_strings = self.gaussian_conditional.compress(y, indexes, means=means_hat)
return {"strings": [y_strings, z_strings], "shape": z.size()[-2:]}
def decompress(self, strings, shape):
assert isinstance(strings, list) and len(strings) == 2
z_hat = self.entropy_bottleneck.decompress(strings[1], shape)
gaussian_params = self.h_s(z_hat)
scales_hat, means_hat = gaussian_params.chunk(2, 1)
indexes = self.gaussian_conditional.build_indexes(scales_hat)
y_hat = self.gaussian_conditional.decompress(
strings[0], indexes, means=means_hat
)
x_hat = self.g_s(y_hat).clamp_(0, 1)
return {"x_hat": x_hat}