-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcags_dataset.py
43 lines (36 loc) · 1.85 KB
/
cags_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import os
import sys
import urllib.request
import tensorflow as tf
class CAGS:
H, W, C = 224, 224, 3
LABELS = [
# Cats
"Abyssinian", "Bengal", "Bombay", "British_Shorthair", "Egyptian_Mau",
"Maine_Coon", "Russian_Blue", "Siamese", "Sphynx",
# Dogs
"american_bulldog", "american_pit_bull_terrier", "basset_hound",
"beagle", "boxer", "chihuahua", "english_cocker_spaniel",
"english_setter", "german_shorthaired", "great_pyrenees", "havanese",
"japanese_chin", "keeshond", "leonberger", "miniature_pinscher",
"newfoundland", "pomeranian", "pug", "saint_bernard", "samoyed",
"scottish_terrier", "shiba_inu", "staffordshire_bull_terrier",
"wheaten_terrier", "yorkshire_terrier",
]
_URL = "https://ufal.mff.cuni.cz/~straka/courses/npfl114/1920/datasets/"
@staticmethod
def parse(example):
example = tf.io.parse_single_example(example, {
"image": tf.io.FixedLenFeature([], tf.string),
"mask": tf.io.FixedLenFeature([], tf.string),
"label": tf.io.FixedLenFeature([], tf.int64)})
example["image"] = tf.image.convert_image_dtype(tf.image.decode_jpeg(example["image"], channels=3), tf.float32)
example["mask"] = tf.image.convert_image_dtype(tf.image.decode_png(example["mask"], channels=1), tf.float32)
return example
def __init__(self):
for dataset, size in [("train", 57463494), ("dev", 8138328), ("test", None)]:
path = "cags.{}.tfrecord".format(dataset)
if not os.path.exists(path) or (size is not None and os.path.getsize(path) != size):
print("Downloading file {}...".format(path), file=sys.stderr)
urllib.request.urlretrieve("{}/{}".format(self._URL, path), filename=path)
setattr(self, dataset, tf.data.TFRecordDataset(path))