-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_pic.py
163 lines (146 loc) · 9.46 KB
/
train_pic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
from pic import PIC, zw_quadrature
from data import datasets, trees
from clt import learn_clt
from dltm import DLTM
from sklearn.model_selection import train_test_split
import numpy as np
import functools
import argparse
import torch
import json
import time
import os
print = functools.partial(print, flush=True)
parser = argparse.ArgumentParser()
parser.add_argument('-device', type=str, default='cuda', help='cpu | cuda')
parser.add_argument('-ds', '--dataset', type=str, default='mnist', help='dataset name')
parser.add_argument('-split', type=str, default=None, help='dataset split for EMNIST')
parser.add_argument('-vsp', type=float, default=0.05, help='MNIST valid split percentage')
parser.add_argument('-nip', type=int, default=64, help='number of integration points')
parser.add_argument('-int', '--integration', type=str, default='trapezoidal', help='integration mode')
parser.add_argument('-lt', '--leaf_type', type=str, default=None, help='leaf distribution type')
parser.add_argument('-nu', '--n_units', type=int, default=64, help='pic neural net unit num.')
parser.add_argument('-sigma', type=float, default=1.0, help='sigma ff')
parser.add_argument('-bs', '--batch_size', type=int, default=256, help='batch size during')
parser.add_argument('-as', '--accum_steps', type=int, default=1, help='number of accumulation steps')
parser.add_argument('-nc', '--n_chunks', type=int, default=1, help='num. of chunks to avoid OOM')
parser.add_argument('-ts', '--train_steps', type=int, default=30_000, help='num. of training steps')
parser.add_argument('-vf', '--valid_freq', type=int, default=250, help='validation every n steps')
parser.add_argument('-pf', '--print_freq', type=int, default=250, help='print every n steps')
parser.add_argument('-pat', '--patience', type=int, default=5, help='valid ll patience')
parser.add_argument('-lr', type=float, default=0.01, help='initial learning rate')
parser.add_argument('-t0', type=int, default=500, help='CAWR t0, 1 for fixed lr')
parser.add_argument('-eta_min', type=float, default=1e-4, help='CAWR eta min')
parser.set_defaults(normalize=False)
parser.add_argument('-n', dest='normalize', action='store_true', help='normalize QPC')
parser.add_argument('-nn', dest='normalize', action='store_false', help='do not normalize QPC')
args = parser.parse_args()
dev = args.device
print(args)
#########################################################
################# create logging folder #################
#########################################################
dataset = args.dataset + ('' if args.split is None else ('_' + args.split))
idx = [args.dataset in x for x in [datasets.DEBD_DATASETS, datasets.MNIST_DATASETS, datasets.UCI_DATASETS, ['ptb288']]]
log_dir = 'log/pic/' + ['debd', 'mnist', 'uci', ''][np.argmax(idx)] + '/' + dataset + '/' + str(int(time.time())) + '/'
os.makedirs(log_dir, exist_ok=True)
json.dump(vars(args), open(log_dir + 'args.json', 'w'), sort_keys=True, indent=4)
#########################################################
############ load data & instantiate QPC-PIC ############
#########################################################
if args.dataset == 'ptb288':
train, valid, test = datasets.load_ptb288()
qpc = DLTM(trees.TREE_DICT[dataset], 'categorical', n_categories=50, norm_weight=False, learnable=False)
elif args.dataset in datasets.MNIST_DATASETS:
train, test = datasets.load_mnist(ds_name=args.dataset, split=args.split)
train_idx, valid_idx = train_test_split(np.arange(len(train)), train_size=1-args.vsp)
train, valid = train[train_idx], train[valid_idx]
leaf_type = 'categorical' if args.leaf_type is None else args.leaf_type
qpc = DLTM(trees.TREE_DICT[dataset], leaf_type, n_categories=256, norm_weight=False, learnable=False)
elif args.dataset in datasets.DEBD_DATASETS:
train, valid, test = datasets.load_debd(args.dataset)
qpc = DLTM(learn_clt(train, 'bernoulli', n_categories=2), 'bernoulli', norm_weight=False, learnable=False)
else:
train, valid, test = datasets.load_uci(args.dataset)
qpc = DLTM(trees.TREE_DICT[dataset], 'gaussian', norm_weight=False, learnable=False)
pic = PIC(qpc.tree, qpc.leaf_type, args.n_units, sigma=args.sigma, n_categories=qpc.n_categories).to(device=dev)
print('PIC num. param: %d' % sum(param.numel() for param in pic.parameters() if param.requires_grad))
#########################################################
##################### training loop #####################
#########################################################
optimizer = torch.optim.Adam(pic.parameters(), lr=args.lr, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=args.t0, T_mult=1, eta_min=args.eta_min)
z, log_w = zw_quadrature(mode=args.integration, nip=args.nip, a=-1, b=1, log_weight=True, device=dev)
train_lls_log, valid_lls_log, batch_time_log, best_valid_ll = [-np.inf], [-np.inf], [], -np.inf
tik_train = time.time()
for train_step in range(1, args.train_steps + 1):
tik_batch = time.time()
# materialise pic
qpc.sum_logits, qpc.leaf_logits = pic(z, log_w=log_w, n_chunks=args.n_chunks)
# evaluate qpc
ll, batch_idx = 0, np.random.choice(len(train), args.batch_size * args.accum_steps, replace=False)
for idx in np.array_split(batch_idx, args.accum_steps):
ll_accum = qpc(train[idx].to(dev), has_nan=False, normalize=args.normalize).mean()
(-ll_accum).backward(retain_graph=True if args.accum_steps > 1 else False)
ll += float(ll_accum / args.accum_steps)
# adam step
lr = optimizer.param_groups[0]['lr']
optimizer.step()
optimizer.zero_grad()
scheduler.step()
# if args.accum_steps > 1: torch.cuda.empty_cache()
batch_time_log.append(time.time() - tik_batch)
# validation & logging
train_lls_log.append(float(ll))
if max(valid_lls_log[-args.patience:]) < best_valid_ll:
print('Early stopping: valid LL did not improve over the last %d steps' % (args.patience * args.valid_freq))
break
if train_step % args.valid_freq == 0:
with torch.no_grad():
log_norm_const = qpc.log_norm_constant.cpu() if args.normalize else 0
valid_lls_log.append(float(torch.cat(
[qpc(x.to(device=dev), has_nan=False).cpu() - log_norm_const for x in valid.split(args.batch_size)]).mean()))
if valid_lls_log[-1] > best_valid_ll:
best_valid_ll = valid_lls_log[-1]
torch.save(pic, log_dir + 'pic.pt')
if train_step % args.print_freq == 0:
print(train_step, dataset, 'LL: %.2f, lr: %.5f (best valid LL: %.2f, bt: %.2fs, %.2f GiB)' %
(ll, lr, best_valid_ll, np.mean(batch_time_log), (torch.cuda.max_memory_allocated() / 1024 ** 3)))
tok_train = time.time()
##########################################################
####### compute train-valid-test LLs of best model #######
##########################################################
with torch.no_grad():
pic = torch.load(log_dir + 'pic.pt').to(dev)
qpc.sum_logits, qpc.leaf_logits = pic(z, log_w=log_w)
log_norm_const = qpc.log_norm_constant.cpu() if args.normalize else 0
train_lls = torch.cat([qpc(x.to(dev), has_nan=False).cpu() for x in train.split(args.batch_size)]) - log_norm_const
valid_lls = torch.cat([qpc(x.to(dev), has_nan=False).cpu() for x in valid.split(args.batch_size)]) - log_norm_const
test_lls = torch.cat([qpc(x.to(dev), has_nan=False).cpu() for x in test.split(args.batch_size)]) - log_norm_const
##########################################################
################### printing & logging ###################
##########################################################
print('\ndataset: %s' % dataset)
print('train (nats: %.2f, bpd: %.2f)' % (train_lls.mean(), (-train_lls.mean()) / (np.log(2) * train.size(1))))
print('valid (nats: %.2f, bpd: %.2f)' % (valid_lls.mean(), (-valid_lls.mean()) / (np.log(2) * train.size(1))))
print('test (nats: %.2f, bpd: %.2f)' % (test_lls.mean(), (-test_lls.mean()) / (np.log(2) * train.size(1))))
print('train time: %.2fs' % (tok_train - tik_train))
print('batch time: %.2fs' % np.mean(batch_time_log))
print('PIC param num: %d' % sum(param.numel() for param in pic.parameters() if param.requires_grad))
print('QPC param number: %d' % qpc.n_param)
print('max reserved GPU: %.2f GiB' % (torch.cuda.max_memory_reserved() / 1024 ** 3) if dev == 'cuda' else 0)
print('max allocated GPU: %.2f GiB' % (torch.cuda.max_memory_allocated() / 1024 ** 3) if dev == 'cuda' else 0)
results = {
'train_time': tok_train - tik_train,
'batch_time': np.mean(batch_time_log),
'max_reserved_gpu': torch.cuda.max_memory_reserved() if dev == 'cuda' else 0,
'max_allocated_gpu': torch.cuda.max_memory_allocated() if dev == 'cuda' else 0,
'train_lls_log': np.array(train_lls_log[1:]), # [1:] removes -np.inf
'valid_lls_log': np.array(valid_lls_log[1:]), # [1:] removes -np.inf
'train_lls': np.array(train_lls),
'valid_lls': np.array(valid_lls),
'test_lls': np.array(test_lls)
}
if args.dataset in datasets.MNIST_DATASETS:
results['train_valid_idx'] = np.array((train_idx, valid_idx), dtype=tuple)
np.save(log_dir + 'results.npy', results)