-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdataset.py
104 lines (90 loc) · 4.14 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# Copyright (c) 2018-2021, RangerUFO
#
# This file is part of cycle_gan.
#
# cycle_gan is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# cycle_gan is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with cycle_gan. If not, see <https://www.gnu.org/licenses/>.
import os
import cv2
import random
import zipfile
import numpy as np
import mxnet as mx
import gluoncv as gcv
from multiprocessing import cpu_count
from multiprocessing.dummy import Pool
def load_image(path):
with open(path, "rb") as f:
buf = f.read()
return mx.image.imdecode(buf)
def load_dataset(name, category):
url = "https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/%s.zip" % (name)
data_path = "data"
if not os.path.exists(os.path.join(data_path, name)):
data_file = mx.gluon.utils.download(url)
with zipfile.ZipFile(data_file) as f:
if not os.path.exists(data_path):
os.makedirs(data_path)
f.extractall(path=data_path)
imgs = [os.path.join(path, f) for path, _, files in os.walk(os.path.join(data_path, name, category)) for f in files]
return imgs
def get_batches(dataset_a, dataset_b, batch_size, fine_size=(256, 256), load_size=(286, 286), ctx=mx.cpu()):
batches = max(len(dataset_a), len(dataset_b)) // batch_size
sampler_a = Sampler(dataset_a, fine_size, load_size)
sampler_b = Sampler(dataset_b, fine_size, load_size)
batchify_fn = gcv.data.batchify.Stack()
with Pool(cpu_count() * 2) as p:
for i in range(batches):
start = i * batch_size
samples_a = p.map(sampler_a, range(start, start + batch_size))
samples_b = p.map(sampler_b, range(start, start + batch_size))
batch_a = batchify_fn(samples_a)
batch_b = batchify_fn(samples_b)
yield batch_a.as_in_context(ctx), batch_b.as_in_context(ctx)
def rotate(image, angle):
h, w = image.shape[:2]
mat = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1)
return mx.nd.array(cv2.warpAffine(image.asnumpy(), mat, (w, h), flags=random.randint(0, 4)))
class Sampler:
def __init__(self, dataset, fine_size, load_size):
self._dataset = dataset
self._fine_size = fine_size
self._load_size = load_size
def __call__(self, idx):
img = load_image(self._dataset[idx % len(self._dataset)])
img = rotate(img, random.uniform(-20, 20))
img = mx.image.resize_short(img, min(self._load_size), interp=random.randint(0, 4))
img, _ = mx.image.random_crop(img, self._fine_size)
img, _ = gcv.data.transforms.image.random_flip(img, px=0.5)
img = gcv.data.transforms.experimental.image.random_color_distort(img)
return mx.nd.image.normalize(mx.nd.image.to_tensor(img), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
def reconstruct_color(img):
mean = mx.nd.array([0.5, 0.5, 0.5], ctx=img.context)
std = mx.nd.array([0.5, 0.5, 0.5], ctx=img.context)
return ((img * std + mean).clip(0.0, 1.0) * 255).astype("uint8")
if __name__ == "__main__":
import matplotlib.pyplot as plt
batch_size = 4
dataset_a = load_dataset("vangogh2photo", "trainA")
dataset_b = load_dataset("vangogh2photo", "trainB")
for batch_a, batch_b in get_batches(dataset_a, dataset_b, batch_size):
print("batch_a preview: ", batch_a)
print("batch_b preview: ", batch_b)
for i in range(batch_size):
plt.subplot(batch_size * 2 // 8 + 1, 4, i + 1)
plt.imshow(reconstruct_color(batch_a[i].transpose((1, 2, 0))).asnumpy())
plt.axis("off")
plt.subplot(batch_size * 2 // 8 + 1, 4, i + batch_size + 1)
plt.imshow(reconstruct_color(batch_b[i].transpose((1, 2, 0))).asnumpy())
plt.axis("off")
plt.show()