Skip to content

Commit

Permalink
makeing irlma-t faster
Browse files Browse the repository at this point in the history
  • Loading branch information
fakufaku committed Apr 30, 2020
1 parent f305ce0 commit ced8d81
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 49 deletions.
55 changes: 16 additions & 39 deletions dereverb_separation.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,6 @@ def ilrma_t_iteration(
b_temp = ilrma_t_b_iteration(y, b, v, eps)
# dを更新する

"""
d_temp = np.einsum("bst,fsb->fts", v, b_temp)
costTempByFreq = log_likelihood_ilmra_t_by_frequency(
y, P[:, 0, :, :], d_temp, eps
)
"""

if use_increase_constraint == True:
for freq in range(fftMax):
if costOrgByFreq[freq] > costTempByFreq[freq]:
Expand All @@ -196,11 +188,6 @@ def ilrma_t_iteration(
# 時間周波数分散
d = np.einsum("bst,fsb->fts", v, b)

"""
costBByFreq = log_likelihood_ilmra_t_by_frequency(y, P[:, 0, :, :], d, eps)
costB = np.average(costBByFreq)
"""

if fixV == False:
# y: fftMax,frameNum,micNum
# b: fftMax,sourceNum,basis
Expand All @@ -210,12 +197,6 @@ def ilrma_t_iteration(

d_temp = np.einsum("bst,fsb->fts", v_temp, b)

"""
costTempByFreq = log_likelihood_ilmra_t_by_frequency(
y, P[:, 0, :, :], d_temp, eps
)
"""

if use_increase_constraint == True:
for freq in range(fftMax):
if costBByFreq[freq] > costTempByFreq[freq]:
Expand All @@ -225,11 +206,6 @@ def ilrma_t_iteration(

d = np.einsum("bst,fsb->fts", v, b)

"""
costVByFreq = log_likelihood_ilmra_t_by_frequency(y, P[:, 0, :, :], d, eps)
costV = np.average(costVByFreq)
"""

# フィルタを求める。
IP1 = True
IP2 = False
Expand Down Expand Up @@ -322,17 +298,7 @@ def ilrma_t_iteration(

y = np.einsum("fdnm,ftdn->ftm", np.conjugate(P), x)

# Projection Back
P00_H = np.conjugate(P[:, 0, :, :])
P00_H = np.transpose(P00_H, axes=[0, 2, 1])
# P00_H_eps=condition_covariance(P00_H,eps)
A = np.linalg.pinv(P00_H)
y_pb = np.einsum("fts,fms->fstm", y, A)

costPByFreq = log_likelihood_ilmra_t_by_frequency(y, P[:, 0, :, :], d, eps)
costP = np.average(costPByFreq)

return (y, y_pb, b, v, P)
return (y, b, v, P)


# KagamiICASSP2018を実装する
Expand Down Expand Up @@ -568,7 +534,7 @@ def ilrma_t_dereverb_separation(
P = np.transpose(P, axes=[0, 1, 3, 2])

for iter in range(iter_num):
y, y_pb, b, v, P = ilrma_t_iteration(
y, b, v, P = ilrma_t_iteration(
x_delay,
b,
v,
Expand All @@ -581,7 +547,7 @@ def ilrma_t_dereverb_separation(
)

# print(iter, costB, costV, costP)
return (y, y_pb)
return (y, P)


# x: freq,frame,mic
Expand Down Expand Up @@ -687,14 +653,24 @@ def dereverb_separate(
X = X.transpose([1, 0, 2]).copy()

if algorithm == "ilrma_t":
Y, Y_pb = ilrma_t_dereverb_separation(
Y, P = ilrma_t_dereverb_separation(
X,
iter_num=n_iter,
nmf_basis_num=n_components,
tap_num=n_taps,
delay_num=n_delays,
eps=1.0e-18,
)

# Projection Back
t_pb = time.perf_counter()
P00_H = np.conjugate(P[:, 0, :, :])
P00_H = np.transpose(P00_H, axes=[0, 2, 1])
# P00_H_eps=condition_covariance(P00_H,eps)
A = np.linalg.pinv(P00_H)
Y_pb = np.einsum("fts,fms->fstm", Y, A)
t_pb = time.perf_counter() - t_pb

elif algorithm == "kagami":
Y, Y_pb = kagami_dereverb_separation(
X,
Expand All @@ -704,6 +680,7 @@ def dereverb_separate(
delay_num=n_delays,
eps=1.0e-18,
)
t_pb = -1.0
else:
raise ValueError(f"Invalide algorithm {algorithm}")

Expand All @@ -714,7 +691,7 @@ def dereverb_separate(
Y_pb = Y_pb[:, :, :, 0].transpose([2, 0, 1]).copy()

if proj_back_both:
return Y, Y_pb
return Y, Y_pb, t_pb
if proj_back:
return Y_pb
else:
Expand Down
4 changes: 3 additions & 1 deletion example.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,10 @@
mix = mix.astype(np.float64) / 2 ** 15

# add some noise
sigma_n = np.std(mix) * 10 ** (-args.snr / 20)
sigma_src = np.std(mix)
sigma_n = sigma_src * 10 ** (-args.snr / 20)
mix += np.random.randn(*mix.shape) * sigma_n
print("SNR:", 10 * np.log10(sigma_src ** 2 / sigma_n ** 2))

# the reference
if args.algo in dereverb_algos:
Expand Down
2 changes: 2 additions & 0 deletions experiment1_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
},

"ref_mic": 0,
"si_metric": true,
"snr": 40,

"minimum_distortion": {
"p_list": [0.1, 2.0, 20],
Expand Down
48 changes: 39 additions & 9 deletions process.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import time
from pathlib import Path

import numpy as np
Expand Down Expand Up @@ -51,6 +52,9 @@ def process(args, config):

n_channels, room_id, bss_algo = args

# the name of the algorithm we'll use for bss
bss_algo_name = config["bss_algorithms"][bss_algo]["name"]

if mkl_available:
mkl.set_num_threads_local(1)

Expand All @@ -67,8 +71,17 @@ def process(args, config):
fn_mix = dataset_dir / rooms[room_id]["mix_filename"]
fs, mix = load_audio(fn_mix)

# add the noise
sigma_src = np.std(mix)
sigma_n = sigma_src * 10 ** (-args.snr / 20)
mix += np.random.randn(*mix.shape) * sigma_n

# the reference
fn_ref = dataset_dir / rooms[room_id]["src_filenames"][ref_mic]
if bss_algo_name in dereverb_algos:
# for dereverberation algorithms we use the anechoic reference signal
fn_ref = dataset_dir / rooms[room_id]["anechoic_filenames"][REF_MIC]
else:
fn_ref = dataset_dir / rooms[room_id]["src_filenames"][REF_MIC]
fs, ref = load_audio(fn_ref)

# STFT parameters
Expand All @@ -84,26 +97,31 @@ def process(args, config):
X = stft.analysis(mix, nfft, hop, win=win_a)

# Separation
bss_name = config["bss_algorithms"][bss_algo]["name"]
bss_kwargs = config["bss_algorithms"][bss_algo]["kwargs"]
n_iter_p_ch = config["bss_algorithms"][bss_algo]["n_iter_per_channel"]
if bss_algo == "fastmnmf":

runtime_bss = time.perf_counter()
if bss_algo_name == "fastmnmf":
Y = bss_algorithms[bss_name](X, n_iter=n_iter_p_ch * n_channels, **bss_kwargs)
elif bss_algo in dereverb_algos:
Y, Y_pb = bss_algorithms[bss_name](
elif bss_algo_name in dereverb_algos:
Y, Y_pb, runtime_pb = bss_algorithms[bss_name](
X, n_iter=n_iter_p_ch * n_channels, proj_back_both=True, **bss_kwargs
)
# adjust start time to remove the projection back
runtime_bss += runtime_pb
else:
Y = bss_algorithms[bss_name](
X, n_iter=n_iter_p_ch * n_channels, proj_back=False, **bss_kwargs
)
runtime_bss = time.perf_counter() - runtime_bss

results = []
results = [{"bss_runtime": {"bss_algo": bss_algo, "runtime": runtime_bss,}}]
t = {
"room_id": room_id,
"n_channels": n_channels,
"bss_algo": bss_algo,
"proj_algo": None,
"runtime": 0.0,
"sdr": None,
"sir": None,
"p": None,
Expand All @@ -113,7 +131,9 @@ def process(args, config):

# Evaluation of raw signal
t["proj_algo"] = "None"
y, sdr, sir, _ = reconstruct_evaluate(ref, Y, nfft, hop, win=win_s)
y, sdr, sir, _ = reconstruct_evaluate(
ref, Y, nfft, hop, win=win_s, si_metric=config["si_metric"]
)
t["sdr"], t["sir"] = sdr.tolist(), sir.tolist()
results.append(t.copy())

Expand All @@ -122,9 +142,14 @@ def process(args, config):
if bss_algo in dereverb_algos:
Z = Y_pb
else:
runtime_pb = time.perf_counter()
Z = bss_scale.projection_back(Y, X[:, :, ref_mic])
y, sdr, sir, _ = reconstruct_evaluate(ref, Z, nfft, hop, win=win_s)
runtime_pb = time.perf_counter() - runtime_pb
y, sdr, sir, _ = reconstruct_evaluate(
ref, Z, nfft, hop, win=win_s, si_metric=config["si_metric"]
)
t["sdr"], t["sir"] = sdr.tolist(), sir.tolist()
t["runtime"] = runtime_pb
results.append(t.copy())

# minimum distortion
Expand All @@ -139,13 +164,18 @@ def process(args, config):
"minimum_distortion",
)

runtime_md = time.perf_counter()
Z, t["n_iter"] = bss_scale.minimum_distortion(
Y, X[:, :, ref_mic], p=p, q=q, **kwargs
)
runtime_md = time.perf_counter() - runtime_md

y, sdr, sir, _ = reconstruct_evaluate(ref, Z, nfft, hop, win=win_s)
y, sdr, sir, _ = reconstruct_evaluate(
ref, Z, nfft, hop, win=win_s, si_metric=config["si_metric"]
)
t["sdr"] = sdr.tolist()
t["sir"] = sir.tolist()
t["runtime"] = runtime_md
results.append(t.copy())

return results

0 comments on commit ced8d81

Please sign in to comment.