Skip to content

Commit

Permalink
Better default behaviour for BN dtypes; better State errors.
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Nov 27, 2023
1 parent 588a8c3 commit 4642347
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 5 deletions.
8 changes: 8 additions & 0 deletions equinox/_misc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import jax
import jax.core
import jax.numpy as jnp
from jaxtyping import Array
Expand All @@ -10,3 +11,10 @@ def left_broadcast_to(arr: Array, shape: tuple[int, ...]) -> Array:

def currently_jitting():
return isinstance(jnp.array(1) + 1, jax.core.Tracer)


def default_floating_dtype():
if jax.config.jax_enable_x64: # pyright: ignore
return jnp.float64
else:
return jnp.float32
9 changes: 7 additions & 2 deletions equinox/nn/_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import jax.numpy as jnp
from jaxtyping import Array, Bool, Float

from .._misc import default_floating_dtype
from .._module import field
from ._sequential import StatefulLayer
from ._stateful import State, StateIndex
Expand Down Expand Up @@ -62,7 +63,7 @@ def __init__(
channelwise_affine: bool = True,
momentum: float = 0.99,
inference: bool = False,
dtype=jnp.float32,
dtype=None,
**kwargs,
):
"""**Arguments:**
Expand All @@ -81,7 +82,9 @@ def __init__(
statistics are directly used for normalisation. This may be toggled with
[`equinox.nn.inference_mode`][] or overridden during
[`equinox.nn.BatchNorm.__call__`][].
- `dtype`: The dtype of the input array.
- `dtype`: The dtype to use for the running statistics. Defaults to either
`jax.numpy.float32` or `jax.numpy.float64` depending on whether JAX is in
64-bit mode.
"""

super().__init__(**kwargs)
Expand All @@ -93,6 +96,8 @@ def __init__(
self.weight = None
self.bias = None
self.first_time_index = StateIndex(jnp.array(True))
if dtype is None:
dtype = default_floating_dtype()
init_buffers = (
jnp.empty((input_size,), dtype=dtype),
jnp.empty((input_size,), dtype=dtype),
Expand Down
13 changes: 10 additions & 3 deletions equinox/nn/_stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from .._module import field, Module
from .._pretty_print import bracketed, named_objs, text, tree_pformat
from .._tree import tree_at
from .._tree import tree_at, tree_equal


_Value = TypeVar("_Value")
Expand Down Expand Up @@ -165,8 +165,15 @@ def set(self, item: StateIndex[_Value], value: _Value) -> "State":
raise ValueError("Can only use `eqx.nn.StateIndex`s as state keys.")
old_value = self._state[item.marker] # pyright: ignore
value = jtu.tree_map(jnp.asarray, value)
if jax.eval_shape(lambda: old_value) != jax.eval_shape(lambda: value):
raise ValueError("Old and new values have different structures.")
old_struct = jax.eval_shape(lambda: old_value)
new_struct = jax.eval_shape(lambda: value)
if tree_equal(old_struct, new_struct) is not True:
old_repr = tree_pformat(old_struct, struct_as_array=True)
new_repr = tree_pformat(new_struct, struct_as_array=True)
raise ValueError(
"Old and new values have different structures/shapes/dtypes. The old "
f"value is {old_repr} and the new value is {new_repr}."
)
state = self._state.copy() # pyright: ignore
state[item.marker] = value
new_self = object.__new__(State)
Expand Down

0 comments on commit 4642347

Please sign in to comment.