-
Notifications
You must be signed in to change notification settings - Fork 2
/
estimator.py
79 lines (59 loc) · 2.18 KB
/
estimator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import torch.nn.functional as F
from pathlib import Path
from torch import nn
from spleeter_pytorch.separator import Separator
class Estimator(nn.Module):
def __init__(self, num_instruments: int, checkpoint_path: Path):
super().__init__()
# stft config
self.F = 1024
self.T = 512
self.win_length = 4096
self.hop_length = 1024
self.win = nn.Parameter(
torch.hann_window(self.win_length),
requires_grad=False
)
self.separator = Separator(num_instruments=num_instruments, checkpoint_path=checkpoint_path)
def compute_stft(self, wav):
"""
Computes stft feature from wav
Args:
wav (Tensor): B x L
"""
stft = torch.stft(wav, n_fft=self.win_length, hop_length=self.hop_length, window=self.win,
center=True, return_complex=True, pad_mode='constant')
# only keep freqs smaller than self.F
stft = stft[:, :self.F, :]
mag = stft.abs()
return torch.view_as_real(stft), mag
def inverse_stft(self, stft):
"""Inverses stft to wave form"""
pad = self.win_length // 2 + 1 - stft.size(1)
stft = F.pad(stft, (0, 0, 0, 0, 0, pad))
stft = torch.view_as_complex(stft)
wav = torch.istft(stft, self.win_length, hop_length=self.hop_length, center=True,
window=self.win)
return wav.detach()
def forward(self, wav):
return self.separate(wav)
def separate(self, wav):
"""
Separates stereo wav into different tracks corresponding to different instruments
Args:
wav (tensor): 2 x L
"""
# stft (complex tensor) - 2 X F x L
# stft_mag - 2 X F x L
# Compute the STFT from the mixed wav
stft, stft_mag = self.compute_stft(wav)
# Perform the actual stem separation
masks = self.separator(stft_mag)
# Recover the wavs via an inverse STFT
wavs = []
for mask in masks:
stft_masked = stft * mask
wav_masked = self.inverse_stft(stft_masked)
wavs.append(wav_masked)
return wavs