-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataloader.py
executable file
·134 lines (93 loc) · 3.96 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data import random_split
from torch.utils.data import DataLoader
import time
import glob
class prepareData(Dataset):
"""
Skeleton for loading datasets in the episodic Fashion.
Args:
root_dir: path to dataset root dir
split: specialized to the dataset, but usually train/val/test
"""
def __init__(self, path='./dataPreprocessed/',view='coronal'):
fileList=sorted([f for f in glob.glob(path+view+'/*') if os.path.isfile(f)])
views=['sagital','frontal','coronal']
self.files=[np.load(file,allow_pickle=True) for file in fileList]
self.view=views.index(view)
self.len=self.files[0].shape[self.view]
def fetchImage(self,idx):
if self.view==2:
imgList=[file[:,:,idx] for file in self.files[:-1]]
annot=self.files[-1][:,:,idx]
if self.view==1:
imgList=[file[:,idx,:] for file in self.files[:-1]]
annot=self.files[-1][:,idx,:]
if self.view==0:
imgList=[file[idx,:,:] for file in self.files[:-1]]
annot=self.files[-1][idx,:,:]
return imgList, annot
def __getitem__(self, idx):
imgList, annot=self.fetchImage(idx)
imgs=np.stack(imgList)
return imgs, annot
def __len__(self):
return self.len
class NpToTensor(object):
"""
Convert `np.array` to `torch.Tensor`, but not like `ToTensor()`
from `torchvision` because we don't rescale the values.
"""
def __call__(self, arr):
return torch.from_numpy(np.ascontiguousarray(arr)).float()
def __repr__(self):
return self.__class__.__name__ + '()'
class SegToTensor(object):
def __call__(self, seg):
seg = torch.from_numpy(seg.astype(np.float)).float()
return seg.unsqueeze(0)
def __repr__(self):
return self.__class__.__name__ + '()'
class TransformData(Dataset):
"""
Transform a dataset by registering a transform for every input and the
target. Skip transformation by setting the transform to None.
Take
dataset: the `Dataset` to transform (which must be a `SegData`).
input_transforms: list of `Transform`s for each input
target_transform: `Transform` for the target image
"""
def __init__(self, dataset, input_transforms=None, target_transform=None):
#super().__init__(dataset)
self.ds = dataset
self.input_transforms = input_transforms
self.target_transform = target_transform
def __getitem__(self, idx):
# extract data from inner dataset
inputs, target = self.ds[idx]
inputs=self.input_transforms(inputs)
count=np.count_nonzero(target)
if self.target_transform is not None:
target =self.target_transform(target)
# repackage data
return inputs, target, count
def __len__(self):
return len(self.ds)
def prepare_data(path, view, testSplit=0.2, valSplit=0.2, batch_size=1):
# load the data
ds=prepareData(path,view)
# transforms for the input and target
image_transform = NpToTensor()
target_transform = SegToTensor()
# apply transforms and get class frequency
TransformedDS = TransformData(ds, input_transforms=image_transform, target_transform=target_transform)
# get the size of the splits
testSize= int(testSplit*len(TransformedDS))
valSize= int(valSplit*len(TransformedDS))
trainSize= len(TransformedDS)-testSize-valSize
#split dataset here
train_ds, test_ds, val_ds=random_split( TransformedDS ,[trainSize,testSize, valSize])
return DataLoader(train_ds, batch_size=1, shuffle=True), DataLoader(test_ds, batch_size=1, shuffle=True), DataLoader(val_ds, batch_size=1, shuffle=True)