-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathvisualization.py
61 lines (52 loc) · 1.87 KB
/
visualization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
plt.rcParams['figure.figsize'] = (10, 5)
def save_confusion_matrix(actual_labels, pred_labels, output_path):
""" Save plot confusion matrix using actual, predicted labels to a given output path
:param actual_labels: list
:param pred_labels: list
:param output_path: output path
"""
df = pd.DataFrame({'actual_labels': actual_labels, 'pred_labels': pred_labels}, columns=['actual_labels', 'pred_labels'])
cm_df = pd.crosstab(df['actual_labels'], df['pred_labels'], rownames=['Actual'], colnames=['Predicted'], margins=True)
plt.figure()
ax = sns.heatmap(cm_df, annot=True, fmt="d", cmap="Blues")
plt.savefig(output_path)
plt.clf()
def save_train_history(history, output_path):
""" Save plot train history
:param history: dictionary
:param output_path: output path
"""
plt.figure()
plt.plot(history['train_acc'], label='train accuracy')
plt.plot(history['val_acc'], label='val accuracy')
plt.title('Training history')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend()
plt.ylim([0, 1])
plt.savefig(output_path)
plt.clf()
def save_class_distribution(df, output_path):
""" Save plot class distribution found in a dataset
:param df: DataFrame
:param output_path: output path
"""
plt.figure()
ax = sns.countplot(df['label'])
plt.xlabel('label')
plt.savefig(output_path)
plt.clf()
def save_seq_len_distribution(df, output_path):
""" Save plot text length distribution found in a dataset
:param df: DataFrame
:param output_path: output path
"""
plt.figure()
sentences = df['text'].tolist()
seq_len = [len(sentence.split()) for sentence in sentences]
pd.Series(seq_len).hist(bins=30)
plt.savefig(output_path)
plt.clf()