Skip to content

Commit

Permalink
First commit
Browse files Browse the repository at this point in the history
  • Loading branch information
fakufaku committed Apr 7, 2020
0 parents commit 197994e
Show file tree
Hide file tree
Showing 10 changed files with 727 additions and 0 deletions.
8 changes: 8 additions & 0 deletions bss_scale/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# BSS scale disambiguation algorithms
# Copyright (C) 2020 Robin Scheibler
from .algorithms import projection_back, minimum_distortion

algorithms = {
"projection_back": projection_back,
"minimum_distortion": minimum_distortion,
}
168 changes: 168 additions & 0 deletions bss_scale/algorithms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# BSS scale disambiguation algorithms
# Copyright (C) 2020 Robin Scheibler

import numpy as np

from .surrogate import lp_norm, lpq_norm


def projection_back(Y, X, ref_mic=0, **kwargs):
"""
Solves the scale ambiguity according to Murata et al., 2001.
This technique uses the steering vector value corresponding
to the demixing matrix obtained during separation.
Parameters
----------
Y: array_like (n_frames, n_bins, n_channels)
The STFT data to project back on the reference signal
ref: array_like (n_frames, n_bins)
The reference signal
Returns
-------
Y: array_like (n_frames, n_bins, n_channels)
The projected data
"""
n_frames, n_freq, n_chan = Y.shape

# find a bunch of non-zero frames
I_nz = np.argsort(np.linalg.norm(Y, axis=(1, 2)))[-n_chan:]

# Now we only need to solve a linear system of size n_chan x n_chan
# per frequency band
A = Y[I_nz, :, :].transpose([1, 0, 2])
b = X[I_nz, :].T
c = np.linalg.solve(A, b)

return c[None, :, :] * Y


def minimum_distortion_l2(Y, ref):
"""
This function computes the frequency-domain filter that minimizes
the squared error to a reference signal. This is commonly used
to solve the scale ambiguity in BSS.
Derivation of the projection
----------------------------
The optimal filter `z` minimizes the squared error.
.. math::
\min E[|z^* y - x|^2]
It should thus satsify the orthogonality condition
and can be derived as follows
.. math::
0 & = E[y^*\\, (z^* y - x)]
0 & = z^*\\, E[|y|^2] - E[y^* x]
z^* & = \\frac{E[y^* x]}{E[|y|^2]}
z & = \\frac{E[y x^*]}{E[|y|^2]}
In practice, the expectations are replaced by the sample
mean.
Parameters
----------
Y: array_like (n_frames, n_bins, n_channels)
The STFT data to project back on the reference signal
ref: array_like (n_frames, n_bins)
The reference signal
"""

num = np.sum(np.conj(ref[:, :, None]) * Y, axis=0)
denom = np.sum(np.abs(Y) ** 2, axis=0)
c = num / np.maximum(1e-15, denom)

return np.conj(c[None, :, :]) * Y


def minimum_distortion(
Y, ref, p=None, q=None, rtol=1e-3, max_iter=100,
):
"""
This function computes the frequency-domain filter that minimizes the sum
of errors to a reference signal. This is a sparse version of the projection
back that is commonly used to solve the scale ambiguity in BSS.
Derivation of the projection
----------------------------
The optimal filter `z` minimizes the expected absolute deviation.
.. math::
\min E[|z^* y - x|]
The optimization is done via the MM algorithm (i.e. IRLS).
Parameters
----------
Y: array_like (n_frames, n_freq, n_channels)
The STFT data to project back on the reference signal
ref: array_like (n_frames, n_freq)
The reference signal
p: float (0 < p <= 2)
The norm to use to measure distortion
q: float (0 < p <= q <= 2)
The other exponent when using a mixed norm to measure distortion
max_iter: int, optional
Maximum number of iterations
rtol: float, optional
Stop the optimization when the algorithm makes less than rtol relative progress
Returns
-------
The projected signal
The number of iterations
"""

# by default we do the regular minimum distortion
if p is None or (p == 2.0 and q is None):
return minimum_distortion_l2(Y, ref), 1

n_frames, n_freq, n_channels = Y.shape

c = np.ones(Y.shape, dtype=Y.dtype)

eps = 1e-15

prev_res = None

epoch = 0
while epoch < max_iter:

epoch += 1

# the current error
error = ref[:, :, None] - c * Y
if q is None or p == q:
res, weights = lp_norm(error, p=p)
else:
res, weights = lpq_norm(error, p=p, q=q, axis=1)

# minimize
num = np.sum(ref[:, :, None] * np.conj(Y) * weights, axis=0)
denom = np.sum(np.abs(Y) ** 2 * weights, axis=0)
c = num / np.maximum(eps, denom)

# condition for termination
if prev_res is None:
prev_res = res
continue

# relative step length
delta = (prev_res - res) / prev_res
prev_res = res
if delta < rtol:
break

return c[None, :, :] * Y, epoch
22 changes: 22 additions & 0 deletions bss_scale/surrogate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Some norm functions with the weights for MM algorithms
# Copyright (C) 2020 Robin Scheibler
import numpy as np

eps = 1e-15


def lp_norm(E, p=1):
assert p > 0 and p < 2
cost = np.sum(np.abs(E) ** p)
weights = p / np.maximum(eps, 2.0 * np.abs(E) ** (2 - p))
return cost, weights


def lpq_norm(E, p=1, q=2, axis=1):
assert p > 0 and q >= p and q <= 2.0

cost = np.sum(np.sum(np.abs(E) ** q, axis=axis, keepdims=True) ** (p / q))
rn = np.sum(np.abs(E) ** q, axis=axis, keepdims=True) ** (1 - p / q)
qfn = np.abs(E) ** (2 - q)
weights = p / np.maximum(eps, 2.0 * rn * qfn)
return cost, weights
109 changes: 109 additions & 0 deletions example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import argparse
import json
import numpy as np
from scipy.io import wavfile
import pyroomacoustics as pra
from pyroomacoustics.transform import stft

from metrics import si_bss_eval
import bss_scale

algorithms = {
"auxiva": pra.bss.auxiva,
"ilrma": pra.bss.ilrma,
"sparseauxiva": pra.bss.sparseauxiva,
"fastmnmf": pra.bss.fastmnmf,
}

DATA_META = "data/metadata.json"
REF_MIC = 0
RTOL = 1e-5

if __name__ == "__main__":

np.random.seed(0)

with open(DATA_META, "r") as f:
metadata = json.load(f)

mics_choices = [int(key[0]) for key in metadata]
algo_choices = list(algorithms.keys())

parser = argparse.ArgumentParser(description="Separation example")
parser.add_argument(
"-a",
"--algo",
type=str,
choices=algo_choices,
default=algo_choices[0],
help="BSS algorithm",
)
parser.add_argument(
"-m",
"--mics",
type=int,
choices=mics_choices,
default=mics_choices[0],
help="Number of channels",
)
parser.add_argument(
"-p", type=float, help="Outer norm",
)
parser.add_argument(
"-q", type=float, help="Inner norm",
)
parser.add_argument("-r", "--room", default=0, type=int, help="Room number")
parser.add_argument("-b", "--block", default=4096, type=int, help="STFT frame size")
args = parser.parse_args()

rooms = metadata[f"{args.mics}_channels"]

assert args.room >= 0 or args.room < len(
rooms
), f"Room must be between 0 and {len(rooms) - 1}"

# choose and read the audio files

# the mixtures
fn_mix = rooms[args.room]["mix_filename"]
fs, mix = wavfile.read(fn_mix)
mix = mix.astype(np.float64) / 2 ** 15

# the reference
fn_ref = rooms[args.room]["src_filenames"][REF_MIC]
fs, ref = wavfile.read(fn_ref)
ref = ref.astype(np.float64) / 2 ** 15

# STFT parameters
hop = args.block // 4
win_a = pra.hamming(args.block)
win_s = pra.transform.stft.compute_synthesis_window(win_a, hop)

# STFT
X = stft.analysis(mix, args.block, hop, win=win_a)

# Separation
if args.algo != "fastmnmf":
Y = algorithms[args.algo](X, n_iter=30, proj_back=False)
else:
Y = algorithms[args.algo](X, n_iter=30)

# Projection back

if args.p is None:
Y = bss_scale.projection_back(Y, X[:, :, REF_MIC])
else:
Y = bss_scale.minimum_distortion(Y, X[:, :, REF_MIC], p=args.p, q=args.q)

# iSTFT
y = stft.synthesis(Y, args.block, hop, win=win_s)
y = y[args.block - hop :]
if y.ndim == 1:
y = y[:, None]

# Evaluate
m = np.minimum(ref.shape[0], y.shape[0])
sdr, sir, sar, perm = si_bss_eval(ref[:m, :], y[:m, :])

# Reorder the signals
print(sdr)
66 changes: 66 additions & 0 deletions experiment1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import argparse
import json
from multiprocessing import Pool
from pathlib import Path

import numpy

import pyroomacoustics as pra
from process import bss_algorithms, process


def gen_args(config_fn):

with open(config_fn, "r") as f:
config = json.load(f)

with open(config["metadata_fn"], "r") as f:
metadata = json.load(f)

args = []

for label, room_list in metadata.items():

n_rooms = len(room_list)
n_channels = int(label[0])

for room_id in range(n_rooms):

for bss_algo in bss_algorithms.keys():

args.append([n_channels, room_id, bss_algo, str(config_fn)])

return args


if __name__ == "__main__":

parser = argparse.ArgumentParser(description="Run experiment in parallel")
parser.add_argument("config_file", type=str, help="Path to the configuration file")
parser.add_argument(
"-t",
"--test",
action="store_true",
help="Fix number of iterations to two for test purposes",
)
parser.add_argument(
"-s", "--seq", action="store_true", help="Run the experiment sequentially",
)
args = parser.parse_args()

sim_args = gen_args(args.config_file)

if args.test:
sim_args = sim_args[:2]

all_results = []

if args.seq:
for this_args in sim_args:
all_results += process(this_args)

else:
with Pool() as p:
results = p.map(process, sim_args)
for r in results:
all_results += r
19 changes: 19 additions & 0 deletions experiment1_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"metadata_fn": "./data/metadata.json",

"stft": {
"nfft": 4096,
"hop": 1024,
"window": "hamming"
},

"ref_mic": 0,

"minimum_distortion": {
"p_list": [0.1, 2.0, 0.1],
"kwargs": {
"rtol": 1e-5,
"max_iter": 100
}
}
}
Loading

0 comments on commit 197994e

Please sign in to comment.