-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpreprocessor.py
104 lines (91 loc) · 3.63 KB
/
preprocessor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch.nn as nn
import librosa
import torch
import math
class Preprocessor(nn.Module):
"""Featurizer that converts wavs to Mel Spectrograms.
See AudioToMelSpectrogramPreprocessor for args.
"""
def __init__(
self,
sample_rate=16000,
window_size=0.025,
window_stride=0.01,
normalize="NA",
n_fft=512,
preemph=0.97,
nfilt=80,
lowfreq=0,
log_zero_guard_value=2 ** -24,
dither=0,
pad_to=0,
max_duration=16.7,
pad_value=0,
mag_power=2.0,
mel_norm="slaney",
):
super().__init__()
self.log_zero_guard_value = log_zero_guard_value
self.win_length = int(window_size * sample_rate)
self.hop_length = int(window_stride * sample_rate)
self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length))
window_fn = torch.hann_window
window_tensor = window_fn(self.win_length, periodic=False) if window_fn else None
self.register_buffer("window", window_tensor)
self.stft = lambda x: torch.stft(x,n_fft=self.n_fft,hop_length=self.hop_length,
win_length=self.win_length,
center= True,
window=self.window.to(dtype=torch.float),
return_complex=True,
)
self.normalize = normalize
self.dither = dither
self.nfilt = nfilt
self.preemph = preemph
self.pad_to = pad_to
highfreq = sample_rate / 2
filterbanks = torch.tensor(
librosa.filters.mel(
sr=sample_rate, n_fft=self.n_fft, n_mels=nfilt, fmin=lowfreq, fmax=highfreq, norm=mel_norm
),
dtype=torch.float,
).unsqueeze(0)
self.register_buffer("fb", filterbanks)
# Calculate maximum sequence length
max_length = self.get_seq_len(torch.tensor(max_duration * sample_rate, dtype=torch.float))
max_pad = pad_to - (max_length % pad_to) if pad_to > 0 else 0
self.max_length = max_length + max_pad
self.pad_value = pad_value
self.mag_power = mag_power
self.forward = torch.no_grad()(self.forward)
def get_seq_len(self, seq_len):
# Assuming that center is True is stft_pad_amount = 0
pad_amount = self.n_fft // 2 * 2
seq_len = torch.floor_divide((seq_len + pad_amount - self.n_fft), self.hop_length) + 1
return seq_len.to(dtype=torch.long)
def forward(self, x, seq_len):
seq_len = self.get_seq_len(seq_len)
# do preemphasis
x = torch.cat((x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]), dim=1)
# disable autocast to get full range of stft values
with torch.cuda.amp.autocast(enabled=False):
x = self.stft(x)
# torch stft returns complex tensor (of shape [B,N,T]); so convert to magnitude
# guard is needed for sqrt if grads are passed through
guard = 0
x = torch.view_as_real(x)
x = torch.sqrt(x.pow(2).sum(-1) + guard)
# get power spectrum
if self.mag_power != 1.0:
x = x.pow(self.mag_power)
# dot with filterbank energies
x = torch.matmul(self.fb.to(x.dtype), x)
# log features if required
x = torch.log(x + self.log_zero_guard_value)
# mask to zero any values beyond seq_len in batch, pad to multiple of `pad_to` (for efficiency)
max_len = x.size(-1)
mask = torch.arange(max_len).to(x.device)
mask = mask.repeat(x.size(0), 1) >= seq_len.unsqueeze(1)
x = x.masked_fill(mask.unsqueeze(1).type(torch.bool).to(device=x.device), self.pad_value)
del mask
return x, seq_len