-
Notifications
You must be signed in to change notification settings - Fork 2
/
separator.py
106 lines (82 loc) · 2.85 KB
/
separator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import math
import torch
import torch.nn.functional as F
from pathlib import Path
from torch import nn
from spleeter_pytorch.unet import UNet
from spleeter_pytorch.util import tf2pytorch
def load_ckpt(model, ckpt):
state_dict = model.state_dict()
for k, v in ckpt.items():
if k in state_dict:
target_shape = state_dict[k].shape
assert target_shape == v.shape
state_dict.update({k: torch.from_numpy(v)})
else:
print('Ignore ', k)
model.load_state_dict(state_dict)
return model
def pad_and_partition(tensor, T):
"""
pads zero and partition tensor into segments of length T
Args:
tensor(Tensor): BxCxFxL
Returns:
tensor of size (B*[L/T] x C x F x T)
"""
old_size = tensor.size(3)
new_size = math.ceil(old_size/T) * T
tensor = F.pad(tensor, [0, new_size - old_size])
[b, c, t, f] = tensor.shape
split = new_size // T
return torch.cat(torch.split(tensor, T, dim=3), dim=0)
class Separator(nn.Module):
def __init__(self, num_instruments: int, checkpoint_path: Path):
super().__init__()
# stft config
self.F = 1024
self.T = 512
self.win_length = 4096
self.hop_length = 1024
self.win = nn.Parameter(
torch.hann_window(self.win_length),
requires_grad=False
)
ckpts = tf2pytorch(checkpoint_path, num_instruments)
# filter
self.instruments = nn.ModuleList()
for i in range(num_instruments):
print('Loading model for instrument {}'.format(i))
net = UNet(2)
ckpt = ckpts[i]
net = load_ckpt(net, ckpt)
net.eval() # change mode to eval
self.instruments.append(net)
def forward(self, stft_mag):
"""
Separates stereo wav into different tracks corresponding to different instruments
Args:
stft_mag (tensor): 2 x F x L
"""
L = stft_mag.shape[2]
# 1 x 2 x F x T
stft_mag = stft_mag.unsqueeze(-1).permute([3, 0, 1, 2])
stft_mag = pad_and_partition(stft_mag, self.T) # B x 2 x F x T
stft_mag = stft_mag.transpose(2, 3) # B x 2 x T x F
# compute instruments' mask
masks = []
for net in self.instruments:
mask = net(stft_mag)
masks.append(mask)
# compute denominator
mask_sum = sum([m ** 2 for m in masks])
mask_sum += 1e-10
normalized_masks = []
for mask in masks:
mask = (mask ** 2 + 1e-10/2)/(mask_sum)
mask = mask.transpose(2, 3) # B x 2 X F x T
mask = torch.cat(
torch.split(mask, 1, dim=0), dim=3)
mask = mask.squeeze(0)[:,:,:L].unsqueeze(-1) # 2 x F x L x 1
normalized_masks.append(mask)
return normalized_masks