-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 197994e
Showing
10 changed files
with
727 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# BSS scale disambiguation algorithms | ||
# Copyright (C) 2020 Robin Scheibler | ||
from .algorithms import projection_back, minimum_distortion | ||
|
||
algorithms = { | ||
"projection_back": projection_back, | ||
"minimum_distortion": minimum_distortion, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
# BSS scale disambiguation algorithms | ||
# Copyright (C) 2020 Robin Scheibler | ||
|
||
import numpy as np | ||
|
||
from .surrogate import lp_norm, lpq_norm | ||
|
||
|
||
def projection_back(Y, X, ref_mic=0, **kwargs): | ||
""" | ||
Solves the scale ambiguity according to Murata et al., 2001. | ||
This technique uses the steering vector value corresponding | ||
to the demixing matrix obtained during separation. | ||
Parameters | ||
---------- | ||
Y: array_like (n_frames, n_bins, n_channels) | ||
The STFT data to project back on the reference signal | ||
ref: array_like (n_frames, n_bins) | ||
The reference signal | ||
Returns | ||
------- | ||
Y: array_like (n_frames, n_bins, n_channels) | ||
The projected data | ||
""" | ||
n_frames, n_freq, n_chan = Y.shape | ||
|
||
# find a bunch of non-zero frames | ||
I_nz = np.argsort(np.linalg.norm(Y, axis=(1, 2)))[-n_chan:] | ||
|
||
# Now we only need to solve a linear system of size n_chan x n_chan | ||
# per frequency band | ||
A = Y[I_nz, :, :].transpose([1, 0, 2]) | ||
b = X[I_nz, :].T | ||
c = np.linalg.solve(A, b) | ||
|
||
return c[None, :, :] * Y | ||
|
||
|
||
def minimum_distortion_l2(Y, ref): | ||
""" | ||
This function computes the frequency-domain filter that minimizes | ||
the squared error to a reference signal. This is commonly used | ||
to solve the scale ambiguity in BSS. | ||
Derivation of the projection | ||
---------------------------- | ||
The optimal filter `z` minimizes the squared error. | ||
.. math:: | ||
\min E[|z^* y - x|^2] | ||
It should thus satsify the orthogonality condition | ||
and can be derived as follows | ||
.. math:: | ||
0 & = E[y^*\\, (z^* y - x)] | ||
0 & = z^*\\, E[|y|^2] - E[y^* x] | ||
z^* & = \\frac{E[y^* x]}{E[|y|^2]} | ||
z & = \\frac{E[y x^*]}{E[|y|^2]} | ||
In practice, the expectations are replaced by the sample | ||
mean. | ||
Parameters | ||
---------- | ||
Y: array_like (n_frames, n_bins, n_channels) | ||
The STFT data to project back on the reference signal | ||
ref: array_like (n_frames, n_bins) | ||
The reference signal | ||
""" | ||
|
||
num = np.sum(np.conj(ref[:, :, None]) * Y, axis=0) | ||
denom = np.sum(np.abs(Y) ** 2, axis=0) | ||
c = num / np.maximum(1e-15, denom) | ||
|
||
return np.conj(c[None, :, :]) * Y | ||
|
||
|
||
def minimum_distortion( | ||
Y, ref, p=None, q=None, rtol=1e-3, max_iter=100, | ||
): | ||
""" | ||
This function computes the frequency-domain filter that minimizes the sum | ||
of errors to a reference signal. This is a sparse version of the projection | ||
back that is commonly used to solve the scale ambiguity in BSS. | ||
Derivation of the projection | ||
---------------------------- | ||
The optimal filter `z` minimizes the expected absolute deviation. | ||
.. math:: | ||
\min E[|z^* y - x|] | ||
The optimization is done via the MM algorithm (i.e. IRLS). | ||
Parameters | ||
---------- | ||
Y: array_like (n_frames, n_freq, n_channels) | ||
The STFT data to project back on the reference signal | ||
ref: array_like (n_frames, n_freq) | ||
The reference signal | ||
p: float (0 < p <= 2) | ||
The norm to use to measure distortion | ||
q: float (0 < p <= q <= 2) | ||
The other exponent when using a mixed norm to measure distortion | ||
max_iter: int, optional | ||
Maximum number of iterations | ||
rtol: float, optional | ||
Stop the optimization when the algorithm makes less than rtol relative progress | ||
Returns | ||
------- | ||
The projected signal | ||
The number of iterations | ||
""" | ||
|
||
# by default we do the regular minimum distortion | ||
if p is None or (p == 2.0 and q is None): | ||
return minimum_distortion_l2(Y, ref), 1 | ||
|
||
n_frames, n_freq, n_channels = Y.shape | ||
|
||
c = np.ones(Y.shape, dtype=Y.dtype) | ||
|
||
eps = 1e-15 | ||
|
||
prev_res = None | ||
|
||
epoch = 0 | ||
while epoch < max_iter: | ||
|
||
epoch += 1 | ||
|
||
# the current error | ||
error = ref[:, :, None] - c * Y | ||
if q is None or p == q: | ||
res, weights = lp_norm(error, p=p) | ||
else: | ||
res, weights = lpq_norm(error, p=p, q=q, axis=1) | ||
|
||
# minimize | ||
num = np.sum(ref[:, :, None] * np.conj(Y) * weights, axis=0) | ||
denom = np.sum(np.abs(Y) ** 2 * weights, axis=0) | ||
c = num / np.maximum(eps, denom) | ||
|
||
# condition for termination | ||
if prev_res is None: | ||
prev_res = res | ||
continue | ||
|
||
# relative step length | ||
delta = (prev_res - res) / prev_res | ||
prev_res = res | ||
if delta < rtol: | ||
break | ||
|
||
return c[None, :, :] * Y, epoch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Some norm functions with the weights for MM algorithms | ||
# Copyright (C) 2020 Robin Scheibler | ||
import numpy as np | ||
|
||
eps = 1e-15 | ||
|
||
|
||
def lp_norm(E, p=1): | ||
assert p > 0 and p < 2 | ||
cost = np.sum(np.abs(E) ** p) | ||
weights = p / np.maximum(eps, 2.0 * np.abs(E) ** (2 - p)) | ||
return cost, weights | ||
|
||
|
||
def lpq_norm(E, p=1, q=2, axis=1): | ||
assert p > 0 and q >= p and q <= 2.0 | ||
|
||
cost = np.sum(np.sum(np.abs(E) ** q, axis=axis, keepdims=True) ** (p / q)) | ||
rn = np.sum(np.abs(E) ** q, axis=axis, keepdims=True) ** (1 - p / q) | ||
qfn = np.abs(E) ** (2 - q) | ||
weights = p / np.maximum(eps, 2.0 * rn * qfn) | ||
return cost, weights |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import argparse | ||
import json | ||
import numpy as np | ||
from scipy.io import wavfile | ||
import pyroomacoustics as pra | ||
from pyroomacoustics.transform import stft | ||
|
||
from metrics import si_bss_eval | ||
import bss_scale | ||
|
||
algorithms = { | ||
"auxiva": pra.bss.auxiva, | ||
"ilrma": pra.bss.ilrma, | ||
"sparseauxiva": pra.bss.sparseauxiva, | ||
"fastmnmf": pra.bss.fastmnmf, | ||
} | ||
|
||
DATA_META = "data/metadata.json" | ||
REF_MIC = 0 | ||
RTOL = 1e-5 | ||
|
||
if __name__ == "__main__": | ||
|
||
np.random.seed(0) | ||
|
||
with open(DATA_META, "r") as f: | ||
metadata = json.load(f) | ||
|
||
mics_choices = [int(key[0]) for key in metadata] | ||
algo_choices = list(algorithms.keys()) | ||
|
||
parser = argparse.ArgumentParser(description="Separation example") | ||
parser.add_argument( | ||
"-a", | ||
"--algo", | ||
type=str, | ||
choices=algo_choices, | ||
default=algo_choices[0], | ||
help="BSS algorithm", | ||
) | ||
parser.add_argument( | ||
"-m", | ||
"--mics", | ||
type=int, | ||
choices=mics_choices, | ||
default=mics_choices[0], | ||
help="Number of channels", | ||
) | ||
parser.add_argument( | ||
"-p", type=float, help="Outer norm", | ||
) | ||
parser.add_argument( | ||
"-q", type=float, help="Inner norm", | ||
) | ||
parser.add_argument("-r", "--room", default=0, type=int, help="Room number") | ||
parser.add_argument("-b", "--block", default=4096, type=int, help="STFT frame size") | ||
args = parser.parse_args() | ||
|
||
rooms = metadata[f"{args.mics}_channels"] | ||
|
||
assert args.room >= 0 or args.room < len( | ||
rooms | ||
), f"Room must be between 0 and {len(rooms) - 1}" | ||
|
||
# choose and read the audio files | ||
|
||
# the mixtures | ||
fn_mix = rooms[args.room]["mix_filename"] | ||
fs, mix = wavfile.read(fn_mix) | ||
mix = mix.astype(np.float64) / 2 ** 15 | ||
|
||
# the reference | ||
fn_ref = rooms[args.room]["src_filenames"][REF_MIC] | ||
fs, ref = wavfile.read(fn_ref) | ||
ref = ref.astype(np.float64) / 2 ** 15 | ||
|
||
# STFT parameters | ||
hop = args.block // 4 | ||
win_a = pra.hamming(args.block) | ||
win_s = pra.transform.stft.compute_synthesis_window(win_a, hop) | ||
|
||
# STFT | ||
X = stft.analysis(mix, args.block, hop, win=win_a) | ||
|
||
# Separation | ||
if args.algo != "fastmnmf": | ||
Y = algorithms[args.algo](X, n_iter=30, proj_back=False) | ||
else: | ||
Y = algorithms[args.algo](X, n_iter=30) | ||
|
||
# Projection back | ||
|
||
if args.p is None: | ||
Y = bss_scale.projection_back(Y, X[:, :, REF_MIC]) | ||
else: | ||
Y = bss_scale.minimum_distortion(Y, X[:, :, REF_MIC], p=args.p, q=args.q) | ||
|
||
# iSTFT | ||
y = stft.synthesis(Y, args.block, hop, win=win_s) | ||
y = y[args.block - hop :] | ||
if y.ndim == 1: | ||
y = y[:, None] | ||
|
||
# Evaluate | ||
m = np.minimum(ref.shape[0], y.shape[0]) | ||
sdr, sir, sar, perm = si_bss_eval(ref[:m, :], y[:m, :]) | ||
|
||
# Reorder the signals | ||
print(sdr) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import argparse | ||
import json | ||
from multiprocessing import Pool | ||
from pathlib import Path | ||
|
||
import numpy | ||
|
||
import pyroomacoustics as pra | ||
from process import bss_algorithms, process | ||
|
||
|
||
def gen_args(config_fn): | ||
|
||
with open(config_fn, "r") as f: | ||
config = json.load(f) | ||
|
||
with open(config["metadata_fn"], "r") as f: | ||
metadata = json.load(f) | ||
|
||
args = [] | ||
|
||
for label, room_list in metadata.items(): | ||
|
||
n_rooms = len(room_list) | ||
n_channels = int(label[0]) | ||
|
||
for room_id in range(n_rooms): | ||
|
||
for bss_algo in bss_algorithms.keys(): | ||
|
||
args.append([n_channels, room_id, bss_algo, str(config_fn)]) | ||
|
||
return args | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
parser = argparse.ArgumentParser(description="Run experiment in parallel") | ||
parser.add_argument("config_file", type=str, help="Path to the configuration file") | ||
parser.add_argument( | ||
"-t", | ||
"--test", | ||
action="store_true", | ||
help="Fix number of iterations to two for test purposes", | ||
) | ||
parser.add_argument( | ||
"-s", "--seq", action="store_true", help="Run the experiment sequentially", | ||
) | ||
args = parser.parse_args() | ||
|
||
sim_args = gen_args(args.config_file) | ||
|
||
if args.test: | ||
sim_args = sim_args[:2] | ||
|
||
all_results = [] | ||
|
||
if args.seq: | ||
for this_args in sim_args: | ||
all_results += process(this_args) | ||
|
||
else: | ||
with Pool() as p: | ||
results = p.map(process, sim_args) | ||
for r in results: | ||
all_results += r |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
{ | ||
"metadata_fn": "./data/metadata.json", | ||
|
||
"stft": { | ||
"nfft": 4096, | ||
"hop": 1024, | ||
"window": "hamming" | ||
}, | ||
|
||
"ref_mic": 0, | ||
|
||
"minimum_distortion": { | ||
"p_list": [0.1, 2.0, 0.1], | ||
"kwargs": { | ||
"rtol": 1e-5, | ||
"max_iter": 100 | ||
} | ||
} | ||
} |
Oops, something went wrong.