-
Notifications
You must be signed in to change notification settings - Fork 2
/
Training_and_Model_Accuracy.m
65 lines (51 loc) · 1.88 KB
/
Training_and_Model_Accuracy.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
clc
clear all
close all
warning off
%----------------LOAD THE EXTRACTED FEATURES----------------%
load hogfeatures
load trainingLabels
%-----------------PARTITIONING THE DATASET------------------%
cv = cvpartition(size(hogFeatures,1),'HoldOut',0.3);
idx = cv.test;
dataTrain=hogFeatures(~idx,:);
dataTrainL = trainingLabels(~idx,:);
dataTest=hogFeatures(idx,:);
dataTestL = trainingLabels(idx,:);
%-------------------TRAINING THE MODEL----------------------%
model = TreeBagger(700,dataTrain,dataTrainL,'OOBPrediction','On',...
'Method','classification');
save('classifier.mat', 'model', '-v7.3');
%---------------------ACCURACY CALCULATION-----------------------%
[prediction, scores]=predict(model,dataTest);
Accuracy = (sum(prediction==dataTestL)/size(dataTest,1))*100;
%--------------------CALCULATING MODEL PERFORMANCE CURVES--------%
%-----------------------PRECISION vs RECALL----------------------%
prediction = categorical(prediction);
confmat = confusionmat(dataTestL, prediction);
confchart = confusionchart(dataTestL,prediction);
for i =1:size(confmat,1)
recall(i)=confmat(i,i)/sum(confmat(i,:));
end
Recall = sum(recall)/size(confmat,1);
for i =1:size(confmat,1)
precision(i)=confmat(i,i)/sum(confmat(:,i));
end
Precision=sum(precision)/size(confmat,1);
F_score=2*Recall*Precision/(Precision+Recall);
%---------------------ROC CURVE--------------------------------%
figure;
title('ROC curve');
k = 02;
[X,Y,t,AUC] = perfcurve(dataTestL,scores(:,k),k-1);
plot(X,Y);
plot(X,Y,'LineWidth',1.25,'Color','b');
xlabel('False positive rate');
ylabel('True positive rate');
title('ROC for classification-AUC',num2str(AUC));
[Xr,Yp,tpr,AUC_pr] = perfcurve(dataTestL,scores(:,k),k-1,'Xcrit','reca','YCrit','prec');
figure;
plot(Xr,Yp,'LineWidth',1.25,'Color','r');
xlabel('Recall');
ylabel('Precision');
title('PR curve-AUC', num2str(AUC_pr));