-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathdemo_mnist_svmplus.m
50 lines (41 loc) · 1.65 KB
/
demo_mnist_svmplus.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
clear; clc;
addpath('./utils');
% load data
load('./data/mnist_plus.mat');
% preprocessing data with L1-normalization
train_features = L1_normalization(train_features');
test_features = L1_normalization(test_features');
train_PFfeatures = L1_normalization(train_PFfeatures');
train_labels(train_labels==5) = 1;
train_labels(train_labels~=1) = -1;
test_labels(test_labels==5) = 1;
test_labels(test_labels~=1) = -1;
% calculate kernels
kparam = struct();
kparam.kernel_type = 'gaussian';
[K, train_kparam] = getKernel(train_features, kparam);
testK = getKernel(test_features, train_features, train_kparam);
kparam = struct();
kparam.kernel_type = 'gaussian';
tK = getKernel(train_PFfeatures, kparam);
% ================ train SVM+ ====================
% parameters could be obtained via validation
svmplus_param.svm_C = 1;
svmplus_param.gamma = 1;
tic;
model = svm_plus_train(train_labels, K, tK, svmplus_param);
tt = toc;
decs = testK(:, model.SVs) * model.sv_coef - model.rho;
acc = sum((2*(decs>0)-1) == test_labels)/length(test_labels);
fprintf(2, 'Orignal SVM+, time = %f, Acc = %.4f.\n', tt, acc);
% ================ train l2-SVM+ ====================
% parameters could be obtained via validation
tic;
model = solve_l2svmplus_kernel(train_labels, K, tK, svmplus_param.svm_C, svmplus_param.gamma);
tt = toc;
alpha = zeros(length(train_labels), 1);
alpha(model.SVs) = full(model.sv_coef);
alpha = abs(alpha);
decs = (testK + 1)*(alpha.*train_labels);
acc = sum((2*(decs>0)-1) == test_labels)/length(test_labels);
fprintf(2, 'L2-SVM+, time=%f, Acc = %.4f.\n', tt, acc);