-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathunbiased.py
51 lines (38 loc) · 2.2 KB
/
unbiased.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
import torch.nn.functional as F
def split_losses(preds: torch.Tensor, truth: torch.Tensor, magnification: int, eps=1e-32):
"""
preds: logits ([-inf, inf]) of predictions
truth: binary tensor containing ground truth 1 for positives and 0 for unknowns
magnification: by how much to magnify positive losses AFTER class imbalance has been corrected
eps (optinal): tiny value to avoid division by zero
automatically corrects class imbalance
"""
loss = F.binary_cross_entropy_with_logits
positives = preds[truth.bool()]
unknowns = preds[~truth.bool()]
imbalance_factor = unknowns.shape[0] / (positives.shape[0] + eps)
i_f = imbalance_factor * magnification
positive_plus = loss(positives, torch.ones_like(positives))
positive_minus = loss(positives, torch.zeros_like(positives))
unknown_minus = loss(unknowns, torch.zeros_like(unknowns))
return positive_plus * i_f, positive_minus * i_f, unknown_minus
def upu(preds: torch.Tensor, truth: torch.Tensor, magnification: int = 2):
"""
preds: logits ([-inf, inf]) of predictions (do not use sigmoid layer before this loss)
truth: binary tensor containing ground truth 1 for positives and 0 for unknowns
magnification: by how much to magnify positive losses AFTER class imbalance has been corrected
Equation 16 in https://arxiv.org/pdf/2103.04683.pdf
"""
positive_plus, positive_minus, unknown_minus = split_losses(preds, truth, magnification)
return positive_plus - positive_minus + unknown_minus
def nnpu(preds: torch.Tensor, truth: torch.Tensor, magnification: int = 2):
"""
preds: logits ([-inf, inf]) of predictions (do not use sigmoid layer before this loss)
truth: binary tensor containing ground truth 1 for positives and 0 for unknowns
magnification: by how much to magnify positive losses AFTER class imbalance has been corrected
From my tests, this does not penalize false-positive unknowns as much as upu does.
Equation 17 in https://arxiv.org/pdf/2103.04683.pdf
"""
positive_plus, positive_minus, unknown_minus = split_losses(preds, truth, magnification)
return positive_plus + torch.max(torch.Tensor((0, (unknown_minus - positive_minus))))