-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlosses.py
executable file
·92 lines (69 loc) · 2.55 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
from utils import *
import cv2 as cv
import numpy as np
import torch
from torch import nn
from scipy.ndimage.morphology import distance_transform_edt as edt
from scipy.ndimage import convolve
"""
Hausdorff loss implementation based on paper:
https://arxiv.org/pdf/1904.10030.pdf
"""
class HausdorffDTLoss(nn.Module):
"""Binary Hausdorff loss based on distance transform"""
def __init__(self, alpha=2.0, **kwargs):
super(HausdorffDTLoss, self).__init__()
self.alpha = alpha
@torch.no_grad()
def distance_field(self, img: np.ndarray) -> np.ndarray:
field = np.zeros_like(img)
for batch in range(len(img)):
fg_mask = img[batch] > 0.5
if fg_mask.any():
bg_mask = ~fg_mask
fg_dist = edt(fg_mask)
bg_dist = edt(bg_mask)
field[batch] = fg_dist + bg_dist
return field
def forward(
self, pred: torch.Tensor, target: torch.Tensor
) -> torch.Tensor:
"""
Uses one binary channel: 1 - fg, 0 - bg
pred: (b, 1, x, y, z) or (b, 1, x, y)
target: (b, 1, x, y, z) or (b, 1, x, y)
"""
assert pred.dim() == 4 or pred.dim() == 5, "Only 2D and 3D supported"
assert (
pred.dim() == target.dim()
), "Prediction and target need to be of same dimension"
pred = torch.sigmoid(pred)
pred_dt = torch.from_numpy(
self.distance_field(pred.detach().cpu().numpy())
).float()
target_dt = torch.from_numpy(
self.distance_field(target.detach().cpu().numpy())
).float()
pred_error = (pred - target) ** 2
distance = pred_dt ** 2.0 + target_dt ** 2.0
distance = distance.cuda()
dt_field = pred_error * distance
loss = dt_field.mean()
return loss
class DiceLoss(nn.Module):
"""Calculate dice loss."""
def __init__(self, eps: float = 1e-9):
super(DiceLoss, self).__init__()
self.eps = eps
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
num = targets.size(0)
probability = torch.sigmoid(logits)
probability = probability.view(num, -1)
targets = targets.view(num, -1)
assert probability.shape == targets.shape
intersection = 2.0 * (probability * targets).sum()
union = probability.sum() + targets.sum()
dice_score = (intersection + self.eps) / union
# print("intersection", intersection, union, dice_score)
return 1.0 - dice_score