Skip to content

Commit

Permalink
Data analysis and figures
Browse files Browse the repository at this point in the history
  • Loading branch information
fakufaku committed May 7, 2020
1 parent 041866c commit c925912
Show file tree
Hide file tree
Showing 4 changed files with 215 additions and 135 deletions.
259 changes: 149 additions & 110 deletions analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,48 @@ def find_best_pq_sdr(df, config):

# find the best variants under all conditions
for bss_algo in df["bss_algo"].unique():
# Find best parameters overall channels
df_ba = df[
(df["proj_algo"] == "minimum_distortion") & (df["bss_algo"] == bss_algo)
]

pt_sdr_fix = df_ba.pivot_table(
index="q", columns="p", values="sdr", aggfunc=np.mean,
)

pt_sir_fix = df_ba.pivot_table(
index="q", columns="p", values="sir", aggfunc=np.mean,
)

strategy_fix = {
"sir": {
"val": pt_sir_fix[CLASSIC_P][CLASSIC_Q],
"pq": [CLASSIC_P, CLASSIC_Q],
},
"sdr": {
"val": pt_sdr_fix[CLASSIC_P][CLASSIC_Q],
"pq": [CLASSIC_P, CLASSIC_Q],
},
}

for ip, p in enumerate(p_vals):
for q in p_vals[ip:]:

if pt_sdr_fix[p][q] > strategy_fix["sdr"]["val"]:
strategy_fix["sdr"]["val"] = pt_sdr_fix[p][q]
strategy_fix["sdr"]["pq"] = [p, q]

if (
pt_sdr_fix[p][q] >= pt_sdr_fix[CLASSIC_P][CLASSIC_Q]
and pt_n_iter[p][q] <= BEST_MAX_ITER
and pt_sir_fix[p][q] > strategy_fix["sir"]["val"]
):
strategy_fix["sir"]["val"] = pt_sir_fix[p][q]
strategy_fix["sir"]["pq"] = [p, q]

for n_chan in df["n_channels"].unique():

# Find best parameters for this number of channels
df_loc = df[
(df["proj_algo"] == "minimum_distortion")
& (df["bss_algo"] == bss_algo)
Expand All @@ -73,101 +113,74 @@ def find_best_pq_sdr(df, config):
# TBD replace by 2.0!
sdr_l2 = pt_sdr[CLASSIC_P][CLASSIC_Q]

max_sdr_lbl = "md_best_sdr"
max_sdr = sdr_l2
max_sdr_id = [CLASSIC_P, CLASSIC_Q]

max_sir_lbl = "md_best_sir"
max_sir = sdr_l2
max_sir_id = [CLASSIC_P, CLASSIC_Q]

max_n_iter_lbl = "md_best_sir_m10"
max_n_iter = sdr_l2
max_n_iter_id = [CLASSIC_P, CLASSIC_Q]
strategies = {
"sdr": {"val": sdr_l2, "pq": [CLASSIC_P, CLASSIC_Q],},
"sir": {
"val": pt_sir[CLASSIC_P][CLASSIC_Q],
"pq": [CLASSIC_P, CLASSIC_Q],
},
"sir_10": {
"val": pt_sir[CLASSIC_P][CLASSIC_Q],
"pq": [CLASSIC_P, CLASSIC_Q],
},
"sdr_fix": strategy_fix["sdr"],
"sir_fix": strategy_fix["sir"],
}

for ip, p in enumerate(p_vals):
for q in p_vals[ip:]:

if pt_sdr[p][q] > max_sdr:
max_sdr = pt_sdr[p][q]
max_sdr_id = [p, q]
if pt_sdr[p][q] > strategies["sdr"]["val"]:
strategies["sdr"]["val"] = pt_sdr[p][q]
strategies["sdr"]["pq"] = [p, q]

if pt_sdr[p][q] >= sdr_l2 and pt_sir[p][q] > max_sir:
if (
pt_sdr[p][q] >= sdr_l2
and pt_sir[p][q] > strategies["sir"]["val"]
):
strategies["sir"]["val"] = pt_sir[p][q]
strategies["sir"]["pq"] = [p, q]
max_sir = pt_sir[p][q]
max_sir_id = [p, q]

if (
pt_sdr[p][q] >= sdr_l2
and pt_n_iter[p][q] <= BEST_MAX_ITER
and pt_sir[p][q] > max_n_iter
and pt_sir[p][q] > strategies["sir_10"]["val"]
):
max_n_iter = pt_sir[p][q]
max_n_iter_id = [p, q]

sub_data.append(
df_loc[
(df_loc["proj_algo"] == "minimum_distortion")
& (df_loc["p"] == max_sdr_id[0])
& (df_loc["q"] == max_sdr_id[1])
].replace({"proj_algo": {"minimum_distortion": max_sdr_lbl}})
)

best_params.append(
{
"bss_algo": bss_algo,
"n_channels": n_chan,
"proj_algo": max_sdr_lbl,
"p": max_sdr_id[0],
"q": max_sdr_id[1],
"n_iter": pt_n_iter[max_sdr_id[0]][max_sdr_id[1]],
}
)

sub_data.append(
df_loc[
(df_loc["proj_algo"] == "minimum_distortion")
& (df_loc["p"] == max_sir_id[0])
& (df_loc["q"] == max_sir_id[1])
].replace({"proj_algo": {"minimum_distortion": max_sir_lbl}})
)

best_params.append(
{
"bss_algo": bss_algo,
"n_channels": n_chan,
"proj_algo": max_sir_lbl,
"p": max_sir_id[0],
"q": max_sir_id[1],
"n_iter": pt_n_iter[max_sir_id[0]][max_sir_id[1]],
}
)

sub_data.append(
df_loc[
(df_loc["proj_algo"] == "minimum_distortion")
& (df_loc["p"] == max_n_iter_id[0])
& (df_loc["q"] == max_n_iter_id[1])
].replace({"proj_algo": {"minimum_distortion": max_n_iter_lbl}})
)

best_params.append(
{
"bss_algo": bss_algo,
"n_channels": n_chan,
"proj_algo": max_n_iter_lbl,
"p": max_n_iter_id[0],
"q": max_n_iter_id[1],
"n_iter": pt_n_iter[max_n_iter_id[0]][max_n_iter_id[1]],
}
)
strategies["sir_10"]["val"] = pt_sir[p][q]
strategies["sir_10"]["pq"] = [p, q]

for label, strat in strategies.items():

new_label = f"gmd_{label}"

sub_data.append(
df_loc[
(df_loc["proj_algo"] == "minimum_distortion")
& (df_loc["p"] == strat["pq"][0])
& (df_loc["q"] == strat["pq"][1])
].replace({"proj_algo": {"minimum_distortion": new_label}})
)

best_params.append(
{
"bss_algo": bss_algo,
"n_channels": n_chan,
"proj_algo": new_label,
"p": strat["pq"][0],
"q": strat["pq"][1],
"n_iter": pt_n_iter[strat["pq"][0]][strat["pq"][1]],
}
)

sub_data = pd.concat(sub_data)
best_params_df = pd.DataFrame(best_params)

return sub_data, best_params_df


def print_table(sub_data, best_params, metrics=[]):
def print_table(sub_data, best_params, proj_algos=None, metrics=[]):
"""
Merge the two data frames and print the result as a nice latex table
"""
Expand Down Expand Up @@ -200,28 +213,33 @@ def print_table(sub_data, best_params, metrics=[]):

# now print this into a table that we can import in latex

# do everything if not specified
if proj_algos is None:
print("yo")
proj_algos = avgmet["proj_algo"].unique()

# First row has only the name of the algorithms
print(" & ", end="")
for proj_algo in avgmet["proj_algo"].unique():
for proj_algo in proj_algos:
print(f" & \\multicolumn{{3}}{{c}}{{ {proj_algo} }} ", end="")
print(" \\\\")

# Second row has the parameter names (algo/channels) and the metric names
print("\\text{Algo} & \\text{Channels} ", end="")
print("Algo. & Mics ", end="")
for proj_algo in avgmet["proj_algo"].unique():
for metric in metrics:
print(f" & \\text{{ {metric} }}", end="")
print(f" & {metric}", end="")
print(" \\\\")

for bss_algo in avgmet["bss_algo"].unique():
for m, n_chan in enumerate(avgmet["n_channels"].unique()):

if m == 0:
print(f"\\text{{ {bss_algo} }} & \\text{{ {n_chan} }} ", end="")
print(f" {bss_algo} & {n_chan} ", end="")
else:
print(f" & \\text{{ {n_chan} }} ", end="")
print(f" & {n_chan} ", end="")

for proj_algo in avgmet["proj_algo"].unique():
for proj_algo in proj_algos:
for metric in metrics:
val = avgmet[
(avgmet["bss_algo"] == bss_algo)
Expand Down Expand Up @@ -311,43 +329,64 @@ def draw_heatmap(*args, **kwargs):

ax = plt.gca()

for metric in ["sdr", "sir", "n_iter"]:
fg = sns.FacetGrid(
df,
col="n_channels",
row="bss_algo",
# row_order=algo_order[n_targets],
# margin_titles=True,
# aspect=aspect,
# height=height,
)
fg.map_dataframe(
draw_heatmap,
"p",
"q",
metric,
# cbar=False,
# vmin=0.0,
# vmax=1.0,
# xticklabels=[1, "", "", 4, "", "", 7, "", "", 10],
# yticklabels=[10, "", "", 40, "", "", 70, "", "", 100],
# yticklabels=yticklabels,
square=True,
)
for bss_algo in params["bss_algorithms"].keys():
for metric in ["sdr", "sir"]:
fg = sns.FacetGrid(
df[df["bss_algo"] == bss_algo],
col="n_channels",
row="bss_algo",
# row_order=algo_order[n_targets],
# margin_titles=True,
# aspect=aspect,
# height=height,
)
fg.map_dataframe(
draw_heatmap,
"p",
"q",
metric,
# cbar=False,
# vmin=0.0,
# vmax=1.0,
# xticklabels=[1, "", "", 4, "", "", 7, "", "", 10],
# yticklabels=[10, "", "", 40, "", "", 70, "", "", 100],
# yticklabels=yticklabels,
square=True,
)

fg.set_titles(template=f"{metric} " + "| {row_name} | mics={col_name}")
fg.set_titles(template=f"{metric} " + "| {row_name} | mics={col_name}")

for suffix in ["png", "pdf"]:
plt.savefig(output_dir / f"heatmaps_{metric}.{suffix}")
for suffix in ["png", "pdf"]:
plt.savefig(output_dir / f"heatmaps_{bss_algo}_{metric}.{suffix}")

"""
Now we will plot the box plots for key algorithms
"""

sub_df, best_params = find_best_pq_sdr(df, params)

print_table(sub_df, best_params, metrics=["sdr", "sir"])
print_table(
sub_df,
best_params,
metrics=["sdr", "sir"],
proj_algos=[
"projection_back",
"minimum_distortion_l2",
"gmd_sdr",
"gmd_sir",
"gmd_sir_10",
"gmd_sdr_fix",
# "gmd_sir_fix",
],
)

print("")

print_table(sub_df, best_params, metrics=["p", "q", "n_iter_med"])
print_table(
sub_df,
best_params,
metrics=["p", "q", "n_iter_med"],
proj_algos=["gmd_sdr", "gmd_sir_10", "gmd_sdr_fix"],
)

plt.show()
2 changes: 1 addition & 1 deletion bss_speech_dataset
34 changes: 19 additions & 15 deletions dereverb_separation.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,20 +174,6 @@ def ilrma_t_iteration(
# costOrgByFreq = log_likelihood_ilmra_t_by_frequency(y, P[:, 0, :, :], d, eps)
# costOrg = np.average(costOrgByFreq)

if fixB == False:
b_temp = ilrma_t_b_iteration(y, b, v, eps)
# dを更新する

if use_increase_constraint == True:
for freq in range(fftMax):
if costOrgByFreq[freq] > costTempByFreq[freq]:
b[freq, ...] = b_temp[freq, ...]
else:
b = b_temp

# 時間周波数分散
d = np.einsum("bst,fsb->fts", v, b)

if fixV == False:
# y: fftMax,frameNum,micNum
# b: fftMax,sourceNum,basis
Expand All @@ -206,6 +192,20 @@ def ilrma_t_iteration(

d = np.einsum("bst,fsb->fts", v, b)

if fixB == False:
b_temp = ilrma_t_b_iteration(y, b, v, eps)
# dを更新する

if use_increase_constraint == True:
for freq in range(fftMax):
if costOrgByFreq[freq] > costTempByFreq[freq]:
b[freq, ...] = b_temp[freq, ...]
else:
b = b_temp

# 時間周波数分散
d = np.einsum("bst,fsb->fts", v, b)

# フィルタを求める。
IP1 = True
IP2 = False
Expand Down Expand Up @@ -509,12 +509,16 @@ def ilrma_t_dereverb_separation(
weight = np.random.uniform(size=fftMax * source_num * nmf_basis_num)
weight = np.reshape(weight, [fftMax, source_num, nmf_basis_num])

# v: basis,sourceNum,frameNum
v = np.einsum("fst,fsb->bst", v, weight)
v_ave = np.mean(v, axis=2, keepdims=True)
v = v / (v_ave + 1.0e-14)
v = np.abs(v)
v = 0.2 * np.random.rand(nmf_basis_num, source_num, frameNum) + 0.8

b = np.ones(shape=(fftMax, source_num, nmf_basis_num))
# b: fftMax,sourceNum,basis
# b = np.ones(shape=(fftMax, source_num, nmf_basis_num))
b = 0.2 * np.random.rand(fftMax, source_num, nmf_basis_num) + 0.8

W = np.zeros(shape=(fftMax, channels, (tap_num + 1) * channels), dtype=np.complex)
W[:, :, :channels] = (
Expand Down
Loading

0 comments on commit c925912

Please sign in to comment.