-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmain_kernel.py
193 lines (157 loc) · 5.93 KB
/
main_kernel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import os
import argparse
import time
import pickle
import torch
from utils import format_time, args2train_test_sizes
from datasets import dataset_initialization
from kernels import select_kernel
from sklearn.svm import SVC, LinearSVC
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def kernel_regression(ktrtr, ktetr, ytr, yte, ridge):
"""
Perform kernel ridge regression
:param ktrtr: train-train gram matrix
:param ktrtr: test-train gram matrix
:param ytr: training labels
:param yte: testing labels
:param ridge: ridge value
:return: mean square error.
"""
yte, ytr = yte.float(), ytr.float()
alpha = torch.linalg.solve(
ktrtr + ridge * torch.eye(ytr.size(0), device=ktrtr.device),
ytr
)
f = ktetr @ alpha
mse = (f - yte).pow(2).mean()
return mse
def svc(ktrtr, ktetr, ytr, yte, l, kernel='precomputed'):
"""
Train a Support Vector Classifier
:param ktrtr: train-train gram matrix
:param ktrtr: test-train gram matrix
:param ytr: training labels
:param yte: testing labels
:param l: l2 penalty
:return: classification error.
"""
clf = SVC(C=1/l, kernel=kernel, max_iter=-1)
clf.fit(ktrtr, ytr)
y_hat = torch.tensor(clf.predict(ktetr))
terr = 1 - y_hat.eq(yte).float().mean()
return terr
def linear_svc(xtr, xte, ytr, yte, l):
"""
Train a Support Vector Classifier
:param xtr: train samples
:param xte: test samples
:param ytr: training labels
:param yte: testing labels
:param l: l2 penalty
:return: classification error.
"""
clf = LinearSVC(C=1/l, max_iter=-1)
clf.fit(xtr, ytr)
y_hat = torch.tensor(clf.predict(xte))
terr = 1 - y_hat.eq(yte).float().mean()
return terr
def run_krr(args):
ptr, pte = args.ptr, args.pte
t1 = time.time()
def timing_fun(t1):
t2 = time.time()
print(format_time(t2 - t1), flush=True)
t1 = t2
return t1
args.device = device
# initialize dataset
print('Init dataset...', flush=True)
trainset, testset, _, _ = dataset_initialization(args)
xtr, ytr = trainset.dataset.x[:ptr].permute(0, 2, 1).flatten(1), trainset.dataset.targets[:ptr]
xte, yte = testset.x[:pte].permute(0, 2, 1).flatten(1), testset.targets[:pte]
t1 = timing_fun(t1)
if args.algo == 'linear_svc':
assert args.kernel == 'linear', "Kernel must be linear for linearSVC algo."
print('Linear SVC...', flush=True)
err = linear_svc(xtr, xte, ytr, yte, args.l)
timing_fun(t1)
else:
gram = select_kernel(args)
print('Compute gram matrix (train)...', flush=True)
ktrtr = gram(xtr, xtr)
t1 = timing_fun(t1)
print('Compute gram matrix (test)...', flush=True)
ktetr = gram(xte, xtr)
t1 = timing_fun(t1)
if args.algo == 'krr':
print('KRR...', flush=True)
err = kernel_regression(ktrtr, ktetr, ytr, yte, args.l)
timing_fun(t1)
elif args.algo == 'svc':
print('SVC...', flush=True)
err = svc(ktrtr, ktetr, ytr, yte, args.l)
timing_fun(t1)
else:
raise ValueError('`algo` argument is invalid, must be either svc or krr!')
res = {
'args': args,
'err': err.item(),
}
yield res
def main():
parser = argparse.ArgumentParser(
description="Perform a kernel method on hierarchical dataset."
)
"""
DATASET ARGS
"""
parser.add_argument("--dataset", type=str, default='hier1')
parser.add_argument("--num_features", type=int, default=8)
parser.add_argument("--m", type=int, default=2)
parser.add_argument("--num_layers", type=int, default=2)
parser.add_argument("--num_classes", type=int, default=-1)
parser.add_argument("--input_format", type=str, default="onehot")
parser.add_argument("--seed_init", type=int, default=-1)
parser.add_argument("--seed_trainset", type=int, default=-1)
parser.add_argument("--whitening", type=int, default=0)
"""
TRAINING ARGS
"""
parser.add_argument("--algo", type=str, required=True)
parser.add_argument("--kernel", type=str, required=True)
parser.add_argument("--ptr", metavar="P", type=float, help="size of the training set")
parser.add_argument("--pte", type=float, help="size of the validation set", default=512)
### ridge parameter ###
parser.add_argument("--l", metavar="lambda", type=float, help="regularisation parameter")
parser.add_argument("--output", type=str, required=False, default="None")
args = parser.parse_args()
args.loss = 'none'
args.auto_regression = 0
if args.seed_trainset == -1:
args.seed_trainset = args.seed_init
if args.num_classes == -1:
args.num_classes = args.num_features
if args.m == -1:
args.m = args.num_features
args.ptr, args.pte = args2train_test_sizes(args, max_pte=1000)
## PAY ATTENTION TO THIS!!! ##
# condition to run only if m^L/n < 1. #
# upper_bound = args.m ** args.num_layers / args.num_features ** (2 * (1 - 2 ** -args.num_layers)) < 0.8
# lower_bound = args.m ** args.num_layers / args.num_features ** (1 - 2 ** -args.num_layers) > 1.2
# if upper_bound and lower_bound:
# raise ValueError('Parameters outsize range for using couples to classify!')
# if args.ptr > 30000:
# raise ValueError("ptr too large!! (>30k)")
with open(args.output, "wb") as handle:
pickle.dump(args, handle)
try:
for data in run_krr(args):
with open(args.output, "wb") as handle:
pickle.dump(args, handle)
pickle.dump(data, handle)
except:
os.remove(args.output)
raise
if __name__ == "__main__":
main()