diff --git a/users/zeyer/experiments/exp2024_04_23_baselines/model_ext/ctc_sep_net.py b/users/zeyer/experiments/exp2024_04_23_baselines/model_ext/ctc_sep_net.py index 47f44ba05..0a9b0a936 100644 --- a/users/zeyer/experiments/exp2024_04_23_baselines/model_ext/ctc_sep_net.py +++ b/users/zeyer/experiments/exp2024_04_23_baselines/model_ext/ctc_sep_net.py @@ -793,7 +793,7 @@ def backward(ctx, grad_log_probs_main, grad_log_probs_sep): ny_sep_interpolated_scaled = ny_sep_scaled * (1 - beta) + ny_main_scaled * (beta * scale_ratio) else: ny_sep_interpolated_scaled = ny_sep_scaled - return ny_main_interpolated_scaled, ny_sep_interpolated_scaled + return ny_main_interpolated_scaled, ny_sep_interpolated_scaled, None, None log_probs_main, log_probs_sep = _InterpolateGradFunc.apply(log_probs_main, log_probs_sep, alpha, beta) return log_probs_main, log_probs_sep