-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathtrain.py
62 lines (51 loc) · 2.46 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import argparse
from chainer import datasets, training, iterators, optimizers, optimizer
from chainer.training import updater, extensions
from lib.chainer.iterators import RandomNoiseIterator, UniformNoiseGenerator
from lib.chainer.updater import GenerativeAdversarialUpdater
from lib.chainer.extensions import GeneratorSample
from lib.models import Generator, Discriminator, Denoiser
iterators.RandomNoiseIterator = RandomNoiseIterator
updater.GenerativeAdversarialUpdater = GenerativeAdversarialUpdater
extensions.GeneratorSample = GeneratorSample
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=1)
parser.add_argument('--nz', type=int, default=100)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--lambda-denoise', type=float, default=1.0)
parser.add_argument('--lambda-adv', type=float, default=0.03)
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
nz = args.nz
batch_size = args.batch_size
epochs = args.epochs
gpu = args.gpu
lambda_denoise = args.lambda_denoise
lambda_adv = args.lambda_adv
train, _ = datasets.get_cifar10(withlabel=False, ndim=3)
train_iter = iterators.SerialIterator(train, batch_size)
z_iter = iterators.RandomNoiseIterator(UniformNoiseGenerator(-1, 1, nz), batch_size)
optimizer_generator = optimizers.Adam(alpha=1e-4, beta1=0.5)
optimizer_discriminator = optimizers.Adam(alpha=1e-4, beta1=0.5)
optimizer_denoiser = optimizers.Adam(alpha=1e-4, beta1=0.5)
optimizer_generator.setup(Generator())
optimizer_discriminator.setup(Discriminator())
optimizer_denoiser.setup(Denoiser())
updater = updater.GenerativeAdversarialUpdater(
iterator=train_iter,
noise_iterator=z_iter,
optimizer_generator=optimizer_generator,
optimizer_discriminator=optimizer_discriminator,
optimizer_denoiser=optimizer_denoiser,
lambda_denoise=lambda_denoise,
lambda_adv=lambda_adv,
device=gpu)
trainer = training.Trainer(updater, stop_trigger=(epochs, 'epoch'))
trainer.extend(extensions.LogReport(trigger=(1, 'iteration')))
trainer.extend(extensions.PrintReport(['epoch', 'iteration', 'gen/loss', 'dis/loss', 'denoiser/loss']))
trainer.extend(extensions.ProgressBar())
trainer.extend(extensions.GeneratorSample(), trigger=(1, 'epoch'))
trainer.run()