-
Notifications
You must be signed in to change notification settings - Fork 0
/
load_dataset.py
117 lines (98 loc) · 2.77 KB
/
load_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
'''
provded By
Hemanth Venkateswara
hkdv1@asu.edu
Feb 2018
'''
import numpy as np
import os
import pdb
datasets_dir = './data/'
def one_hot(x,n):
if type(x) == list:
x = np.array(x)
x = x.flatten()
o_h = np.zeros((len(x),n))
o_h[np.arange(len(x)),x] = 1
return o_h
def mnist(ntrain=60000,ntest=10000,onehot=False,subset=True,digit_range=[0,2],shuffle=True):
data_dir = os.path.join(datasets_dir,'mnist/')
fd = open(os.path.join(data_dir,'train-images-idx3-ubyte'))
loaded = np.fromfile(file=fd,dtype=np.uint8)
trX = loaded[16:].reshape((60000,28*28)).astype(float)
fd = open(os.path.join(data_dir,'train-labels-idx1-ubyte'))
loaded = np.fromfile(file=fd,dtype=np.uint8)
trY = loaded[8:].reshape((60000)).astype(float)
fd = open(os.path.join(data_dir,'t10k-images-idx3-ubyte'))
loaded = np.fromfile(file=fd,dtype=np.uint8)
teX = loaded[16:].reshape((10000,28*28)).astype(float)
fd = open(os.path.join(data_dir,'t10k-labels-idx1-ubyte'))
loaded = np.fromfile(file=fd,dtype=np.uint8)
teY = loaded[8:].reshape((10000)).astype(float)
trX = trX/255.
teX = teX/255.
trX = trX[:ntrain]
trY = trY[:ntrain]
teX = teX[:ntest]
teY = teY[:ntest]
if onehot:
trY = one_hot(trY, 10)
teY = one_hot(teY, 10)
else:
trY = np.asarray(trY)
teY = np.asarray(teY)
if subset:
subset_label = np.arange(digit_range[0], digit_range[1])
train_data_sub = []
train_label_sub = []
test_data_sub = []
test_label_sub = []
for i in subset_label:
train_sub_idx = np.where(trY==i)
test_sub_idx = np.where(teY==i)
#pdb.set_trace()
A = trX[train_sub_idx[0],:]
C = teX[test_sub_idx[0],:]
if onehot:
B = trY[train_sub_idx[0],:]
D = teY[test_sub_idx[0],:]
else:
B = trY[train_sub_idx[0]]
D = teY[test_sub_idx[0]]
train_data_sub.append(A)
train_label_sub.append(B)
test_data_sub.append(C)
test_label_sub.append(D)
trX = train_data_sub[0]
trY = train_label_sub[0]
teX = test_data_sub[0]
teY = test_label_sub[0]
for i in range(digit_range[1]-digit_range[0]-1):
trX = np.concatenate((trX,train_data_sub[i+1]),axis=0)
trY = np.concatenate((trY,train_label_sub[i+1]),axis=0)
teX = np.concatenate((teX,test_data_sub[i+1]),axis=0)
teY = np.concatenate((teY,test_label_sub[i+1]),axis=0)
if shuffle:
train_idx = np.random.permutation(trX.shape[0])
test_idx = np.random.permutation(teX.shape[0])
trX = trX[train_idx,:]
teX = teX[test_idx,:]
if onehot:
trY = trY[train_idx,:]
teY = teY[test_idx,:]
else:
trY = trY[train_idx]
teY = teY[test_idx]
trX = np.squeeze(trX).T
teX = np.squeeze(teX).T
trY = trY.reshape(1,-1)
teY = teY.reshape(1,-1)
trY = trY-digit_range[0]
teY = teY-digit_range[0]
return trX, trY, teX, teY
'''
def main():
trx, trY, teX, teY = mnist()
if __name__ == "__main__":
main()
'''