-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataloader.py
54 lines (40 loc) · 2.71 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from torch.utils.data import DataLoader
import dataset
def get_noisy_dataset(train, transform, target_transform, args, exist=False):
if args.dataset == 'cifar10':
data = dataset.CIFAR10(args, train=train, transform=transform, target_transform=target_transform, exist=exist)
elif args.dataset.startswith('cifar10N'):
data = dataset.CIFAR10N(args, train=train, transform=transform, target_transform=target_transform, exist=exist)
elif args.dataset == 'cifar100':
data = dataset.CIFAR100(args, train=train, transform=transform, target_transform=target_transform, exist=exist)
elif args.dataset == 'Clothing1M':
data = dataset.Clothing1M_Dataset(train=train, transform=transform, target_transform=target_transform)
return data
def get_processed_dataset(train, transform, target_transform, args, exist=False):
if args.dataset == 'cifar10' or args.dataset == 'cifar100':
data = dataset.processed_dataset(args, train=train, transform=transform, target_transform=target_transform,
exist=exist)
elif args.dataset.startswith('cifar10N'):
data = dataset.CIFAR10N_processed_dataset(args, train=train, transform=transform,
target_transform=target_transform,
exist=exist)
elif args.dataset == 'Clothing1M':
data = dataset.Clothing1M_processed(train=train, transform=transform, target_transform=target_transform)
return data
def get_distilled_dataset(train, transform, target_transform, dir, args):
if args.dataset.startswith('cifar10') or args.dataset.startswith('cifar100'):
data = dataset.distilled_dataset(train=train, transform=transform, target_transform=target_transform, dir=dir)
elif args.dataset == 'Clothing1M':
data = dataset.distilled_dataset_Clothing1M(args, train=train, transform=transform,
target_transform=target_transform, dir=dir)
return data
def get_test_loader(transform, target_transform, args, exist=False):
if args.dataset.startswith('cifar10'):
test_data = dataset.CIFAR10_test(transform=transform, target_transform=target_transform, exist=exist)
elif args.dataset.startswith('cifar100'):
test_data = dataset.CIFAR100_test(transform=transform, target_transform=target_transform, exist=exist)
elif args.dataset == 'Clothing1M':
test_data = dataset.Clothing1M_test(transform=transform, target_transform=target_transform)
test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers,
drop_last=False)
return test_loader