-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbestRBF.m
54 lines (50 loc) · 2.43 KB
/
bestRBF.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
function [ model ] = bestRBF( data, labels )
%BESTRBF Summary of this function goes here
% Detailed explanation goes here
% Use 10 fold cross validation
folds = 10;
% Ensure that standard values for SVM are included (C = 1 or C = 2^0
% and gamma = 1/num_factors or gamma = 2^-lg(num_factors))
nf = -log2(size(data, 2));
[C,gamma] = meshgrid(-5:5:15, (nf-12):6:(nf+12));
%# grid search, and cross-validation
coarseAcc = zeros(25,1);
parfor (i=1:numel(C), 4)
coarseAcc(i) = svmtrain(labels, data, ...
sprintf('-q -c %f -g %f -v %d -m 512', 2^C(i), 2^gamma(i), folds));
end
%# pair (C,gamma) with best accuracy
[~,idx] = max(coarseAcc);
fprintf('--------------------\nBest C-value: 2^%d\nBest gamma-value: 2^%d\n--------------------\n', C(idx), gamma(idx));
%# contour plot of paramter selection
contour(C, gamma, reshape(coarseAcc,size(C))), colorbar
hold on
plot(C(idx), gamma(idx), 'rx')
text(C(idx), gamma(idx), sprintf('Acc = %.2f %%',coarseAcc(idx)), ...
'HorizontalAlign','left', 'VerticalAlign','top')
hold off
xlabel('log_2(C)'), ylabel('log_2(\gamma)'), title('Cross-Validation Accuracy with coarse grid-search')
bestC = C(idx);
bestG = gamma(idx);
[Cf,gammaf] = meshgrid((bestC-1):0.5:(bestC+1), (bestG-0.75):0.5:(bestG+0.75));
fineAcc = zeros(20,1);
parfor (i = 1:numel(Cf), 4)
fineAcc(i+25) = svmtrain(labels, data, ...
sprintf('-q -c %f -g %f -v %d -m 512', 2^Cf(i), 2^gammaf(i), folds));
end
[~,idx] = max(fineAcc);
fprintf('--------------------\nBest C-value: 2^%d\nBest gamma-value: 2^%d\n--------------------\n', C(idx), gamma(idx));
%# contour plot of paramter selection
contour(Cf, gammaf, reshape(fineAcc,size(Cf))), colorbar
hold on
plot(Cf(idx), gammaf(idx), 'rx')
text(Cf(idx), gammaf(idx), sprintf('Acc = %.2f %%',fineAcc(idx)), ...
'HorizontalAlign','left', 'VerticalAlign','top')
hold off
xlabel('log_2(C)'), ylabel('log_2(\gamma)'), title('Cross-Validation Accuracy with fine grid-search')
% Retrain the model without cross validation - but with the best
% parameters
model = svmtrain(labels, data, ...
sprintf('-q -c %f -g %f -m 512', 2^Cf(idx), 2^gammaf(idx)));
fprintf('--------------------\nBest C-value: 2^%d\nBest gamma-value: 2^%d\n--------------------\n', C(idx), gamma(idx));
end