-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmetrics.py
238 lines (186 loc) · 8.78 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
import numpy as np
import pickle
import seaborn as sns
from sklearn import metrics as skmetrics
from pathlib import Path
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")
def lgb_MaF(preds, dtrain):
Y = np.array(dtrain.get_label(), dtype=np.int32)
preds = preds.reshape(-1, len(Y))
Y_pre = np.argmax(preds, axis=0)
return 'macro_f1', float(F1(preds.shape[0], Y_pre, Y, 'macro')), True
def lgb_precision(preds, dtrain):
Y = dtrain.get_label()
preds = preds.reshape(-1, len(Y))
Y_pre = np.argmax(preds, axis=0)
return 'precision', float(Counter(Y == Y_pre)[True]/len(Y)), True
id2lab = [[-1, -1]]*20
for a in range(1, 11):
for s in [1, 2]:
id2lab[a-1+(s-1)*10] = [a, s]
class Metrictor:
def __init__(self):
self._reporter_ = {"ACC": self.ACC, "AUC": self.AUC, "Precision": self.Precision,
"Recall": self.Recall, "F1": self.F1, "LOSS": self.LOSS}
def __call__(self, report, end='\n'):
res = {}
for mtc in report:
v = self._reporter_[mtc]()
print(f" {mtc}={v:6.3f}", end=';')
res[mtc] = v
print(end=end)
return res
def set_data(self, Y_prob_pre, Y, threshold=0.5):
self.Y = Y.astype('int')
if len(Y_prob_pre.shape) > 1:
self.Y_prob_pre = Y_prob_pre[:, 1]
self.Y_pre = Y_prob_pre.argmax(axis=-1)
else:
self.Y_prob_pre = Y_prob_pre
self.Y_pre = (Y_prob_pre > threshold).astype('int')
@staticmethod
def table_show(resList, report, rowName='CV'):
lineLen = len(report)*8 + 6
print("="*(lineLen//2-6) + "FINAL RESULT" + "="*(lineLen//2-6))
print(f"{'-':^6}" + "".join([f"{i:>8}" for i in report]))
for i, res in enumerate(resList):
print(f"{rowName+'_'+str(i+1):^6}" +
"".join([f"{res[j]:>8.3f}" for j in report]))
print(f"{'MEAN':^6}" +
"".join([f"{np.mean([res[i] for res in resList]):>8.3f}" for i in report]))
print("======" + "========"*len(report))
def each_class_indictor_show(self, id2lab):
print('Waiting for finishing...')
def ACC(self):
return ACC(self.Y_pre, self.Y)
def AUC(self):
return AUC(self.Y_prob_pre, self.Y)
def Precision(self):
return Precision(self.Y_pre, self.Y)
def Recall(self):
return Recall(self.Y_pre, self.Y)
def F1(self):
return F1(self.Y_pre, self.Y)
def LOSS(self):
return LOSS(self.Y_prob_pre, self.Y)
def calc_ROC(y_test, y_score, savePath, timestamp, plot=True):
picklefile = f'logs/ROC_{savePath}_{timestamp}.pkl' # create log with unique timestamp
plotfile = f'logs/plot_ROC_{savePath}_{timestamp}.png' # create log with unique timestamp
all_metrics = dict()
fpr = dict()
tpr = dict()
roc_auc = dict()
fpr['class'], tpr['class'], _ = skmetrics.roc_curve(y_test, y_score)
roc_auc['class'] = skmetrics.auc(fpr['class'], tpr['class'])
# Compute micro-average ROC curve and ROC area
fpr["micro"], tpr["micro"], _ = skmetrics.roc_curve(y_test.ravel(), y_score.ravel())
roc_auc["micro"] = skmetrics.auc(fpr["micro"], tpr["micro"])
all_metrics['fpr'] = fpr
all_metrics['tpr'] = tpr
all_metrics['roc_auc'] = roc_auc
pickle.dump(all_metrics, open(picklefile, 'wb'))
if plot:
plt.figure()
lw = 2
plt.plot(fpr['class'], tpr['class'], color='darkorange',
lw=lw, label='ROC curve (area = %0.2f)' % roc_auc['class'])
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate', fontsize=12)
plt.ylabel('True Positive Rate', fontsize=12)
plt.title('Receiver operating characteristic', fontsize=16)
plt.legend(loc="lower right", fontsize=8)
plt.savefig(plotfile, dpi=300)
plt.clf() # clear the plot object
def calc_conf_matrix(y_true, y_pred, savePath, timestamp, plot=True):
logfile = f'logs/CM_{savePath}_{timestamp}.txt' # create log with unique timestamp
plotfile = f'logs/plot_CM_{savePath}_{timestamp}.png' # create log with unique timestamp
y_pred = np.round(np.clip(y_pred, 0, 1)) # predicted values from continuous to 0,1
tn, fp, fn, tp = skmetrics.confusion_matrix(y_true, y_pred).ravel()
header = ['TN', 'FP', 'FN', 'TP', '\n']
with open(logfile, 'a') as out:
out.write(','.join(header))
out.write(f'{tn},{fp},{fn},{tp}\n')
if plot:
sns.set(rc={'figure.figsize':(8,6), 'axes.labelsize': 14})
y_pred = np.round(np.clip(y_pred, 0, 1))
cm = skmetrics.confusion_matrix(y_true, y_pred, normalize=None)
ax = sns.heatmap(cm, annot=True, fmt='g', cmap=plt.cm.cividis)
ax.set(xlabel='Actual', ylabel='Predicted')
plt.savefig(plotfile, dpi=300)
plt.clf() # clear the plot object
class MetricLog:
"""
log train and validation loss
"""
def __init__(self, savePath, timestamp, to_report):
Path("logs").mkdir(parents=True, exist_ok=True)
self.logger = f'logs/train_val_{savePath}_{timestamp}.txt' # create log with unique timestamp
self.best_results = f'logs/best_{savePath}_{timestamp}.txt' # create log with unique timestamp
self.plot_log = f'logs/learn_curve_{savePath}_{timestamp}.png' # create plot file with unique timestamp
self.to_report = to_report
self.save_train = list()
self.save_val = list()
header = [f'{mtc}_train' for mtc in to_report] + [f'{mtc}_valid' for mtc in to_report]
self.header_best = header + [f'{mtc}_test' for mtc in to_report]
self.write_header(header)
def log_train_val(self, train, val):
train_temp = [train[mtc] for mtc in self.to_report] # log LOSS and additional params in to_report param
val_temp = [val[mtc] for mtc in self.to_report] # log LOSS and additional params in to_report param
self.save_train.append(train_temp)
self.save_val.append(val_temp)
train_formatted = [f'{train[mtc]:.3f}' for mtc in self.to_report] # format to 3 digit floats
val_formatted = [f'{val[mtc]:.3f}' for mtc in self.to_report] # format to 3 digit floats
self.write_log(train_formatted, val_formatted)
def write_header(self, header):
with open(self.logger, 'a') as out:
out.write(f'{",".join(header)}\n')
def write_log(self, train_mtc, val_mtc):
"""
write all metrics in to_report for train and test
"""
with open(self.logger, 'a') as out:
out.write(f'{",".join(train_mtc)},{",".join(val_mtc)}\n')
def write_best(self, train, val, test):
test_form = [f'{test[mtc]:.3f}' for mtc in self.to_report] # format to 3 digit floats
train_form = [f'{train[mtc]:.3f}' for mtc in self.to_report] # format to 3 digit floats
val_form = [f'{val[mtc]:.3f}' for mtc in self.to_report] # format to 3 digit floats
with open(self.best_results, 'w') as out:
out.write(f'{",".join(self.header_best)}\n')
out.write(f'{",".join(train_form)},{",".join(val_form)},{",".join(test_form)}\n')
def plot_curve(self):
"""
default learn curve plotting with just LOSS
"""
idx = self.to_report.index('LOSS')
x = [i for i in range(1, len(self.save_train)+1)]
plt.figure(figsize=(10, 8))
fig, ax = plt.subplots()
ax.plot(x, [item[idx] for item in self.save_train], label='train loss', c='blue')
ax.plot(x, [item[idx] for item in self.save_val],label='validation loss', c='orange')
ax.tick_params(axis='both', which='major', labelsize=14)
ax.tick_params(axis='both', which='minor', labelsize=12)
plt.xticks(np.arange(0, max(x)+1, 16))
plt.legend(fontsize=12)
plt.title('Learning curve', fontsize=18)
plt.xlabel('Epochs', fontsize=16)
plt.ylabel('Loss', fontsize=16)
plt.savefig(self.plot_log, dpi=300)
def ACC(Y_pre, Y):
return (Y_pre == Y).sum() / len(Y)
def AUC(Y_prob_pre, Y):
return skmetrics.roc_auc_score(Y, Y_prob_pre)
def Precision(Y_pre, Y):
return skmetrics.precision_score(Y, Y_pre)
def Recall(Y_pre, Y):
return skmetrics.recall_score(Y, Y_pre)
def F1(Y_pre, Y):
return skmetrics.f1_score(Y, Y_pre)
def LOSS(Y_prob_pre, Y):
Y_prob_pre, Y = Y_prob_pre.reshape(-1), Y.reshape(-1)
Y_prob_pre[Y_prob_pre > 0.99] -= 1e-3
Y_prob_pre[Y_prob_pre < 0.01] += 1e-3
return -np.mean(Y*np.log(Y_prob_pre) + (1-Y)*np.log(1-Y_prob_pre))