-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata.py
40 lines (33 loc) · 1.12 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import torchvision
from torch.utils.data import DataLoader
def load_mnist(flatten=True):
if flatten is True:
dataset_transform = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Lambda(lambda x: torch.flatten(x)),
]
)
else:
dataset_transform = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
]
)
train = torchvision.datasets.MNIST(
root="~/.torchdata/",
download=True,
# natively stored as PIL images
transform=dataset_transform,
)
test = torchvision.datasets.MNIST(
root="~/.torchdata/", download=True, train=False, transform=dataset_transform
)
train_loader = DataLoader(train, batch_size=100, shuffle=True)
# If flatten
# Returns (torch.Size([100, 784]), torch.Size([100]))
# Else
# Returns (torch.Size([100, 1, 28, 28]), torch.Size([100]))
test_loader = DataLoader(test, batch_size=500, shuffle=False)
return train_loader, test_loader