-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathstats.py
86 lines (67 loc) · 2.74 KB
/
stats.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import csv
import os
from brokenaxes import brokenaxes
import matplotlib.pyplot as plt
import numpy as np
from utils import acc, ex
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
@ex.capture
def stats(predict_dir):
"""Calculates prediction and uncertainty statistics."""
bp = np.load(predict_dir + "/bayesian/bayesian_pred.npy").squeeze()
bu = np.load(predict_dir + "/bayesian/bayesian_unc.npy").squeeze()
dp = np.load(predict_dir + "/dropout/dropout_pred.npy").squeeze()
du = np.load(predict_dir + "/dropout/dropout_unc.npy").squeeze()
y = np.load(predict_dir + "/test_targets.npy").squeeze()
with open(predict_dir + "/stats.csv", "w") as csvfile:
w = csv.writer(csvfile, delimiter=" ")
w.writerow(["Category", "Dropout", "Bayesian"])
w.writerow(["Pred_Acc", acc(dp, y), acc(bp, y)])
w.writerow(["Unc_Mean", du.mean(), bu.mean()])
w.writerow(["Unc_Var", du.var(), bu.var()])
w.writerow(["Unc_Max", du.max(), bu.max()])
w.writerow(["Unc_Min", du.min(), bu.min()])
@ex.capture
def plots(images_dir, predict_dir):
"""Plots histograms of uncertainty values."""
bu = np.load(predict_dir + "/bayesian/bayesian_unc.npy").flatten()
du = np.load(predict_dir + "/dropout/dropout_unc.npy").flatten()
# Removes extreme outliers so plot isn't stretched out.
xlim = round(max(np.percentile(bu, 99.95), np.percentile(du, 99.95)), 2)
bu = bu[bu < xlim]
du = du[du < xlim]
# Automatically calculates y-axis heights.
bu_max = np.count_nonzero(bu == 0.)
bu_mid = np.partition(np.histogram(bu, bins=50)[0], -2)[-2]
du_max = np.count_nonzero(du == 0.)
du_mid = np.partition(np.histogram(du, bins=50)[0], -2)[-2]
# Plots histogram of Bayesian uncertainty map.
fig = plt.figure()
if bu_mid > 0:
bax = brokenaxes(ylims=((0, bu_mid), (bu_max - (bu_mid / 5), bu_max)))
bax.hist(bu, bins=50)
else:
plt.hist(bu, bins=50)
plt.title("Distribution of Bayesian uncertainty map")
# plt.xlabel("Uncertainty value")
# plt.ylabel("Count")
plt.savefig(images_dir + "/bayesian/bayesian_unc_dist.png")
plt.clf()
# Plots histogram of dropout uncertainty map.
fig = plt.figure()
if du_mid > 0:
bax = brokenaxes(ylims=((0, du_mid), (du_max - (du_mid / 5), du_max)))
bax.hist(du, bins=50)
else:
plt.hist(du, bins=50)
plt.title("Distribution of dropout uncertainty map")
# plt.xlabel("Uncertainty value")
# plt.ylabel("Count")
plt.savefig(images_dir + "/dropout/dropout_unc_dist.png")
plt.clf()
@ex.automain
def get_stats_and_plots(images_dir):
os.makedirs(images_dir + "/bayesian", exist_ok=True)
os.makedirs(images_dir + "/dropout", exist_ok=True)
stats()
plots()