-
Notifications
You must be signed in to change notification settings - Fork 1
/
clcifar.py
84 lines (64 loc) · 2.59 KB
/
clcifar.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from torch.utils.data import Dataset
import os
import urllib.request
from tqdm import tqdm
import pickle
import gdown
class CLCIFAR10(Dataset):
"""CLCIFAR10 training set
The training set of CIFAR10 with human annotated complementary labels.
Containing 50000 samples, each with one ordinary label and the first one of the three complementary labels
Args:
root: the path to store the dataset
transform: feature transformation function
"""
def __init__(self, root="./data", transform=None):
os.makedirs(os.path.join(root, 'clcifar10'), exist_ok=True)
dataset_path = os.path.join(root, 'clcifar10', f"clcifar10.pkl")
if not os.path.exists(dataset_path):
gdown.download(
id="1uNLqmRUkHzZGiSsCtV2-fHoDbtKPnVt2", output=dataset_path
)
data = pickle.load(open(dataset_path, "rb"))
self.transform = transform
self.input_dim = 32 * 32 * 3
self.num_classes = 10
self.targets = [labels[0] for labels in data["cl_labels"]]
self.data = data["images"]
self.ord_labels = data["ord_labels"]
def __len__(self):
return len(self.data)
def __getitem__(self, index):
image = self.data[index]
if self.transform is not None:
image = self.transform(image)
return image, self.targets[index]
class CLCIFAR20(Dataset):
"""CLCIFAR20 training set
The training set of CIFAR20 with human annotated complementary labels.
Containing 50000 samples, each with one ordinary label and the first one of the three complementary labels
Args:
root: the path to store the dataset
transform: feature transformation function
"""
def __init__(self, root="./data", transform=None):
os.makedirs(os.path.join(root, 'clcifar20'), exist_ok=True)
dataset_path = os.path.join(root, 'clcifar20', f"clcifar20.pkl")
if not os.path.exists(dataset_path):
gdown.download(
id="1PhZsyoi1dAHDGlmB4QIJvDHLf_JBsFeP", output=dataset_path
)
data = pickle.load(open(dataset_path, "rb"))
self.transform = transform
self.input_dim = 32 * 32 * 3
self.num_classes = 20
self.targets = [labels[0] for labels in data["cl_labels"]]
self.data = data["images"]
self.ord_labels = data["ord_labels"]
def __len__(self):
return len(self.data)
def __getitem__(self, index):
image = self.data[index]
if self.transform is not None:
image = self.transform(image)
return image, self.targets[index]