-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtinyimagenet.py
156 lines (127 loc) · 7.34 KB
/
tinyimagenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""
TF: https://github.com/ksachdeva/tiny-imagenet-tfds/blob/master/tiny_imagenet/_imagenet.py
PyTorch: https://gist.github.com/lromor/bcfc69dcf31b2f3244358aea10b7a11b
"""
import os
import tensorflow as tf
import tensorflow_datasets.public_api as tfds
_URL = "http://cs231n.stanford.edu/tiny-imagenet-200.zip"
_EXTRACTED_FOLDER_NAME = "tiny-imagenet-200"
SUPPORTED_IMAGE_FORMAT = (".jpg", ".jpeg", ".png")
def _list_folders(root_dir):
return [
f for f in tf.io.gfile.listdir(root_dir)
if tf.io.gfile.isdir(os.path.join(root_dir, f))
]
def _list_imgs(root_dir):
return [
os.path.join(root_dir, f)
for f in tf.io.gfile.listdir(root_dir)
if any(f.lower().endswith(ext) for ext in SUPPORTED_IMAGE_FORMAT)
]
class_names = ['Egyptian cat', 'reel', 'volleyball', 'rocking chair', 'lemon', 'bullfrog', 'basketball', 'cliff',
'espresso', 'plunger', 'parking meter', 'German shepherd', 'dining table', 'monarch', 'brown bear',
'school bus', 'pizza', 'guinea pig', 'umbrella', 'organ', 'oboe', 'maypole', 'goldfish', 'potpie',
'hourglass', 'seashore', 'computer keyboard', 'Arabian camel', 'ice cream', 'nail', 'space heater',
'cardigan', 'baboon', 'snail', 'coral reef', 'albatross', 'spider web', 'sea cucumber', 'backpack',
'Labrador retriever', 'pretzel', 'king penguin', 'sulphur butterfly', 'tarantula', 'lesser panda',
'pop bottle', 'banana', 'sock', 'cockroach', 'projectile', 'beer bottle', 'mantis', 'freight car',
'guacamole', 'remote control', 'European fire salamander', 'lakeside', 'chimpanzee', 'pay-phone',
'fur coat', 'alp', 'lampshade', 'torch', 'abacus', 'moving van', 'barrel', 'tabby', 'goose', 'koala',
'bullet train', 'CD player', 'teapot', 'birdhouse', 'gazelle', 'academic gown', 'tractor', 'ladybug',
'miniskirt', 'golden retriever', 'triumphal arch', 'cannon', 'neck brace', 'sombrero', 'gasmask',
'candle', 'desk', 'frying pan', 'bee', 'dam', 'spiny lobster', 'police van', 'iPod', 'punching bag',
'beacon', 'jellyfish', 'wok', "potter's wheel", 'sandal', 'pill bottle', 'butcher shop', 'slug', 'hog',
'cougar', 'crane', 'vestment', 'dragonfly', 'cash machine', 'mushroom', 'jinrikisha', 'water tower',
'chest', 'snorkel', 'sunglasses', 'fly', 'limousine', 'black stork', 'dugong', 'sports car', 'water jug',
'suspension bridge', 'ox', 'ice lolly', 'turnstile', 'Christmas stocking', 'broom', 'scorpion',
'wooden spoon', 'picket fence', 'rugby ball', 'sewing machine', 'steel arch bridge', 'Persian cat',
'refrigerator', 'barn', 'apron', 'Yorkshire terrier', 'swimming trunks', 'stopwatch', 'lawn mower',
'thatch', 'fountain', 'black widow', 'bikini', 'plate', 'teddy', 'barbershop', 'confectionery',
'beach wagon', 'scoreboard', 'orange', 'flagpole', 'American lobster', 'trolleybus', 'drumstick',
'dumbbell', 'brass', 'bow tie', 'convertible', 'bighorn', 'orangutan', 'American alligator', 'centipede',
'syringe', 'go-kart', 'brain coral', 'sea slug', 'cliff dwelling', 'mashed potato', 'viaduct',
'military uniform', 'pomegranate', 'chain', 'kimono', 'comic book', 'trilobite', 'bison', 'pole',
'boa constrictor', 'poncho', 'bathtub', 'grasshopper', 'walking stick', 'Chihuahua', 'tailed frog',
'lion', 'altar', 'obelisk', 'beaker', 'bell pepper', 'bannister', 'bucket', 'magnetic compass',
'meat loaf', 'gondola', 'standard poodle', 'acorn', 'lifeboat', 'binoculars', 'cauliflower',
'African elephant']
# Use V2 to avoid name collision with tfds
class TinyImagenetV2(tfds.core.GeneratorBasedBuilder):
""" tiny-imagenet dataset """
VERSION = tfds.core.Version('1.0.0')
def _info(self):
return tfds.core.DatasetInfo(
builder=self,
description=("""Tiny ImageNet Challenge is a similar challenge as ImageNet with a smaller dataset but
less image classes. It contains 200 image classes, a training
dataset of 100, 000 images, a validation dataset of 10, 000
images, and a test dataset of 10, 000 images. All images are
of size 64×64."""),
features=tfds.features.FeaturesDict({
"image": tfds.features.Image(shape=(64, 64, 3), encoding_format="jpeg"),
"id": tfds.features.Text(),
"label": tfds.features.ClassLabel(names=class_names),
}),
supervised_keys=("image", "label"),
homepage="https://tiny-imagenet.herokuapp.com/",
citation=r"""@article{tiny-imagenet,
author = {Li,Fei-Fei}, {Karpathy,Andrej} and {Johnson,Justin}"}""",
)
def _process_train_ds(self, ds_folder, identities):
path_to_ds = os.path.join(ds_folder, 'train')
names = _list_folders(path_to_ds)
label_images = {}
for n in names:
images_dir = os.path.join(path_to_ds, n, 'images')
total_images = _list_imgs(images_dir)
label_images[n] = {
'images': total_images,
'id': identities.index(n)
}
return label_images
def _process_test_ds(self, ds_folder, identities):
path_to_ds = os.path.join(ds_folder, 'val')
# read the val_annotations.txt file
with tf.io.gfile.GFile(os.path.join(path_to_ds, 'val_annotations.txt')) as f:
data_raw = f.read()
lines = data_raw.split("\n")
label_images = {}
for line in lines:
if line == '':
continue
row_values = line.strip().split()
label_name = row_values[1]
if not label_name in label_images.keys():
label_images[label_name] = {
'images': [],
'id': identities.index(label_name)
}
label_images[label_name]['images'].append(
os.path.join(path_to_ds, 'images', row_values[0]))
return label_images
def _split_generators(self, dl_manager):
extracted_path = dl_manager.extract(dl_manager.download(_URL))
ds_folder = os.path.join(extracted_path, _EXTRACTED_FOLDER_NAME)
with tf.io.gfile.GFile(os.path.join(ds_folder, 'wnids.txt')) as f:
data_raw = f.read()
lines = data_raw.split("\n")
train_label_images = self._process_train_ds(ds_folder, lines)
test_label_images = self._process_test_ds(ds_folder, lines)
return [
tfds.core.SplitGenerator(
name=tfds.Split.TRAIN,
gen_kwargs=dict(label_images=train_label_images, )),
tfds.core.SplitGenerator(
name=tfds.Split.TEST,
gen_kwargs=dict(label_images=test_label_images, )),
]
def _generate_examples(self, label_images):
for label, image_info in label_images.items():
for image_path in image_info['images']:
key = "%s/%s" % (label, os.path.basename(image_path))
yield key, {
"image": image_path,
"id": label,
"label": image_info['id'],
}