diff --git a/equinox/nn/_batch_norm.py b/equinox/nn/_batch_norm.py index 45802d10..0ac300aa 100644 --- a/equinox/nn/_batch_norm.py +++ b/equinox/nn/_batch_norm.py @@ -1,10 +1,11 @@ +import warnings from collections.abc import Hashable, Sequence -from typing import Optional, Union +from typing import Literal, Optional, Union import jax import jax.lax as lax import jax.numpy as jnp -from jaxtyping import Array, Bool, Float, PRNGKeyArray +from jaxtyping import Array, Float, Int, PRNGKeyArray from .._misc import default_floating_dtype from .._module import field @@ -40,28 +41,92 @@ class BatchNorm(StatefulLayer, strict=True): statistics updated. During inference then just the running statistics are used. Whether the model is in training or inference mode should be toggled using [`equinox.nn.inference_mode`][]. + + With `approach = "batch"` during training the batch mean and variance are used + for normalization. For inference the exponential running mean and ubiased + variance are used for normalization in accordance with the cited paper below. + Let `m` be momentum: + + $\text{TrainStats}_t = \text{BatchStats}_t$ + + $\text{InferenceStats}_t = \frac{\left(1.0 - m\right)\sum_{i=0}^{t}m^{t-i} + \text{BatchStats}_i}{\text{max} \left(1.0 - m^{t+1}, \varepsilon \right)}$ + + With `approach = "ema"` exponential running means and variances are kept. During + training the batch statistics are used to fill in the running statistics until + they are populated. In addition a linear iterpolation is used between the batch + and running statistics over the `warmup_period`. During inference the running + statistics are used for normalization: + + + + $\text{WarmupFrac}_t = \text{min} \left(1.0, \frac{t}{\text{WarmupPeriod}} \right)$ + + $\text{TrainStats}_t = (1.0 - \text{WarmupFrac}_t) * BatchStats_t + + \text{WarmupFrac}_t * \left(1.0 - m\right)\sum_{i=0}^{t}m^{t-i}\text{BatchStats}_i$ + + $\text{InferenceStats}_t = \frac{\left(1.0 - m\right)\sum_{i=0}^{t}m^{t-i} + \text{BatchStats}_i}{\text{max} \left(1.0 - m^{t+1}, \varepsilon \right)}$ + + + $\text{Note: } \frac{(1.0 - m)\sum_{i=0}^{t}m^{t-i}}{1.0 - m^{t+1}} = + \frac{(1.0 - m)\sum_{i=0}^{t}m^{i}}{1.0 - m^{t+1}}$ + $= \frac{(1.0 - m)\frac{1.0 - m^{t+1}}{1.0 - m}}{1.0 - m^{t+1}} = 1$ + + `approach = "ema_compatibility"` reproduces the original equinox BatchNorm + behavior. It often results in training instabilities and `approach = "batch"` + or `"ema"` is recommended. + + ??? cite + + [Batch Normalization: Accelerating Deep Network Training by Reducing + Internal Covariate Shift](https://arxiv.org/abs/1502.03167) + + ```bibtex + @article{DBLP:journals/corr/IoffeS15, + author = {Sergey Ioffe and + Christian Szegedy}, + title = {Batch Normalization: Accelerating Deep Network Training + by Reducing Internal Covariate Shift}, + journal = {CoRR}, + volume = {abs/1502.03167}, + year = {2015}, + url = {http://arxiv.org/abs/1502.03167}, + eprinttype = {arXiv}, + eprint = {1502.03167}, + timestamp = {Mon, 13 Aug 2018 16:47:06 +0200}, + biburl = {https://dblp.org/rec/journals/corr/IoffeS15.bib}, + bibsource = {dblp computer science bibliography, https://dblp.org} + } + ``` + """ # noqa: E501 weight: Optional[Float[Array, "input_size"]] bias: Optional[Float[Array, "input_size"]] - first_time_index: StateIndex[Bool[Array, ""]] + count_index: StateIndex[Int[Array, ""]] state_index: StateIndex[ tuple[Float[Array, "input_size"], Float[Array, "input_size"]] ] + zero_frac_index: StateIndex[Float[Array, ""]] axis_name: Union[Hashable, Sequence[Hashable]] inference: bool input_size: int = field(static=True) + approach: Literal["batch", "ema", "ema_compatibility"] = field(static=True) eps: float = field(static=True) channelwise_affine: bool = field(static=True) momentum: float = field(static=True) + warmup_period: int = field(static=True) def __init__( self, input_size: int, axis_name: Union[Hashable, Sequence[Hashable]], + approach: Optional[Literal["batch", "ema", "ema_compatibility"]] = None, eps: float = 1e-5, channelwise_affine: bool = True, momentum: float = 0.99, + warmup_period: int = 1, inference: bool = False, dtype=None, ): @@ -71,11 +136,17 @@ def __init__( - `axis_name`: The name of the batch axis to compute statistics over, as passed to `axis_name` in `jax.vmap` or `jax.pmap`. Can also be a sequence (e.g. a tuple or a list) of names, to compute statistics over multiple named axes. + - `approach`: The approach to use for the running statistics. If `approach=None` + a warning will be raised and approach will default to `"ema_compatibility"`. + During training `"batch"` only uses batch statisics while`"ema"` and + `"ema_compatibility"` uses the running statistics. - `eps`: Value added to the denominator for numerical stability. - `channelwise_affine`: Whether the module has learnable channel-wise affine parameters. - `momentum`: The rate at which to update the running statistics. Should be a value between 0 and 1 exclusive. + - `warmup_period`: The interpolation period between batch and running + statistics. Only used when `approach=\"ema\"`. - `inference`: If `False` then the batch means and variances will be calculated and used to update the running statistics. If `True` then the running statistics are directly used for normalisation. This may be toggled with @@ -86,26 +157,46 @@ def __init__( 64-bit mode. """ + if approach is None: + warnings.warn( + "BatchNorm approach is None, defaults to " + 'approach="ema_compatibility". This is not recommended as ' + 'it can lead to training instability. Use "batch" or ' + 'alternatively "ema" with appropriately selected warmup ' + "instead." + ) + approach = "ema_compatibility" + + valid_approaches = {"batch", "ema", "ema_compatibility"} + if approach not in valid_approaches: + raise ValueError(f"approach must be one of {valid_approaches}") + self.approach = approach + + if warmup_period < 1: + raise ValueError("warmup_period must be >= 1") + if channelwise_affine: self.weight = jnp.ones((input_size,)) self.bias = jnp.zeros((input_size,)) else: self.weight = None self.bias = None - self.first_time_index = StateIndex(jnp.array(True)) + self.count_index = StateIndex(jnp.array(0, dtype=jnp.int32)) if dtype is None: dtype = default_floating_dtype() init_buffers = ( - jnp.empty((input_size,), dtype=dtype), - jnp.empty((input_size,), dtype=dtype), + jnp.zeros((input_size,), dtype=dtype), + jnp.zeros((input_size,), dtype=dtype), ) self.state_index = StateIndex(init_buffers) + self.zero_frac_index = StateIndex(jnp.array(1.0, dtype=dtype)) self.inference = inference self.axis_name = axis_name self.input_size = input_size self.eps = eps self.channelwise_affine = channelwise_affine self.momentum = momentum + self.warmup_period = warmup_period @jax.named_scope("eqx.nn.BatchNorm") def __call__( @@ -143,7 +234,11 @@ def __call__( if inference is None: inference = self.inference if inference: + # renormalize running stats to account for the zeroed part + zero_frac = state.get(self.zero_frac_index) running_mean, running_var = state.get(self.state_index) + norm_mean = running_mean / jnp.maximum(1.0 - zero_frac, self.eps) + norm_var = running_var / jnp.maximum(1.0 - zero_frac, self.eps) else: def _stats(y): @@ -154,16 +249,50 @@ def _stats(y): var = jnp.maximum(0.0, var) return mean, var - first_time = state.get(self.first_time_index) - state = state.set(self.first_time_index, jnp.array(False)) - + momentum = self.momentum batch_mean, batch_var = jax.vmap(_stats)(x) + zero_frac = state.get(self.zero_frac_index) running_mean, running_var = state.get(self.state_index) - momentum = self.momentum - running_mean = (1 - momentum) * batch_mean + momentum * running_mean - running_var = (1 - momentum) * batch_var + momentum * running_var - running_mean = lax.select(first_time, batch_mean, running_mean) - running_var = lax.select(first_time, batch_var, running_var) + + if self.approach == "ema": + zero_frac = zero_frac * momentum + running_mean = (1 - momentum) * batch_mean + momentum * running_mean + running_var = (1 - momentum) * batch_var + momentum * running_var + warmup_count = state.get(self.count_index) + warmup_count = jnp.minimum(warmup_count + 1, self.warmup_period) + state = state.set(self.count_index, warmup_count) + + # fill in unpopulated part of running stats with batch stats + warmup_frac = warmup_count / self.warmup_period + norm_mean = zero_frac * batch_mean + running_mean + norm_var = zero_frac * batch_var + running_var + + # apply warmup interpolation between batch and running statistics + norm_mean = (1.0 - warmup_frac) * batch_mean + warmup_frac * norm_mean + norm_var = (1.0 - warmup_frac) * batch_var + warmup_frac * norm_var + + elif self.approach == "ema_compatibility": + running_mean = (1 - momentum) * batch_mean + momentum * running_mean + running_var = (1 - momentum) * batch_var + momentum * running_var + running_mean = lax.select(zero_frac == 1.0, batch_mean, running_mean) + running_var = lax.select(zero_frac == 1.0, batch_var, running_var) + norm_mean, norm_var = running_mean, running_var + zero_frac = 0.0 * zero_frac + + else: + zero_frac = zero_frac * momentum + running_mean = (1 - momentum) * batch_mean + momentum * running_mean + # calculate unbiased variance for saving + axis_size = jax.lax.psum(jnp.array(1.0), self.axis_name) + debias_coef = (axis_size) / jnp.maximum(axis_size - 1, self.eps) + running_var = ( + 1 - momentum + ) * debias_coef * batch_var + momentum * running_var + + # just use batch statistics when not in inference mode + norm_mean, norm_var = batch_mean, batch_var + + state = state.set(self.zero_frac_index, zero_frac) state = state.set(self.state_index, (running_mean, running_var)) def _norm(y, m, v, w, b): @@ -172,5 +301,5 @@ def _norm(y, m, v, w, b): out = out * w + b return out - out = jax.vmap(_norm)(x, running_mean, running_var, self.weight, self.bias) + out = jax.vmap(_norm)(x, norm_mean, norm_var, self.weight, self.bias) return out, state diff --git a/tests/test_nn.py b/tests/test_nn.py index babdd07d..fb3704b0 100644 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -124,7 +124,7 @@ def test_sequential(getkey): [ eqx.nn.Linear(2, 4, key=getkey()), eqx.nn.Linear(4, 1, key=getkey()), - eqx.nn.BatchNorm(1, axis_name="batch"), + eqx.nn.BatchNorm(1, axis_name="batch", approach="batch"), eqx.nn.Linear(1, 3, key=getkey()), ] ) @@ -158,7 +158,7 @@ def make(): inner_seq = eqx.nn.Sequential( [ eqx.nn.Linear(2, 4, key=getkey()), - eqx.nn.BatchNorm(4, axis_name="batch") + eqx.nn.BatchNorm(4, axis_name="batch", approach="batch") if inner_stateful else eqx.nn.Identity(), eqx.nn.Linear(4, 3, key=getkey()), @@ -168,7 +168,7 @@ def make(): [ eqx.nn.Linear(5, 2, key=getkey()), inner_seq, - eqx.nn.BatchNorm(3, axis_name="batch") + eqx.nn.BatchNorm(3, axis_name="batch", approach="batch") if outer_stateful else eqx.nn.Identity(), eqx.nn.Linear(3, 6, key=getkey()), @@ -825,18 +825,40 @@ def test_batch_norm(getkey): x2 = jrandom.uniform(getkey(), (10, 5, 6)) x3 = jrandom.uniform(getkey(), (10, 5, 7, 8)) - # Test that it works with a single vmap'd axis_name + # Test that it warns with no approach - defaulting to batch + with pytest.warns(UserWarning): + bn = eqx.nn.BatchNorm(5, "batch") + assert bn.approach == "ema_compatibility" + + with pytest.raises(ValueError): + bn = eqx.nn.BatchNorm(5, "batch", approach="ema", warmup_period=0) - bn = eqx.nn.BatchNorm(5, "batch") + # Test initialization + bn_momentum = 0.99 + bn = eqx.nn.BatchNorm( + 5, "batch", approach="ema", warmup_period=10, momentum=bn_momentum + ) state = eqx.nn.State(bn) vbn = jax.vmap(bn, axis_name="batch", in_axes=(0, None), out_axes=(0, None)) + running_mean, running_var = state.get(bn.state_index) + zero_frac = state.get(bn.zero_frac_index) + warmup_count = state.get(bn.count_index) + assert jnp.array_equal(running_mean, jnp.zeros(running_mean.shape)) + assert jnp.array_equal(running_var, jnp.zeros(running_var.shape)) + assert jnp.array_equal(zero_frac, jnp.array(1.0)) + assert jnp.array_equal(warmup_count, jnp.array(0)) - for x in (x1, x2, x3): + # Test that it works with a single vmap'd axis_name + for i, x in enumerate([x1, x2, x3]): out, state = vbn(x, state) assert out.shape == x.shape running_mean, running_var = state.get(bn.state_index) + zero_frac = state.get(bn.zero_frac_index) + warmup_count = state.get(bn.count_index) assert running_mean.shape == (5,) assert running_var.shape == (5,) + assert jnp.array_equal(warmup_count, jnp.array(i + 1)) + assert jnp.allclose(zero_frac, jnp.array(bn_momentum ** (i + 1))) # Test that it fails without any vmap'd axis_name @@ -861,7 +883,7 @@ def test_batch_norm(getkey): # Test that it handles multiple axis_names - vvbn = eqx.nn.BatchNorm(6, ("batch1", "batch2")) + vvbn = eqx.nn.BatchNorm(6, ("batch1", "batch2"), approach="ema") vvstate = eqx.nn.State(vvbn) for axis_name in ("batch1", "batch2"): vvbn = jax.vmap( @@ -873,10 +895,21 @@ def test_batch_norm(getkey): assert running_mean.shape == (6,) assert running_var.shape == (6,) - # Test that it normalises - + # Test that approach=ema normalises x1alt = jrandom.normal(jrandom.PRNGKey(5678), (10, 5)) # avoid flakey test - bn = eqx.nn.BatchNorm(5, "batch", channelwise_affine=False) + bn = eqx.nn.BatchNorm(5, "batch", channelwise_affine=False, approach="ema") + state = eqx.nn.State(bn) + vbn = jax.vmap(bn, axis_name="batch", in_axes=(0, None), out_axes=(0, None)) + out, state = vbn(x1alt, state) + true_out = (x1alt - jnp.mean(x1alt, axis=0, keepdims=True)) / jnp.sqrt( + jnp.var(x1alt, axis=0, keepdims=True) + 1e-5 + ) + assert jnp.allclose(out, true_out) + + # Test that approach=batch normalises in training mode + bn = eqx.nn.BatchNorm( + 5, "batch", channelwise_affine=False, approach="batch", momentum=0.9 + ) state = eqx.nn.State(bn) vbn = jax.vmap(bn, axis_name="batch", in_axes=(0, None), out_axes=(0, None)) out, state = vbn(x1alt, state) @@ -884,23 +917,37 @@ def test_batch_norm(getkey): jnp.var(x1alt, axis=0, keepdims=True) + 1e-5 ) assert jnp.allclose(out, true_out) + # Test that approach=batch normaises in inference mode + bn_inf = eqx.nn.inference_mode(bn, value=True) + vbn_inf = jax.vmap(bn_inf, axis_name="batch", in_axes=(0, None), out_axes=(0, None)) + out, state = vbn_inf(x1alt, state) + debias_coef = x1alt.shape[0] / (x1alt.shape[0] - 1) + true_out = (x1alt - jnp.mean(x1alt, axis=0, keepdims=True)) / jnp.sqrt( + debias_coef * jnp.var(x1alt, axis=0, keepdims=True) + 1e-5 + ) + assert jnp.allclose(out, true_out) # Test that the statistics update during training out, state = vbn(x1, state) running_mean, running_var = state.get(bn.state_index) out, state = vbn(3 * x1 + 10, state) running_mean2, running_var2 = state.get(bn.state_index) + zero_frac2 = state.get(bn.zero_frac_index) + warmup_count2 = state.get(bn.count_index) assert not jnp.allclose(running_mean, running_mean2) assert not jnp.allclose(running_var, running_var2) # Test that the statistics don't update at inference - ibn = eqx.nn.inference_mode(bn, value=True) vibn = jax.vmap(ibn, axis_name="batch", in_axes=(0, None), out_axes=(0, None)) out, state = vibn(4 * x1 + 20, state) running_mean3, running_var3 = state.get(bn.state_index) + zero_frac3 = state.get(bn.zero_frac_index) + warmup_count3 = state.get(bn.count_index) assert jnp.array_equal(running_mean2, running_mean3) assert jnp.array_equal(running_var2, running_var3) + assert jnp.array_equal(zero_frac2, zero_frac3) + assert jnp.array_equal(warmup_count2, warmup_count3) # Test that we can differentiate through it diff --git a/tests/test_stateful.py b/tests/test_stateful.py index d7bc632a..de75319b 100644 --- a/tests/test_stateful.py +++ b/tests/test_stateful.py @@ -7,7 +7,7 @@ def test_delete_init_state(): - model = eqx.nn.BatchNorm(3, "batch") + model = eqx.nn.BatchNorm(3, "batch", approach="batch") eqx.nn.State(model) model2 = eqx.nn.delete_init_state(model) @@ -17,7 +17,7 @@ def test_delete_init_state(): leaves = [x for x in jtu.tree_leaves(model) if eqx.is_array(x)] leaves2 = [x for x in jtu.tree_leaves(model2) if eqx.is_array(x)] - assert len(leaves) == len(leaves2) + 3 + assert len(leaves) == len(leaves2) + 4 def test_double_state():