Skip to content

Commit

Permalink
egs: add new exp egs using new trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
haoxiangsnr committed Dec 21, 2023
1 parent 7241dff commit 6b30648
Show file tree
Hide file tree
Showing 7 changed files with 505 additions and 143 deletions.
32 changes: 8 additions & 24 deletions recipes/intel_ndns/spike_fsb/model_low_freq.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ def deepfiltering(complex_spec, coefs, frame_size: int):
complex_spec = complex_spec.unsqueeze(-1) # [B, C, F, T, 1]

complex_coefs = torch.complex(coefs[..., 0], coefs[..., 1]) # [B, C, F, T]
complex_coefs = rearrange(
complex_coefs, "b (c df) f t -> b c df f t", df=frame_size
)
complex_coefs = rearrange(complex_coefs, "b (c df) f t -> b c df f t", df=frame_size)

# df
out = torch.einsum("...ftn,...nft->...ft", complex_spec, complex_coefs)
Expand Down Expand Up @@ -91,9 +89,7 @@ def __init__(
elif output_activate_function == "PReLU":
self.activate_function = nn.PReLU()
else:
raise NotImplementedError(
f"Not implemented activation function {self.activate_function}"
)
raise NotImplementedError(f"Not implemented activation function {self.activate_function}")

self.output_activate_function = output_activate_function
self.output_size = output_size
Expand Down Expand Up @@ -200,9 +196,7 @@ def cumulative_laplace_norm(input):
entry_count = entry_count.expand_as(cumulative_sum) # [1, T] => [B, T]

cumulative_mean = cumulative_sum / entry_count # B, T
cumulative_mean = cumulative_mean.reshape(
batch_size * num_channels, 1, num_frames
)
cumulative_mean = cumulative_mean.reshape(batch_size * num_channels, 1, num_frames)

normed = input / (cumulative_mean + EPSILON)

Expand Down Expand Up @@ -399,9 +393,7 @@ def _freq_unfold(

elif upper_cutoff_freq == num_freqs:
# lower = lower_cutoff_freq - num_neighbor_freqs, upper = num_freqs
valid_input = input[
..., lower_cutoff_freq - num_neighbor_freqs : num_freqs, :
]
valid_input = input[..., lower_cutoff_freq - num_neighbor_freqs : num_freqs, :]

valid_input = functional.pad(
input=valid_input,
Expand All @@ -412,9 +404,7 @@ def _freq_unfold(
# lower = lower_cutoff_freq - num_neighbor_freqs, upper = upper_cutoff_freq + num_neighbor_freqs
valid_input = input[
...,
lower_cutoff_freq
- num_neighbor_freqs : upper_cutoff_freq
+ num_neighbor_freqs,
lower_cutoff_freq - num_neighbor_freqs : upper_cutoff_freq + num_neighbor_freqs,
:,
]

Expand Down Expand Up @@ -574,9 +564,7 @@ def forward(self, noisy_y):
assert ndim in (2, 3), "Input must be 2D (B, T) or 3D tensor (B, 1, T)"

if ndim == 3:
assert (
noisy_y.size(1) == 1
), "Input must be 2D (B, T) or 3D tensor (B, 1, T)"
assert noisy_y.size(1) == 1, "Input must be 2D (B, T) or 3D tensor (B, 1, T)"
noisy_y = noisy_y.squeeze(1)

noisy_mag, _, noisy_real, noisy_imag = self.stft(noisy_y)
Expand Down Expand Up @@ -604,12 +592,8 @@ def forward(self, noisy_y):
for df_coefs, df_order in zip(df_coefs_list, self.sb_df_orders):
# [B, C, F , T] = [B, C, F, ]
num_sub_freqs = df_coefs.shape[2]
complex_stft_in = complex_stft[
..., num_filtered : num_filtered + num_sub_freqs, :
]
enhanced_subband = deepfiltering(
complex_stft_in, df_coefs, frame_size=df_order
) # [B, 1, F, T] of complex
complex_stft_in = complex_stft[..., num_filtered : num_filtered + num_sub_freqs, :]
enhanced_subband = deepfiltering(complex_stft_in, df_coefs, frame_size=df_order) # [B, 1, F, T] of complex
enhanced_spec_list.append(enhanced_subband)
num_filtered += num_sub_freqs

Expand Down
126 changes: 126 additions & 0 deletions recipes/intel_ndns/spiking_fullsubnet/baseline_m.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
[meta]
save_dir = "exp"
description = "Train a model using Generative Adversarial Networks (GANs)"
seed = 20220815

[trainer]
path = "trainer_v2.Trainer"
[trainer.args]
debug = false
max_steps = 0
max_epochs = 200
max_grad_norm = 10
save_max_score = true
save_ckpt_interval = 1
max_patience = 20
plot_norm = true
validation_interval = 1
max_num_checkpoints = 20
scheduler_name = "constant_schedule_with_warmup"
warmup_steps = 0
warmup_ratio = 0.00
gradient_accumulation_steps = 1

[loss_function]
path = "torch.nn.MSELoss"
[loss_function.args]

[optimizer]
path = "torch.optim.AdamW"
[optimizer.args]
lr = 1e-3

[optimizer_g]
path = "torch.optim.AdamW"
[optimizer_g.args]
lr = 1e-3

[optimizer_d]
path = "torch.optim.AdamW"
[optimizer_d.args]
lr = 1e-3

[lr_scheduler_g]
path = "torch.optim.lr_scheduler.ExponentialLR"
[lr_scheduler_g.args]
gamma = 0.99

[lr_scheduler_d]
path = "torch.optim.lr_scheduler.ExponentialLR"
[lr_scheduler_d.args]
gamma = 0.99

[model]
path = "model.SpikingFullSubNet"
[model.args]
n_fft = 512
hop_length = 128
win_length = 512
fdrc = 0.5
fb_input_size = 64
fb_hidden_size = 256
fb_num_layers = 2
fb_proj_size = 64
fb_output_activate_function = false
sb_hidden_size = 128
sb_num_layers = 2
freq_cutoffs = [0, 32, 128, 256]
df_orders = [5, 3, 1]
center_freq_sizes = [4, 32, 64]
neighbor_freq_sizes = [15, 15, 15]
use_pre_layer_norm_fb = true
use_pre_layer_norm_sb = true
bn = true
shared_weights = true
sequence_model = "GSN"

[model_d]
path = "discriminator.Discriminator"
[model_d.args]

[acoustics]
n_fft = 512
hop_length = 128
win_length = 512
sr = 16000

[train_dataset]
path = "dataloader.DNSAudio"
[train_dataset.args]
root = "/datasets/datasets_fullband/training_set/"
limit = false
offset = 0

[train_dataset.dataloader]
batch_size = 64
num_workers = 8
drop_last = true
# pin_memory = true


[[validate_dataset]]
path = "dataloader.DNSAudio"
[validate_dataset.args]
root = "/datasets/datasets_fullband/validation_set/"
train = false
[validate_dataset.dataloader]
batch_size = 16
num_workers = 8

[[validate_dataset]]
path = "dataloader.DNSAudio"
[validate_dataset.args]
root = "/datasets/datasets_fullband/validation_set_20230730/"
train = false
[validate_dataset.dataloader]
batch_size = 16
num_workers = 8

[test_dataset]
path = "dataloader.DNSAudio"
[test_dataset.args]
root = "/datasets/IntelNeuromorphicDNSChallenge-latest/data/MicrosoftDNS_4_ICASSP/test_set_1/"
train = false
[test_dataset.dataloader]
batch_size = 24
num_workers = 8
6 changes: 2 additions & 4 deletions recipes/intel_ndns/spiking_fullsubnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,14 @@ def __init__(
else:
self.proj = nn.Identity()

if output_activate_function is None:
self.output_activate_function = nn.Identity()
elif output_activate_function == "tanh":
if output_activate_function == "tanh":
self.output_activate_function = nn.Tanh()
elif output_activate_function == "sigmoid":
self.output_activate_function = nn.Sigmoid()
elif output_activate_function == "relu":
self.output_activate_function = nn.ReLU()
else:
raise NotImplementedError(f"Output activate function {output_activate_function} not implemented.")
self.output_activate_function = nn.Identity()

self.hidden_size = hidden_size
self.num_layers = num_layers
Expand Down
79 changes: 16 additions & 63 deletions recipes/intel_ndns/spiking_fullsubnet/model_low_freq.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from functools import partial
from operator import is_

import torch
import torch.nn as nn
Expand Down Expand Up @@ -174,11 +175,16 @@ def cumulative_laplace_norm(input):
"""Normalize the input with the cumulative mean
Args:
input: [B, C, F, T]
input: [B, C, F, T] or [B, N, C, F, T]
Returns:
[B, C, F, T] or [B, N, C, F, T]
"""
is_5d = input.dim() == 5
if is_5d:
assert input.shape[2] == 1, "Only mono audio is supported."
input = input.squeeze(2) # [B, N, C, F, T] => [B, N, F, T]

batch_size, num_channels, num_freqs, num_frames = input.size()
input = input.reshape(batch_size * num_channels, num_freqs, num_frames)

Expand All @@ -200,7 +206,12 @@ def cumulative_laplace_norm(input):

normed = input / (cumulative_mean + EPSILON)

return normed.reshape(batch_size, num_channels, num_freqs, num_frames)
normed = normed.reshape(batch_size, num_channels, num_freqs, num_frames)

if is_5d:
normed = normed.unsqueeze(2)

return normed

@staticmethod
def offline_gaussian_norm(input):
Expand Down Expand Up @@ -233,64 +244,6 @@ def norm_wrapper(self, norm_type: str):
return norm


class SubBandSequenceWrapper(SequenceModel):
def __init__(self, df_order, *args, **kwargs):
super().__init__(*args, **kwargs)
self.df_order = df_order

def forward(self, subband_input):
"""Forward pass.
Args:
subband_input: the input of shape [B, N, C, F_subband, T]
Returns:
output: the output of shape [B, df_order, N * F_subband_out, T, 2]
"""

(
batch_size,
num_subband_units,
num_channels,
num_subband_freqs,
num_frames,
) = subband_input.shape
assert num_channels == 1

output = rearrange(subband_input, "b n c fs t -> (b n) (c fs) t")
output, all_layer_outputs = super().forward(output)
output = rearrange(
output,
"(b n) (c fc df) t -> b df (n fc) t c",
b=batch_size,
c=num_channels * 2,
df=self.df_order,
)

# e.g., [B, 3, 20, T, 2]

return output, all_layer_outputs

# for deep filter
# [B, df_order, F, T, C]

# output = subband_input.reshape(
# batch_size * num_subband_units, num_subband_freqs, num_frames
# )
# output = super().forward(output)

# # [B, N, C, 2, center, T]
# output = output.reshape(batch_size, num_subband_units, 2, -1, num_frames)

# # [B, 2, N, center, T]
# output = output.permute(0, 2, 1, 3, 4).contiguous()

# # [B, C, N * F_subband_out, T]
# output = output.reshape(batch_size, 2, -1, num_frames)

# return output


class SubbandModel(BaseModel):
def __init__(
self,
Expand Down Expand Up @@ -626,8 +579,8 @@ def forward(self, noisy_y):
from audiozen.metric import compute_synops

config = toml.load(
"/home/xianghao/proj/audiozen/recipes/intel_ndns/spike_fsb/baseline_s.toml"
# "/home/xianghao/proj/audiozen/recipes/intel_ndns/spike_fsb/baseline_m_cumulative_laplace_norm.toml"
# "/home/xianghao/proj/audiozen/recipes/intel_ndns/spike_fsb/baseline_s.toml"
"/home/xhao/proj/spiking-fullsubnet/recipes/intel_ndns/spiking_fullsubnet/baseline_m_cumulative_laplace_norm.toml"
# "/home/xianghao/proj/audiozen/recipes/intel_ndns/spike_fsb/baseline_l.toml"
)
model_args = config["model_g"]["args"]
Expand Down
Loading

0 comments on commit 6b30648

Please sign in to comment.