-
Notifications
You must be signed in to change notification settings - Fork 33
/
Copy pathmain.py
187 lines (136 loc) · 5.82 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
# import libraries
import os
import matplotlib.pyplot as plt
import torch
import torchvision
from torch.nn import functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from multiprocessing import Process
from utils.start_tensorboard import run_tensorboard
from models.seq2seq_ConvLSTM import EncoderDecoderConvLSTM
from data.MovingMNIST import MovingMNIST
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
parser.add_argument('--beta_1', type=float, default=0.9, help='decay rate 1')
parser.add_argument('--beta_2', type=float, default=0.98, help='decay rate 2')
parser.add_argument('--batch_size', default=12, type=int, help='batch size')
parser.add_argument('--epochs', type=int, default=300, help='number of epochs to train for')
parser.add_argument('--use_amp', default=False, type=bool, help='mixed-precision training')
parser.add_argument('--n_gpus', type=int, default=1, help='number of GPUs')
parser.add_argument('--n_hidden_dim', type=int, default=64, help='number of hidden dim for ConvLSTM layers')
opt = parser.parse_args()
##########################
######### MODEL ##########
##########################
class MovingMNISTLightning(pl.LightningModule):
def __init__(self, hparams=None, model=None):
super(MovingMNISTLightning, self).__init__()
# default config
self.path = os.getcwd() + '/data'
self.model = model
# logging config
self.log_images = True
# Training config
self.criterion = torch.nn.MSELoss()
self.batch_size = opt.batch_size
self.n_steps_past = 10
self.n_steps_ahead = 10 # 4
def create_video(self, x, y_hat, y):
# predictions with input for illustration purposes
preds = torch.cat([x.cpu(), y_hat.unsqueeze(2).cpu()], dim=1)[0]
# entire input and ground truth
y_plot = torch.cat([x.cpu(), y.unsqueeze(2).cpu()], dim=1)[0]
# error (l2 norm) plot between pred and ground truth
difference = (torch.pow(y_hat[0] - y[0], 2)).detach().cpu()
zeros = torch.zeros(difference.shape)
difference_plot = torch.cat([zeros.cpu().unsqueeze(0), difference.unsqueeze(0).cpu()], dim=1)[
0].unsqueeze(1)
# concat all images
final_image = torch.cat([preds, y_plot, difference_plot], dim=0)
# make them into a single grid image file
grid = torchvision.utils.make_grid(final_image, nrow=self.n_steps_past + self.n_steps_ahead)
return grid
def forward(self, x):
x = x.to(device='cuda')
output = self.model(x, future_seq=self.n_steps_ahead)
return output
def training_step(self, batch, batch_idx):
x, y = batch[:, 0:self.n_steps_past, :, :, :], batch[:, self.n_steps_past:, :, :, :]
x = x.permute(0, 1, 4, 2, 3)
y = y.squeeze()
y_hat = self.forward(x).squeeze() # is squeeze neccessary?
loss = self.criterion(y_hat, y)
# save learning_rate
lr_saved = self.trainer.optimizers[0].param_groups[-1]['lr']
lr_saved = torch.scalar_tensor(lr_saved).cuda()
# save predicted images every 250 global_step
if self.log_images:
if self.global_step % 250 == 0:
final_image = self.create_video(x, y_hat, y)
self.logger.experiment.add_image(
'epoch_' + str(self.current_epoch) + '_step' + str(self.global_step) + '_generated_images',
final_image, 0)
plt.close()
tensorboard_logs = {'train_mse_loss': loss,
'learning_rate': lr_saved}
return {'loss': loss, 'log': tensorboard_logs}
def test_step(self, batch, batch_idx):
# OPTIONAL
x, y = batch
y_hat = self.forward(x)
return {'test_loss': self.criterion(y_hat, y)}
def test_end(self, outputs):
# OPTIONAL
avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
tensorboard_logs = {'test_loss': avg_loss}
return {'avg_test_loss': avg_loss, 'log': tensorboard_logs}
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=opt.lr, betas=(opt.beta_1, opt.beta_2))
@pl.data_loader
def train_dataloader(self):
train_data = MovingMNIST(
train=True,
data_root=self.path,
seq_len=self.n_steps_past + self.n_steps_ahead,
image_size=64,
deterministic=True,
num_digits=2)
train_loader = torch.utils.data.DataLoader(
dataset=train_data,
batch_size=self.batch_size,
shuffle=True)
return train_loader
@pl.data_loader
def test_dataloader(self):
test_data = MovingMNIST(
train=False,
data_root=self.path,
seq_len=self.n_steps_past + self.n_steps_ahead,
image_size=64,
deterministic=True,
num_digits=2)
test_loader = torch.utils.data.DataLoader(
dataset=test_data,
batch_size=self.batch_size,
shuffle=True)
return test_loader
def run_trainer():
conv_lstm_model = EncoderDecoderConvLSTM(nf=opt.n_hidden_dim, in_chan=1)
model = MovingMNISTLightning(model=conv_lstm_model)
trainer = Trainer(max_epochs=opt.epochs,
gpus=opt.n_gpus,
distributed_backend='dp',
early_stop_callback=False,
use_amp=opt.use_amp
)
trainer.fit(model)
if __name__ == '__main__':
p1 = Process(target=run_trainer) # start trainer
p1.start()
p2 = Process(target=run_tensorboard(new_run=True)) # start tensorboard
p2.start()
p1.join()
p2.join()