Skip to content

Commit

Permalink
dtype cast for constants in activation functions
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Apr 6, 2024
1 parent 2b58c49 commit 23480d0
Showing 1 changed file with 39 additions and 9 deletions.
48 changes: 39 additions & 9 deletions braintools/functional/_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import jax.numpy as jnp
from jax.scipy.special import logsumexp
from jax.typing import ArrayLike
import braincore as bc

__all__ = [
"relu",
Expand Down Expand Up @@ -42,6 +43,18 @@
]


def _get_dtype(x: ArrayLike):
if hasattr(x, 'dtype'):
return x.dtype
else:
if isinstance(x, float):
return bc.environ.dftype()
elif isinstance(x, int):
return bc.environ.dftype()
else:
raise ValueError(f'Unsupported type: {type(x)}')


def softmin(x, axis=-1):
r"""Applies the Softmin function to an n-dimensional input Tensor
rescaling them so that the elements of the n-dimensional output Tensor
Expand Down Expand Up @@ -93,7 +106,10 @@ def prelu(x, a=0.25):
parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,
a separate :math:`a` is used for each input channel.
"""
return jnp.where(x >= 0., x, a * x)
dtype = _get_dtype(x)
return jnp.where(x >= jnp.asarray(0., dtype),
x,
jnp.asarray(a, dtype) * x)


def soft_shrink(x, lambd=0.5):
Expand All @@ -114,7 +130,11 @@ def soft_shrink(x, lambd=0.5):
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- Output: :math:`(*)`, same shape as the input.
"""
return jnp.where(x > lambd, x - lambd, jnp.where(x < -lambd, x + lambd, 0.))
dtype = _get_dtype(x)
lambd = jnp.asarray(lambd, dtype)
return jnp.where(x > lambd,
x - lambd,
jnp.where(x < -lambd, x + lambd, jnp.asarray(0., dtype)))


def mish(x):
Expand All @@ -135,7 +155,7 @@ def mish(x):
return x * jnp.tanh(softplus(x))


def rrelu(key, x, lower=0.125, upper=0.3333333333333333):
def rrelu(x, lower=0.125, upper=0.3333333333333333):
r"""Applies the randomized leaky rectified liner unit function, element-wise,
as described in the paper:
Expand Down Expand Up @@ -166,9 +186,9 @@ def rrelu(key, x, lower=0.125, upper=0.3333333333333333):
.. _`Empirical Evaluation of Rectified Activations in Convolutional Network`:
https://arxiv.org/abs/1505.00853
"""
x = jnp.asarray(x)
a = jax.random.uniform(key, x.shape, x.dtype, lower, upper)
return jnp.where(x >= 0., x, a * x)
dtype = _get_dtype(x)
a = bc.random.uniform(lower, upper, size=jnp.shape(x), dtype=dtype)
return jnp.where(x >= jnp.asarray(0., dtype), x, jnp.asarray(a, dtype) * x)


def hard_shrink(x, lambd=0.5):
Expand All @@ -192,7 +212,11 @@ def hard_shrink(x, lambd=0.5):
- Output: :math:`(*)`, same shape as the input.
"""
return jnp.where(x > lambd, x, jnp.where(x < -lambd, x, 0.))
dtype = _get_dtype(x)
lambd = jnp.asarray(lambd, dtype)
return jnp.where(x > lambd,
x,
jnp.where(x < -lambd, x, jnp.asarray(0., dtype)))


def relu(x: ArrayLike) -> jax.Array:
Expand Down Expand Up @@ -229,7 +253,6 @@ def relu(x: ArrayLike) -> jax.Array:
return jax.nn.relu(x)



def squareplus(x: ArrayLike, b: ArrayLike = 4) -> jax.Array:
r"""Squareplus activation function.
Expand All @@ -244,7 +267,8 @@ def squareplus(x: ArrayLike, b: ArrayLike = 4) -> jax.Array:
x : input array
b : smoothness parameter
"""
return jax.nn.squareplus(x, b)
dtype = _get_dtype(x)
return jax.nn.squareplus(x, jnp.asarray(b, dtype))


def softplus(x: ArrayLike) -> jax.Array:
Expand Down Expand Up @@ -362,6 +386,8 @@ def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> jax.Array:
See also:
:func:`selu`
"""
dtype = _get_dtype(x)
alpha = jnp.asarray(alpha, dtype)
return jax.nn.elu(x, alpha)


Expand All @@ -388,6 +414,8 @@ def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> jax.Array:
See also:
:func:`relu`
"""
dtype = _get_dtype(x)
negative_slope = jnp.asarray(negative_slope, dtype)
return jax.nn.leaky_relu(x, negative_slope=negative_slope)


Expand Down Expand Up @@ -434,6 +462,8 @@ def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> jax.Array:
Returns:
An array.
"""
dtype = _get_dtype(x)
alpha = jnp.asarray(alpha, dtype)
return jax.nn.celu(x, alpha)


Expand Down

0 comments on commit 23480d0

Please sign in to comment.