forked from q294881866/vtl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_g.py
133 lines (113 loc) · 4.55 KB
/
train_g.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import _thread
import argparse
import os
import numpy as np
import torch
import torch.nn as nn
import GlobalConfig
from dataset.dataset import TrainCache, load_cache, get_dataloader, TrainItem
from dataset.inpainting_dataset import get_inpainting_dataloader
from layer import fn
from layer.genesis import Genesis
from layer.helper import cb2b
from util import figureUtil
from util.logUtil import logger
# program init
label_set = {}
bce_loss = nn.BCELoss(reduction='mean')
itr_times, g_losses, h_losses, d_losses, h_d_losses, accuracies, hash_dists = [], [], [], [], [], [], []
def load_label_classes(data_path):
classes = os.listdir(data_path)
for c in classes:
label_set[c] = 0
num_classes = len(classes)
return num_classes
def get_classes_label(label):
l_set = label_set.copy()
l_set[label] = 1
return list(l_set.values())
def get_tensor_target(labels: []):
ts = []
for l in labels:
ts.append(get_classes_label(l))
x = np.asarray(ts, dtype=np.float32).repeat(3, axis=0)
return torch.from_numpy(x)
def train(args_, dataloader_, test_loader_, num_classes):
# init
genesis = Genesis(224, GlobalConfig.PATCH_SIZE, args_.local_rank, [args_.local_rank])
device = genesis.device
# running
test_itr = enumerate(test_loader_)
idx = GlobalConfig.CHECKPOINT
for epoch in range(1000):
train_cache = TrainCache(size=16)
_thread.start_new_thread(load_cache, (dataloader_, train_cache,))
while not train_cache.finished:
if train_cache.has_item():
try:
_, item = train_cache.next_data()
train_step(genesis, item, idx, epoch, device)
test_step(genesis, idx, epoch, test_itr, device)
except Exception as e:
print(e)
if isinstance(e, StopIteration):
test_itr = enumerate(test_loader_)
idx += 1
def train_step(genesis: Genesis, item: TrainItem, idx, epoch, device):
# HashNet
src = cb2b(item.src, device)
fake = cb2b(item.fake, device)
masks = cb2b(item.masks, device)
loss_g, g = train_g(genesis, [src, fake], masks, idx)
# epoch log
logger.info("Train Epoch:{}/{},G Loss:{:.5f}".format(epoch, idx, loss_g))
def test_step(genesis: Genesis, idx, epoch, test_itr, device):
if idx % 100 == 0:
genesis.eval()
_, (label, _, _, _, sources, fakes, masks) = test_itr.__next__()
# HashNet
fakes = cb2b(fakes, device)
sources = cb2b(sources, device)
masks = cb2b(masks, device)
g = genesis.g([sources, fakes])
# save generate mask
figureUtil.merge_pic(g, masks, 'images/{}_{}_{}_test.jpg'.format(epoch, idx, 0))
# save generate mask
genesis.save('models/{}_{}_'.format(epoch, idx))
genesis.train()
def train_g(genesis: Genesis, train_data, masks, idx):
# train
try:
g = genesis.g(train_data)
g_loss = fn.mask_loss(g, masks)
# backward
genesis.reset_grad()
g_loss.backward()
genesis.opt_g.step()
if idx % 100 == 0:
figureUtil.merge_pic(g, masks, 'images/{}_{}_mask.jpg'.format(idx, 0))
g_losses.append(round(g_loss.item(), 3))
return g_loss, g
except Exception as e:
print(e)
parser = argparse.ArgumentParser()
parser.add_argument('--path', type=str, default=r'Y:\vrf_')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument('--type', type=int, default=0)
if __name__ == '__main__':
args = parser.parse_args()
print('args:{}'.format(args))
dataloader, test_loader, num_classes = None, None, 0
if args.type == 0:
dataloader = get_dataloader(set_path=os.path.join(args.path, GlobalConfig.TRAIN))
test_loader = get_dataloader(mode=GlobalConfig.TEST,
set_path=os.path.join(args.path, GlobalConfig.TEST),
num_workers=0)
num_classes = load_label_classes(os.path.join(args.path, GlobalConfig.TRAIN))
elif args.type == 1:
dataloader = get_inpainting_dataloader(set_path=os.path.join(args.path, GlobalConfig.TRAIN))
test_loader = get_inpainting_dataloader(mode=GlobalConfig.TEST,
set_path=os.path.join(args.path, GlobalConfig.TEST),
num_workers=0)
num_classes = load_label_classes(os.path.join(args.path, GlobalConfig.TRAIN, 'src'))
train(args, dataloader, test_loader, num_classes)