diff --git a/se3cnn/non_linearities/rescaled_act.py b/se3cnn/non_linearities/rescaled_act.py index 1cdce49..53b7ddc 100644 --- a/se3cnn/non_linearities/rescaled_act.py +++ b/se3cnn/non_linearities/rescaled_act.py @@ -15,8 +15,9 @@ def __call__(self, x): class ShiftedSoftplus: def __init__(self, beta): x = torch.randn(100000, dtype=torch.float64) - self.factor = torch.nn.functional.softplus(x, beta).pow(2).mean().rsqrt().item() self.shift = torch.nn.functional.softplus(torch.zeros(()), beta).item() + y = torch.nn.functional.softplus(x, beta).sub(self.shift) + self.factor = y.pow(2).mean().rsqrt().item() self.beta = beta def __call__(self, x):