-
Notifications
You must be signed in to change notification settings - Fork 0
/
Train_Conversion.py
58 lines (44 loc) · 1.9 KB
/
Train_Conversion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
import torch.nn as nn
from network import DF, DF_Src, DF_Dst
from Train_util import DFDataset, train_partial, save_checkpoint_DF
import matplotlib.pyplot as plt
from DFLoss import MSE_DISSIM_Loss
from Conversion_util import conversion, zipdir
DF_sample = DF(3, 64, 128, 128, 256, 64, 16)
DF_sample_src = DF_Src(DF_sample)
DF_sample_dst = DF_Dst(DF_sample)
dfdata_src = DFDataset('./DFDataset/img_i_2', './DFDataset/seg_i_2')
dfdata_dst = DFDataset('./DFDataset/img_e', './DFDataset/seg_e')
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))
BATCH = 32
train_src_dataloader = DataLoader(dfdata_src, batch_size=BATCH, shuffle=True)
train_dst_dataloader = DataLoader(dfdata_dst, batch_size=BATCH, shuffle=True)
mse_dissim = MSE_DISSIM_Loss().to(device)
optimizer_src = torch.optim.Adam(DF_sample_src.parameters(), lr=1e-4)
optimizer_dst = torch.optim.Adam(DF_sample_dst.parameters(), lr=1e-4)
DF_sample = DF_sample.to(device)
DF_sample_src = DF_sample_src.to(device)
DF_sample_dst = DF_sample_dst.to(device)
src_history = []
dst_history = []
for i in range(30):
print('epoch ', i)
print('SRC partial training')
train_partial(train_src_dataloader, DF_sample_src, mse_dissim, optimizer_src, src_history)
print('\nDST partial training')
train_partial(train_dst_dataloader, DF_sample_dst, mse_dissim, optimizer_dst, dst_history)
plt.plot(src_history, color='red', label='Src')
plt.plot(dst_history, color='blue', label='Dst')
plt.title('Change of Loss while Training')
plt.xlabel('Step')
plt.ylabel('MSE + DISSIM')
plt.legend()
plt.savefig('Loss.png')
plt.show()
save_checkpoint_DF(30, DF_sample, DF_sample_src, DF_sample_dst, optimizer_src, optimizer_dst, 'DFcheckpoint_fin_ver_30')
conversion('./DFDataset/img_e', './Conversion_data', DF_sample_src)
zipf = zipfile.ZipFile('conversion.zip', 'w', zipfile.ZIP_DEFLATED)
zipdir('Conversion_data/', zipf)
zipf.close()