-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathstubI_LSTMMessl3.m
73 lines (61 loc) · 2.48 KB
/
stubI_LSTMMessl3.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
function [Y data mask Xp] = stubI_LSTMMessl3(X, fail, fs, inFile, I, refMic, d, useHardMask, beamformer, lstm_dir, varargin)
% Multichannel MESSL mask with simple beamforming initialized from cross
% correlations between mics.
if ~exist('I', 'var') || isempty(I), I = 1; end
if ~exist('refMic', 'var') || isempty(refMic), refMic = 0; end
if ~exist('d', 'var') || isempty(d), d = 0.35; end
if ~exist('useHardMask', 'var') || isempty(useHardMask), useHardMask = true; end
if ~exist('beamformer', 'var') || isempty(beamformer), beamformer = 'bestMic'; end
% Check that mrfHardCompatExp is not zero
ind = find(strcmp(varargin, 'mrfHardCompatExp'));
if useHardMask && (isempty(ind) || (varargin{ind+1} == 0))
error('Must set "mrfHardCompatExp" to nonzero value with useHardMask')
end
%inFile is the filename for the auido file, so make use of it to find LSTM mask file.
lstm_file = fullfile(lstm_dir, regexprep(inFile, '(\.CH1)?\.wav$', '.mat'));
% six masks need to combine to one mask
LSTM_Mask = load(lstm_file); %something like this
LSTM_Mask = LSTM_Mask.mask.';
LSTM_Mask = LSTM_Mask(:,:,1);
LSTM_Mask = repmat(LSTM_Mask,1,1,2);
LSTM_Mask(:,:,2) = 1 - LSTM_Mask(:,:,1);
LSTM_Mask = LSTM_Mask(2:end-1,:,:);
maxSup_db = -40;
maxSup = 10^(maxSup_db/20);
tau = tauGrid(d, fs, 31);
fprintf('Max ITD: %g samples\n', tau(end));
maskInit = LSTM_Mask;
% See if reference mic has failed. refMic = 0 means all pairs.
if (refMic > 0) && fail(refMic)
refMic = find(~fail, 1, 'first');
if isempty(refMic)
error('All potential reference mics have failed')
end
end
% MESSL for mask
messlOpts = [{'GarbageSrc', 1, 'fixIPriors', 1, 'maskInit', maskInit, 'refMic', refMic, 'maskHold', Inf} varargin];
[p_lr_iwt params hardMasks] = messlMultichannel(X(2:end-1,:,~fail), tau, I, messlOpts{:});
if useHardMask
mask = squeeze(hardMasks(1,:,:,:));
else
mask = prob2mask(squeeze(p_lr_iwt(1,:,:,:)));
end
z = zeros([1 size(X,2) size(mask,3)]);
mask = cat(1, z, mask, z);
mask = maxSup + (1 - 2*maxSup) * mask;
switch beamformer
case 'bestMic'
Xp = pickChanWithBestSnr(X, mask, fail);
case 'mvdr'
[Xp mvdrMask mask] = maskDrivenMvdrMulti(X, mask, fail, params.perMicTdoa);
data.mvdrMask = single(mvdrMask);
case 'souden'
[Xp mvdrMask mask] = mvdrSoudenMulti(X, mask, fail);
data.mvdrMask = single(mvdrMask);
otherwise
error('Unknown beamformer: %s', beamformer)
end
data.mask = single(mask);
data.params = params;
% Output spectrogram(s)
Y = Xp .* mask;