-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnet.py
89 lines (82 loc) · 3.32 KB
/
net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
from compressai.models import ScaleHyperprior
from compressai.entropy_models import EntropyBottleneck, GaussianConditional, EntropyModel
from sga import Quantizator_SGA
import numpy as np
class EntropyBottleneckNoQuant(EntropyBottleneck):
def __init__(self, channels):
super().__init__(channels)
self.sga = Quantizator_SGA()
def forward(self, x_quant):
perm = np.arange(len(x_quant.shape))
perm[0], perm[1] = perm[1], perm[0]
# Compute inverse permutation
inv_perm = np.arange(len(x_quant.shape))[np.argsort(perm)]
x_quant = x_quant.permute(*perm).contiguous()
shape = x_quant.size()
x_quant = x_quant.reshape(x_quant.size(0), 1, -1)
likelihood = self._likelihood(x_quant)
if self.use_likelihood_bound:
likelihood = self.likelihood_lower_bound(likelihood)
# Convert back to input tensor shape
likelihood = likelihood.reshape(shape)
likelihood = likelihood.permute(*inv_perm).contiguous()
return likelihood
class GaussianConditionalNoQuant(GaussianConditional):
def __init__(self, scale_table):
super().__init__(scale_table=scale_table)
def forward(self, x_quant, scales, means):
likelihood = self._likelihood(x_quant, scales, means)
if self.use_likelihood_bound:
likelihood = self.likelihood_lower_bound(likelihood)
return likelihood
class ScaleHyperpriorSGA(ScaleHyperprior):
def __init__(self, N, M, **kwargs):
super().__init__(N, M, **kwargs)
self.entropy_bottleneck = EntropyBottleneckNoQuant(N)
self.gaussian_conditional = GaussianConditionalNoQuant(None)
self.sga = Quantizator_SGA()
def quantize(self, inputs, mode, means=None, it=None, tot_it=None):
if means is not None:
inputs = inputs - means
if mode == "noise":
half = float(0.5)
noise = torch.empty_like(inputs).uniform_(-half, half)
outputs = inputs + noise
elif mode == "round":
outputs = torch.round(inputs)
elif mode == "sga":
outputs = self.sga(inputs, it, "training", tot_it)
else:
assert(0)
if means is not None:
outputs = outputs + means
return outputs
def forward(self, x, mode, y_in=None, z_in=None, it=None, tot_it=None):
if mode == "init":
y = self.g_a(x)
z = self.h_a(torch.abs(y))
else:
y = y_in
z = z_in
if mode == "init" or mode == "round":
y_hat = self.quantize(y, "round")
z_hat = self.quantize(z, "round")
elif mode == "noise":
y_hat = self.quantize(y, "noise")
z_hat = self.quantize(z, "noise")
elif mode =="sga":
y_hat = self.quantize(y, "sga", None, it, tot_it)
z_hat = self.quantize(z, "sga", None, it, tot_it)
else:
assert(0)
z_likelihoods = self.entropy_bottleneck(z)
scales_hat = self.h_s(z_hat)
y_likelihoods = self.gaussian_conditional(y_hat, scales_hat, None)
x_hat = self.g_s(y_hat)
return {
"y": y.detach().clone(),
"z": z.detach().clone(),
"x_hat": x_hat,
"likelihoods": {"y": y_likelihoods, "z": z_likelihoods},
}