-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmyDataLoader.py
40 lines (30 loc) · 1.08 KB
/
myDataLoader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
from torch.utils.data.dataset import Dataset
from torchvision import transforms, datasets
import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt
def getDataLoader(train, download=True, permutation=None, args=None):
transform=transforms.Compose([
transforms.ToTensor(),
transforms.ToPILImage(),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
transforms.Lambda(lambda x: _permutate_image_pixels(x, permutation))
])
dataset = datasets.MNIST(
'./datasets/{name}'.format(name=args.origin), train=train,
download=download, transform=transform,)
trainloader = torch.utils.data.DataLoader(dataset,
batch_size=args.batch_size if train else args.test_batch_size,
shuffle=True, num_workers=1)
return trainloader;
def _permutate_image_pixels(image, permutation):
if permutation is None:
return image
c, h, w = image.size()
image = image.view(-1, c)
image = image[permutation, :]
image.view(c, h, w)
return image