-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathloss.py
65 lines (50 loc) · 2.13 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import torch.nn as nn
from loss_utils import mse_loss, stftm_loss, reg_loss
time_loss = mse_loss()
freq_loss = stftm_loss()
reg_loss = reg_loss()
SAMPLE_RATE = 48000
N_FFT = 1022
HOP_LENGTH = 256
def compLossMask(inp, nframes):
loss_mask = torch.zeros_like(inp).requires_grad_(False)
for j, seq_len in enumerate(nframes):
loss_mask.data[j, :, 0:seq_len] += 1.0 # loss_mask.shape: torch.Size([2, 1, 32512])
return loss_mask
class RegularizedLoss(nn.Module):
def __init__(self, gamma=1):
super().__init__()
self.gamma = gamma
'''
def mseloss(self, image, target):
x = ((image - target)**2)
return torch.mean(x)
'''
def wsdr_fn(self, x_, y_pred_, y_true_, eps=1e-8): # g1_wav, fg1_wav, g2_wav
y_pred = y_pred_.flatten(1)
y_true = y_true_.flatten(1)
x = x_.flatten(1)
def sdr_fn(true, pred, eps=1e-8):
num = torch.sum(true * pred, dim=1)
den = torch.norm(true, p=2, dim=1) * torch.norm(pred, p=2, dim=1)
return -(num / (den + eps))
# true and estimated noise
z_true = x - y_true
z_pred = x - y_pred
a = torch.sum(y_true ** 2, dim=1) / (torch.sum(y_true ** 2, dim=1) + torch.sum(z_true ** 2, dim=1) + eps)
wSDR = a * sdr_fn(y_true, y_pred) + (1 - a) * sdr_fn(z_true, z_pred)
return torch.mean(wSDR)
def regloss(self, g1, g2, G1, G2):
return torch.mean((g1-g2-G1+G2)**2)
def forward(self, g1_wav, fg1_wav, g2_wav, g1fx, g2fx):
if(g2_wav.shape[0] == 2):
nframes = [g2_wav.shape[2],g2_wav.shape[2]] # nframes: [32512, 32512]
else:
nframes = [g2_wav.shape[2]]
loss_mask = compLossMask(g2_wav, nframes)
loss_mask = loss_mask.float().cuda()
loss_time = time_loss(fg1_wav, g2_wav, loss_mask)
loss_freq = freq_loss(fg1_wav, g2_wav, loss_mask)
loss1 = (0.8 * loss_time + 0.2 * loss_freq)/600
return loss1 + self.wsdr_fn(g1_wav, fg1_wav, g2_wav) + self.gamma * self.regloss(fg1_wav, g2_wav, g1fx, g2fx)