Skip to content

Commit

Permalink
small fix
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Jan 7, 2025
1 parent e2cf1bf commit 2a3e684
Showing 1 changed file with 1 addition and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,7 @@ def backward(ctx, grad_log_probs_main, grad_log_probs_sep):
ny_sep_interpolated_scaled = ny_sep_scaled * (1 - beta) + ny_main_scaled * (beta * scale_ratio)
else:
ny_sep_interpolated_scaled = ny_sep_scaled
return ny_main_interpolated_scaled, ny_sep_interpolated_scaled
return ny_main_interpolated_scaled, ny_sep_interpolated_scaled, None, None

log_probs_main, log_probs_sep = _InterpolateGradFunc.apply(log_probs_main, log_probs_sep, alpha, beta)
return log_probs_main, log_probs_sep
Expand Down

0 comments on commit 2a3e684

Please sign in to comment.