-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathdata_loader.py
executable file
·115 lines (82 loc) · 3.84 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
import numpy as np
from log_utils import *
class DataLoader(object):
def __init__(self, data, args):
'''
dataset.shape = [num , 3, image_number]
dataset[0 , 1, :] # all data from task 0
dataset[0 , 2, :] # all label from task 0
'''
self.dataset = data
self.batch_size = args.batch_size
n_tasks = args.num_task
self.length = n_tasks
self.current_sample = 0
self.current_task = 0
self.sampler = self
def __iter__(self):
return self
def next(self):
return self.__next__()
def __next__(self):
'''
:return: (data, label) with shape batch_size
'''
if self.current_sample == self.dataset[self.current_task][1].shape[0]:
self.current_sample = 0 # reinitialize
self.shuffle_task()
raise StopIteration
elif self.current_sample + self.batch_size >= self.dataset[self.current_task][1].shape[0]:
last_size = self.dataset[self.current_task][1].shape[0] - self.current_sample
j = range(self.current_sample, self.current_sample + last_size)
self.current_sample = self.current_sample + last_size
j = torch.LongTensor(j)
return self.dataset[self.current_task][1][j], self.dataset[self.current_task][2][j]
else:
j = range(self.current_sample, self.current_sample + self.batch_size)
self.current_sample = self.current_sample + self.batch_size
j = torch.LongTensor(j)
return self.dataset[self.current_task][1][j], self.dataset[self.current_task][2][j]
def __len__(self):
return len(self.dataset[self.current_task][1])
def __getitem__(self, key):
self.current_sample = 0
self.current_task = key
return self
def shuffle_task(self):
indices = torch.randperm(len(self.dataset[self.current_task][1]))
self.dataset[self.current_task][1] = self.dataset[self.current_task][1][indices].clone()
self.dataset[self.current_task][2] = self.dataset[self.current_task][2][indices].clone()
def get_sample(self, number):
indices = torch.randperm(len(self))[0:number]
return self.dataset[self.current_task][1][indices], self.dataset[self.current_task][2][indices]
def concatenate(self, new_data, task=0):
'''
:param new_data: data to add to the actual task
:return: the actual dataset with supplementary data inside
'''
self.dataset[self.current_task][1] = torch.cat((self.dataset[self.current_task][1], new_data.dataset[task][1]), 0).clone()
self.dataset[self.current_task][2] = torch.cat((self.dataset[self.current_task][2], new_data.dataset[task][2]), 0).clone()
return self
def get_current_task(self):
return self.current_task
def save(self, path):
torch.save(self.dataset, path)
def visualize_sample(self, path , number, shape):
data, target = self.get_sample(number)
# get sample in order from 0 to 9
target, order = target.sort()
data = data[order]
image_frame_dim = int(np.floor(np.sqrt(number)))
if shape[2] == 1:
data = data.numpy().reshape(number, shape[0], shape[1], shape[2])
save_images(data[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
path)
else:
data = data.numpy().reshape(number, shape[2], shape[1], shape[0])
make_samples_batche(data[:number], number, path)
def increase_size(self, increase_factor):
self.dataset[self.current_task][1] = torch.cat([self.dataset[self.current_task][1]]*increase_factor, 0)
self.dataset[self.current_task][2] = torch.cat([self.dataset[self.current_task][2]]*increase_factor, 0)
return self