Skip to content

Commit

Permalink
Making min dist a little faster
Browse files Browse the repository at this point in the history
  • Loading branch information
fakufaku committed Apr 30, 2020
1 parent ced8d81 commit 8e1488b
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 12 deletions.
16 changes: 8 additions & 8 deletions bss_scale/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def minimum_distortion_l2(Y, ref):


def minimum_distortion(
Y, ref, p=None, q=None, rtol=1e-3, max_iter=100,
Y, ref, p=None, q=None, rtol=1e-2, max_iter=100,
):
"""
This function computes the frequency-domain filter that minimizes the sum
Expand Down Expand Up @@ -135,7 +135,7 @@ def minimum_distortion(

eps = 1e-15

prev_res = None
prev_c = None

epoch = 0
while epoch < max_iter:
Expand All @@ -145,23 +145,23 @@ def minimum_distortion(
# the current error
error = ref[:, :, None] - c * Y
if q is None or p == q:
res, weights = lp_norm(error, p=p)
weights = lp_norm(error, p=p)
else:
res, weights = lpq_norm(error, p=p, q=q, axis=1)
weights = lpq_norm(error, p=p, q=q, axis=1)

# minimize
num = np.sum(ref[:, :, None] * np.conj(Y) * weights, axis=0)
denom = np.sum(np.abs(Y) ** 2 * weights, axis=0)
c = num / np.maximum(eps, denom)

# condition for termination
if prev_res is None:
prev_res = res
if prev_c is None:
prev_c = c
continue

# relative step length
delta = (prev_res - res) / prev_res
prev_res = res
delta = np.linalg.norm(c - prev_c) / np.linalg.norm(prev_c)
prev_c = c
if delta < rtol:
break

Expand Down
6 changes: 2 additions & 4 deletions bss_scale/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,14 @@

def lp_norm(E, p=1):
assert p > 0 and p < 2
cost = np.sum(np.abs(E) ** p)
weights = p / np.maximum(eps, 2.0 * np.abs(E) ** (2 - p))
return cost, weights
return weights


def lpq_norm(E, p=1, q=2, axis=1):
assert p > 0 and q >= p and q <= 2.0

cost = np.sum(np.sum(np.abs(E) ** q, axis=axis, keepdims=True) ** (p / q))
rn = np.sum(np.abs(E) ** q, axis=axis, keepdims=True) ** (1 - p / q)
qfn = np.abs(E) ** (2 - q)
weights = p / np.maximum(eps, 2.0 * rn * qfn)
return cost, weights
return weights
32 changes: 32 additions & 0 deletions experiment2_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
{
"metadata_fn": "./bss_speech_dataset/data/metadata.json",

"stft": {
"nfft": 512,
"hop": 256,
"window": "hamming"
},

"ref_mic": 0,
"si_metric": false,
"snr": 40,

"minimum_distortion": {
"p_list": [0.1, 2.0, 20],
"kwargs": {
"rtol": 1e-5,
"max_iter": 100
}
},

"bss_algorithms": {
"ilrma_t": {
"name": "ilrma_t",
"kwargs": {
"n_taps": 6,
"n_delays": 1
},
"n_iter_per_channel": 15
}
}
}
File renamed without changes.

0 comments on commit 8e1488b

Please sign in to comment.