From 464234708e9902205e3a1f6c874c4eaec5d5a36a Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Wed, 15 Nov 2023 15:27:32 -0800 Subject: [PATCH] Better default behaviour for BN dtypes; better State errors. --- equinox/_misc.py | 8 ++++++++ equinox/nn/_batch_norm.py | 9 +++++++-- equinox/nn/_stateful.py | 13 ++++++++++--- 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/equinox/_misc.py b/equinox/_misc.py index e24883a9..14183278 100644 --- a/equinox/_misc.py +++ b/equinox/_misc.py @@ -1,3 +1,4 @@ +import jax import jax.core import jax.numpy as jnp from jaxtyping import Array @@ -10,3 +11,10 @@ def left_broadcast_to(arr: Array, shape: tuple[int, ...]) -> Array: def currently_jitting(): return isinstance(jnp.array(1) + 1, jax.core.Tracer) + + +def default_floating_dtype(): + if jax.config.jax_enable_x64: # pyright: ignore + return jnp.float64 + else: + return jnp.float32 diff --git a/equinox/nn/_batch_norm.py b/equinox/nn/_batch_norm.py index dca0353c..bf2ce21b 100644 --- a/equinox/nn/_batch_norm.py +++ b/equinox/nn/_batch_norm.py @@ -6,6 +6,7 @@ import jax.numpy as jnp from jaxtyping import Array, Bool, Float +from .._misc import default_floating_dtype from .._module import field from ._sequential import StatefulLayer from ._stateful import State, StateIndex @@ -62,7 +63,7 @@ def __init__( channelwise_affine: bool = True, momentum: float = 0.99, inference: bool = False, - dtype=jnp.float32, + dtype=None, **kwargs, ): """**Arguments:** @@ -81,7 +82,9 @@ def __init__( statistics are directly used for normalisation. This may be toggled with [`equinox.nn.inference_mode`][] or overridden during [`equinox.nn.BatchNorm.__call__`][]. - - `dtype`: The dtype of the input array. + - `dtype`: The dtype to use for the running statistics. Defaults to either + `jax.numpy.float32` or `jax.numpy.float64` depending on whether JAX is in + 64-bit mode. """ super().__init__(**kwargs) @@ -93,6 +96,8 @@ def __init__( self.weight = None self.bias = None self.first_time_index = StateIndex(jnp.array(True)) + if dtype is None: + dtype = default_floating_dtype() init_buffers = ( jnp.empty((input_size,), dtype=dtype), jnp.empty((input_size,), dtype=dtype), diff --git a/equinox/nn/_stateful.py b/equinox/nn/_stateful.py index aa0f8253..7f7064ee 100644 --- a/equinox/nn/_stateful.py +++ b/equinox/nn/_stateful.py @@ -10,7 +10,7 @@ from .._module import field, Module from .._pretty_print import bracketed, named_objs, text, tree_pformat -from .._tree import tree_at +from .._tree import tree_at, tree_equal _Value = TypeVar("_Value") @@ -165,8 +165,15 @@ def set(self, item: StateIndex[_Value], value: _Value) -> "State": raise ValueError("Can only use `eqx.nn.StateIndex`s as state keys.") old_value = self._state[item.marker] # pyright: ignore value = jtu.tree_map(jnp.asarray, value) - if jax.eval_shape(lambda: old_value) != jax.eval_shape(lambda: value): - raise ValueError("Old and new values have different structures.") + old_struct = jax.eval_shape(lambda: old_value) + new_struct = jax.eval_shape(lambda: value) + if tree_equal(old_struct, new_struct) is not True: + old_repr = tree_pformat(old_struct, struct_as_array=True) + new_repr = tree_pformat(new_struct, struct_as_array=True) + raise ValueError( + "Old and new values have different structures/shapes/dtypes. The old " + f"value is {old_repr} and the new value is {new_repr}." + ) state = self._state.copy() # pyright: ignore state[item.marker] = value new_self = object.__new__(State)