-
Notifications
You must be signed in to change notification settings - Fork 72
/
Copy pathplotter.py
44 lines (38 loc) · 1.5 KB
/
plotter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import numpy as np
def plot_confusion_matrix(y_true, y_pred,
classes=[1, 2, 3, 4, 5],
normalize=False,
cmap=plt.cm.YlOrBr):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
(Adapted from scikit-learn docs).
"""
# Compute confusion matrix
cm = confusion_matrix(y_true, y_pred)
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
fig, ax = plt.subplots()
im = ax.imshow(cm, interpolation='nearest', origin='lower', cmap=cmap)
ax.figure.colorbar(im, ax=ax)
# Show all ticks
ax.set(xticks=np.arange(cm.shape[1]),
yticks=np.arange(cm.shape[0]),
# Label with respective list entries
xticklabels=classes, yticklabels=classes,
ylabel='True label',
xlabel='Predicted label')
# Set alignment of tick labels
plt.setp(ax.get_xticklabels(), rotation=0, ha="right",
rotation_mode="anchor")
# Loop over data dimensions and create text annotations
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax.text(j, i, format(cm[i, j], fmt),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
return fig, ax