-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcommons.py
123 lines (102 loc) · 4.53 KB
/
commons.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import shutil
import random
import numpy as np
from torchinfo import summary
from omegaconf import DictConfig
import torch as th
from model.VQVAE import VQVAEModule, VQVAECNN, VQVAELinear
from model.PixelLinear import PixelLinear
from model.PixelCNN import PixelCNN, PixObsNet, PixObsNetCNN
def get_vqvae(vqvae_cfg: DictConfig, mean, std, device="cpu") -> VQVAEModule:
seed_mch()
vqvae = None
if vqvae_cfg.type == "cnn":
vqvae = VQVAECNN(
mean=mean, std=std,
out_dim=vqvae_cfg.encoder.output_dim,
hid_channels=vqvae_cfg.encoder.hid_channels,
out_channels=vqvae_cfg.codebook.num_features,
kernel_size=vqvae_cfg.encoder.kernel_size,
padding=vqvae_cfg.encoder.padding,
num_layers=vqvae_cfg.encoder.num_layers,
num_embeddings=vqvae_cfg.codebook.num_embeddings,
embedding_dim=vqvae_cfg.codebook.latent_dim,
decoder_shapes=vqvae_cfg.decoder.shape,
init_in_dim=vqvae_cfg.initializer.input_dim,
init_out_dim=vqvae_cfg.initializer.output_dim,
init_shape=vqvae_cfg.initializer.shape,
device=device, num_batch=vqvae_cfg.batch_size
)
elif vqvae_cfg.type == "linear":
vqvae = VQVAELinear(
inp_dim=vqvae_cfg.input_dim,
out_dim=vqvae_cfg.output_dim,
num_embeddings=vqvae_cfg.num_embeddings,
embedding_dim=vqvae_cfg.latent_dim,
encoder_shapes=vqvae_cfg.encoder_shape,
num_features=vqvae_cfg.num_features,
decoder_shapes=vqvae_cfg.decoder_shape,
device=device, num_batch=vqvae_cfg.batch_size
)
for key, value in vqvae_cfg.optimizer.items():
assert hasattr(vqvae.dOptimizer, key)
setattr(vqvae.dOptimizer, key, value)
# print("*****************************************************************************")
# summary(vqvae, [(vqvae_cfg.batch_size, 200), (vqvae_cfg.batch_size, 55)], device=device, mode="train")
return vqvae
def get_pixcnn(vqvae_cfg: DictConfig, pix_cfg: DictConfig, device="cpu"):
pix_net = PixelCNN(
num_embedding=vqvae_cfg.codebook.num_embeddings,
kernel_size=pix_cfg.conv_kernel_size, in_channels=1, padding=pix_cfg.conv_padding,
n_channels=pix_cfg.n_channels, n_layers=pix_cfg.n_layers, device=device
)
return pix_net
def get_pixlin(vqvae_cfg: DictConfig, pix_cfg: DictConfig, device="cpu"):
pix_net = PixelLinear(
num_embedding=vqvae_cfg.num_embeddings,
in_feature=vqvae_cfg.num_features, in_channels=1,
out_channels=pix_cfg.n_channels, n_layers=pix_cfg.n_layers, device=device
)
return pix_net
def get_pix(vqvae_cfg: DictConfig, pix_cfg: DictConfig, mean, std, device="cpu"):
obs_net = None
if pix_cfg.type == "cnn":
obs_net = PixObsNetCNN(
mean=mean, std=std, hid_channels=pix_cfg.obs_cnn.hid_channels,
kernel_size=pix_cfg.obs_cnn.kernel_size, padding=pix_cfg.obs_cnn.padding,
num_layers=pix_cfg.obs_cnn.num_layers, output_shape=vqvae_cfg.codebook.num_features, device=device,
)
# print("*****************************************************************************")
# summary(obs_net, [(1, 55)], device=device)
elif pix_cfg.type == "linear":
obs_net = PixObsNet(
mean=mean, std=std, layers=pix_cfg.obs_layers,
input_shape=pix_cfg.cond_dim, output_shape=vqvae_cfg.codebook.num_features,
device=device,
)
# print("*****************************************************************************")
# summary(obs_net, [(1, 55)], device="cuda")
if pix_cfg.type == "cnn":
pixel_cnn = get_pixcnn(vqvae_cfg, pix_cfg, device)
# print("*****************************************************************************")
# summary(pixel_cnn, [(1, 1, 7), (1, 1, 7)], device=device)
return pixel_cnn, obs_net
elif pix_cfg.type == "linear":
return get_pixlin(vqvae_cfg, pix_cfg, device), obs_net
def make_dir(path: str, clean=True):
if os.path.exists(path) and clean:
shutil.rmtree(path)
os.makedirs(path)
def seed_mch():
th.set_default_dtype(th.float32)
seed = 42
random.seed(seed)
np.random.seed(seed)
th.manual_seed(seed)
th.cuda.manual_seed(seed)
th.cuda.manual_seed_all(seed)
th.set_default_dtype(th.float32)
th.set_float32_matmul_precision('high')
th.backends.cudnn.deterministic = True
th.backends.cudnn.benchmark = False