diff --git a/deepmd/op/_tabulate_grad.py b/deepmd/op/_tabulate_grad.py index 9076ee3213..7cc7532a39 100644 --- a/deepmd/op/_tabulate_grad.py +++ b/deepmd/op/_tabulate_grad.py @@ -55,7 +55,7 @@ def _tabulate_fusion_se_atten_grad_cc(op, dy): op.outputs[0], is_sorted=op.get_attr("is_sorted"), ) - return [None, None, dy_dx, dy_df, None] + return [None, None, dy_dx, dy_df, dy_dtwo] @ops.RegisterGradient("TabulateFusionSeAttenGrad")