Skip to content

Commit

Permalink
Always initialize the capacity for SimpleNeuralField
Browse files Browse the repository at this point in the history
  • Loading branch information
famura committed Jul 17, 2024
1 parent db26722 commit 51d2b2f
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 25 deletions.
46 changes: 22 additions & 24 deletions neuralfields/simple_neural_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,56 +341,54 @@ def __init__(
self._hidden_size, nonlin=activation_nonlin, bias=False, weight=True
)

# Potential dynamics.
# Potential dynamics' capacity.
self.potentials_dyn_fcn = potentials_dyn_fcn
self.capacity_learnable = capacity_learnable
if self.potentials_dyn_fcn in [pd_capacity_21, pd_capacity_21_abs, pd_capacity_32, pd_capacity_32_abs]:
if _is_iterable(activation_nonlin):
self._init_capacity(activation_nonlin[0], device)
self._capacity_opt_init = self._init_capacity_heuristic(activation_nonlin[0])
else:
self._init_capacity(activation_nonlin, device) # type: ignore[arg-type]
self._capacity_opt_init = self._init_capacity_heuristic(activation_nonlin) # type: ignore[arg-type]
else:
self._capacity_opt = None

# Initialize cubic decay and capacity if learnable.
if (self.potentials_dyn_fcn is pd_cubic) and self.kappa_learnable:
self._kappa_opt.data = self._kappa_opt_init
elif self.potentials_dyn_fcn in [pd_capacity_21, pd_capacity_21_abs, pd_capacity_32, pd_capacity_32_abs]:
self._capacity_opt.data = self._capacity_opt_init
# Even if the potential function does not include a capacity term, we initialize it to be compatible with
# custom functions.
self._capacity_opt_init = torch.tensor(1.0, dtype=torch.get_default_dtype())
self._capacity_opt = nn.Parameter(self._capacity_opt_init.to(device=device), requires_grad=capacity_learnable)

# Move the complete model to the given device.
self.to(device=device)

def _init_capacity(self, activation_nonlin: ActivationFunction, device: Union[str, torch.device]) -> None:
def _init_capacity_heuristic(self, activation_nonlin: ActivationFunction) -> torch.Tensor:
"""Initialize the value of the capacity parameter $C$ depending on the activation function.
Args:
activation_nonlin: Nonlinear activation function used.
Returns:
Heuristic initial value for the capacity parameter.
"""
if activation_nonlin is torch.sigmoid:
# sigmoid(7.) approx 0.999
self._capacity_opt_init = PotentialBased.transform_to_opt_space(
torch.tensor([7.0], device=device, dtype=torch.get_default_dtype())
return PotentialBased.transform_to_opt_space(
torch.tensor([7.0], dtype=torch.get_default_dtype())
)
elif activation_nonlin is torch.tanh:
# tanh(3.8) approx 0.999
self._capacity_opt_init = PotentialBased.transform_to_opt_space(
torch.tensor([3.8], device=device, dtype=torch.get_default_dtype())
return PotentialBased.transform_to_opt_space(
torch.tensor([3.8], dtype=torch.get_default_dtype())
)
else:
raise ValueError(
"For the potential dynamics including a capacity, only output nonlinearities of type "
"torch.sigmoid and torch.tanh are supported!"
)
self._capacity_opt = nn.Parameter(self._capacity_opt_init, requires_grad=self.capacity_learnable)
raise NotImplementedError(
"For the potential dynamics including a capacity, the initialization heuristic only supports "
"the activation functions `torch.sigmoid` and `torch.tanh`!"
)

def extra_repr(self) -> str:
return super().extra_repr() + f", capacity_learnable={self.capacity_learnable}"

@property
def capacity(self) -> Optional[torch.Tensor]:
"""Get the capacity parameter (exists for capacity-based dynamics functions), otherwise return `None`."""
return None if self._capacity_opt is None else PotentialBased.transform_to_img_space(self._capacity_opt)
def capacity(self) -> Union[torch.Tensor, nn.Parameter]:
"""Get the capacity parameter (only used for capacity-based dynamics functions)."""
return PotentialBased.transform_to_img_space(self._capacity_opt)

def potentials_dot(self, potentials: torch.Tensor, stimuli: torch.Tensor) -> torch.Tensor:
r"""Compute the derivative of the neurons' potentials per time step.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_simple_neural_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def test_neural_fields_trafos(kappa_init: float, tau_init: float):


def test_simple_neural_fields_fail():
with pytest.raises(ValueError):
with pytest.raises(NotImplementedError):
SimpleNeuralField(input_size=6, output_size=3, potentials_dyn_fcn=pd_capacity_21, activation_nonlin=torch.sqrt)

with pytest.raises(ValueError):
Expand Down

0 comments on commit 51d2b2f

Please sign in to comment.