-
Notifications
You must be signed in to change notification settings - Fork 3
/
data_split.py
178 lines (147 loc) · 6.11 KB
/
data_split.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import argparse
import math
import os
import random
import shutil
import numpy as np
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif',
'.tiff', '.webp')
SPLIT = True
def parse_args():
parser = argparse.ArgumentParser(
description='Train a model for different kaggle competitions.')
parser.add_argument(
'--data-dir',
type=str,
default='',
help='training and validation pictures to use.')
parser.add_argument(
'--dataset', type=str, default='train', help='the kaggle competition')
parser.add_argument(
'--sampling_strategy',
type=str,
default='random',
choices=['balanced', 'random'],
help='Sampling strategy, balanced or random')
opt = parser.parse_args()
return opt
def has_file_allowed_extension(filename, extensions=IMG_EXTENSIONS):
"""Checks if a file is an allowed extension.
Args:
filename (string): path to a file
extensions (tuple of strings): extensions to consider (lowercase)
Returns:
bool: True if the filename ends with one of given extensions
"""
return filename.lower().endswith(extensions)
def balanced_split(all_data_list, val_ratio=0.1, test_ratio=0.1):
assert 0 <= val_ratio < 1.0
assert 0 <= test_ratio < 1.0
assert 0 < val_ratio + test_ratio < 1.0
random.shuffle(all_data_list)
val_nums = math.ceil(len(all_data_list) * val_ratio)
test_nums = math.ceil(len(all_data_list) * test_ratio)
val = all_data_list[:val_nums]
test = all_data_list[val_nums:(val_nums + test_nums)]
train = all_data_list[(val_nums + test_nums):]
return train, val, test
def random_split(all_data_list, val_ratio=0.1, test_ratio=0.1):
assert 0 <= val_ratio < 1.0
assert 0 <= test_ratio < 1.0
assert 0 < val_ratio + test_ratio < 1.0
mask = np.random.rand(len(all_data_list))
test_mask = mask < test_ratio
val_mask = (test_ratio < mask) & (mask < test_ratio + val_ratio)
train_mask = mask > (test_ratio + val_ratio)
all_data_list = np.array(all_data_list)
train = all_data_list[train_mask]
test = all_data_list[test_mask]
val = all_data_list[val_mask]
return train, val, test
def copy_images(root_dir, dest_dir, classes, img_list):
for img_name in img_list:
img_path = os.path.join(root_dir, classes, img_name)
isExists = os.path.exists(img_path)
if (isExists):
new_path = os.path.join(dest_dir, classes, img_name)
shutil.copyfile(img_path, new_path)
else:
print(str(img_path) + ' does not exist.')
print('%s has been moved to %s' % (classes, dest_dir))
def move_images(root_dir, dest_dir, classes, img_list):
for img_name in img_list:
img_path = os.path.join(root_dir, classes, img_name)
isExists = os.path.exists(img_path)
if (isExists):
new_path = os.path.join(dest_dir, classes, img_name)
shutil.move(img_path, new_path)
else:
print(str(img_path) + ' does not exist.')
print('%s has been moved to %s' % (classes, dest_dir))
def mkdir(dir, rmdir=False):
split_mode = True
if not os.path.exists(dir):
os.makedirs(dir)
print('%s does not exist, will be created.' % dir)
elif rmdir:
print('%s exists, will be deleted and rebuilt.' % dir)
shutil.rmtree(dir)
os.makedirs(dir)
else:
print('%s has exists, create new dir will be ignored ' % dir)
split_mode = False
return split_mode
class DataSplit(object):
def __init__(self,
cfg,
split='split',
train='train',
val='val',
test='test'):
self.cfg = cfg
self.root_dir = cfg.get('data_path')
self.dataset = cfg.get('data_name')
self.data_dir = os.path.join(self.root_dir, 'data')
self.train_path = os.path.join(self.root_dir, split, train)
self.val_path = os.path.join(self.root_dir, split, val)
self.test_path = os.path.join(self.root_dir, split, test)
def run_data_split(self,
sampling_strategy='balanced',
val_ratio=0.1,
test_ratio=0.1):
"""root_dir, dataset=None, train='train', val='val', test='test',
sampling_strategy='balanced', val_ratio=0.1, test_ratio=0.1."""
for img_cls in os.listdir(self.data_dir):
img_cls_dir = os.path.join(self.data_dir, img_cls)
if os.path.isdir(img_cls_dir):
img_cls_list = os.listdir(img_cls_dir)
img_cls_list = [
name for name in img_cls_list
if has_file_allowed_extension(name)
]
if len(img_cls_list) > 0:
split_mode = mkdir(os.path.join(self.train_path, img_cls))
if not split_mode:
break
split_mode = mkdir(os.path.join(self.val_path, img_cls))
if not split_mode:
break
split_mode = mkdir(os.path.join(self.test_path, img_cls))
if not split_mode:
break
if sampling_strategy == 'random':
train_list, val_list, test_list = random_split(
img_cls_list,
val_ratio=val_ratio,
test_ratio=test_ratio)
elif sampling_strategy == 'balanced':
train_list, val_list, test_list = balanced_split(
img_cls_list,
val_ratio=val_ratio,
test_ratio=test_ratio)
copy_images(self.data_dir, self.train_path, img_cls,
train_list)
copy_images(self.data_dir, self.val_path, img_cls, val_list)
copy_images(self.data_dir, self.test_path, img_cls, test_list)
print('All images have been processed.')
return self.train_path, self.val_path, self.test_path