-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrainSVM.m
45 lines (40 loc) · 1.43 KB
/
trainSVM.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
%%collects classifiers and stores them in an array.
function [classifier] = trainSVM(X, y, valX, valY)
radial=templateSVM('KernelFunction','rbf','IterationLimit',50000,'Standardize',true);
linear = templateSVM('KernelFunction','linear','IterationLimit',50000,'Standardize',true);
polynomial = templateSVM('KernelFunction','polynomial','IterationLimit',50000,'Standardize',true);
try
rbf.name = 'SVM';
rbf.model = fitcecoc(X, y, 'learners', radial, 'ClassNames',[unique(y)]);
lin.name = 'SVM';
lin.model = fitcecoc(X, y, 'learners', linear, 'ClassNames',[unique(y)]);
poly.name = 'SVM';
poly.model = fitcecoc(X, y, 'learners', polynomial, 'ClassNames',[unique(y)]);
predictRbf = predict(rbf.model, valX);
predictLin = predict(lin.model, valX);
predictPoly = predict(poly.model, valX);
accRbf = mean(predictRbf == valY);
accLin = mean(predictLin == valY);
accPoly = mean(predictPoly == valY);
temp = [accRbf, accLin, accPoly];
[~, idxs] = max(temp);
if idxs == 1
classifier = rbf;
elseif idxs == 2
classifier = lin;
elseif idxs == 3
classifier = poly;
else
classifier = rbf;
end
% if accRbf > accLin
% classifier = rbf;
% elseif accLin > accRbf
% classifier = lin;
% else
% classifier = rbf;
% end
catch exc
disp(sprintf('something happened in training %s \n', exc.identifier));
end
end