This repository has been archived by the owner on Jul 24, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3
/
loss.py
149 lines (105 loc) · 3.75 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import numpy as np
EPSILON = np.spacing(1)
SPATIAL_DIMENSIONS = 2, 3, 4
class TverskyLoss:
def __init__(self, *, alpha, beta, epsilon=None):
self.alpha = alpha
self.beta = beta
self.epsilon = EPSILON if epsilon is None else epsilon
def __call__(self, output, target):
loss = get_tversky_loss(
output,
target,
self.alpha,
self.beta,
epsilon=self.epsilon,
)
return loss
class DiceLoss(TverskyLoss):
def __init__(self, epsilon=None):
super().__init__(alpha=0.5, beta=0.5, epsilon=epsilon)
def get_confusion(output, target):
if output.shape != target.shape:
message = (
f'Shape of output {output.shape} and target {target.shape} differ')
raise ValueError(message)
num_dimensions = output.ndim
if num_dimensions == 3: # 3D image, typically during testing
kwargs = {}
else: # 5D tensor, typically during training
is_torch_tensor = not isinstance(output, np.ndarray)
key = 'dim' if is_torch_tensor else 'axis'
kwargs = {key: SPATIAL_DIMENSIONS}
p0 = output
g0 = target
p1 = 1 - p0
g1 = 1 - g0
tp = (p0 * g0).sum(**kwargs)
fp = (p0 * g1).sum(**kwargs)
fn = (p1 * g0).sum(**kwargs)
return tp, fp, fn
def get_tversky_score(output, target, alpha, beta, epsilon=None):
"""
https://arxiv.org/pdf/1706.05721.pdf
"""
epsilon = EPSILON if epsilon is None else epsilon
tp, fp, fn = get_confusion(output, target)
numerator = tp + epsilon
denominator = tp + alpha * fp + beta * fn + epsilon
score = numerator / denominator
return score
def get_tversky_loss(*args, **kwargs):
losses = 1 - get_tversky_score(*args, **kwargs)
return losses
def get_f_score(output, target, beta, epsilon=None):
"""
https://en.wikipedia.org/wiki/F1_score#Definition
"""
epsilon = EPSILON if epsilon is None else epsilon
confusion = get_confusion(output, target)
precision = get_precision(confusion)
recall = get_recall(confusion)
score = (1 + beta**2) * (precision * recall) / (precision + recall + epsilon)
return score
def get_f_loss(*args, **kwargs):
losses = 1 - get_f_score(*args, **kwargs)
return losses
def get_f_score_alternative(output, target, beta, epsilon=None):
"""
See https://brenocon.com/blog/2012/04/f-scores-dice-and-jaccard-set-similarity/
"""
beta_tversky = 1 / (1 + beta**2)
alpha_tversky = 1 - beta_tversky
score = get_tversky_score(
output, target, alpha_tversky, beta_tversky, epsilon=epsilon)
return score
def get_dice_score(output, target):
alpha = beta = 0.5
return get_tversky_score(output, target, alpha, beta)
def get_dice_loss(output, target):
losses = 1 - get_dice_score(output, target)
return losses
def get_iou_score(output, target):
alpha = beta = 1
return get_tversky_score(output, target, alpha, beta)
def get_iou_loss(output, target):
losses = 1 - get_iou_score(output, target)
return losses
def get_precision_(output, target):
confusion = get_confusion(output, target)
return get_precision(confusion)
def get_recall_(output, target):
confusion = get_confusion(output, target)
return get_recall(confusion)
def get_precision(confusion, epsilon=None):
epsilon = EPSILON if epsilon is None else epsilon
tp, fp, _ = confusion
precision = (tp + epsilon) / (tp + fp + epsilon)
return precision
def get_recall(confusion, epsilon=None):
epsilon = EPSILON if epsilon is None else epsilon
tp, _, fn = confusion
recall = (tp + epsilon) / (tp + fn + epsilon)
return recall
def get_dice_from_precision_and_recall(precision, recall):
return 2 / (1/precision + 1/recall)