-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathDataloader.py
79 lines (63 loc) · 2.59 KB
/
Dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import cv2
from os.path import join
import torch
import numpy as np
from sklearn import preprocessing
Scaler = preprocessing.MinMaxScaler(feature_range=(-1,1))
dataindex = 201
class MMIUnseenDataset(Dataset):
def __init__(self, z_dim,points_path):
self.data = pd.read_csv(points_path,header=None).to_numpy()
self.z_dim = z_dim
def __getitem__(self,index):
item = self.data[index]
# print(item)
# print(item.shape)
# points = item[0:dataindex-1].astype(np.float64)
points = torch.from_numpy(item.astype(np.float64))
points = torch.hstack([points, torch.randn(self.z_dim - len(points))])
points = points.reshape([self.z_dim, 1, 1])
# print(points.shape)
return points
class MMIDataset(Dataset):
def __init__(self, img_size, z_dim, points_path, img_folder):
self.data = pd.read_csv(points_path, header=0, index_col=None).to_numpy()
# self.data = pd.read_csv(points_path, header=0).to_numpy()
self.img_folder = img_folder
self.img_size = img_size
self.z_dim = z_dim
def __getitem__(self, index):
item = self.data[index]
img = cv2.imread(self.img_folder + '\\{}.png'.format(item[0]), cv2.IMREAD_GRAYSCALE)
img = cv2.resize(img, (self.img_size, self.img_size))[:, :, np.newaxis]
img = img / 255.0 * 2 -1
img = img.transpose(2, 0, 1)
img = torch.from_numpy(img)
points21 = item[1:dataindex].astype(np.float64).reshape(-1, 1)
points21 = Scaler.fit_transform(points21)
points21 = torch.from_numpy(points21).flatten(0)
# points21 = torch.from_numpy(points21)
points = item[1:dataindex].astype(np.float64).reshape(-1,1)
points = Scaler.fit_transform(points)
points = torch.from_numpy(points).flatten(0)
assert len(points) <= self.z_dim
points = torch.hstack([points, torch.randn(self.z_dim - len(points))])
points = points.reshape([self.z_dim, 1, 1])
# the shape of points should be [Z_DIM, CHANNELS_IMG, FEATURES_GEN]
return points, img, points21
def __len__(self):
return len(self.data)
def get_loader(
img_size,
batch_size,
z_dim,
points_path='',
img_folder='',
shuffle=True,
):
return DataLoader(MMIDataset(img_size, z_dim, points_path, img_folder),
batch_size=batch_size, shuffle=shuffle)
if __name__ == "__main__":
pass