diff --git a/docs/api/nn/normalisation.md b/docs/api/nn/normalisation.md index 80d27b44..0105f2bd 100644 --- a/docs/api/nn/normalisation.md +++ b/docs/api/nn/normalisation.md @@ -29,3 +29,11 @@ members: - __init__ - __call__ + +--- + +::: equinox.nn.WeightNorm + selection: + members: + - __init__ + - __call__ diff --git a/equinox/nn/__init__.py b/equinox/nn/__init__.py index a5fb3a3d..d094f0be 100644 --- a/equinox/nn/__init__.py +++ b/equinox/nn/__init__.py @@ -47,3 +47,4 @@ State as State, StateIndex as StateIndex, ) +from ._weight_norm import WeightNorm as WeightNorm diff --git a/equinox/nn/_weight_norm.py b/equinox/nn/_weight_norm.py new file mode 100644 index 00000000..6ec324ec --- /dev/null +++ b/equinox/nn/_weight_norm.py @@ -0,0 +1,100 @@ +from typing import Generic, Optional, TypeVar, Union + +import jax +import jax.numpy as jnp +from jaxtyping import Array, PRNGKeyArray + +from .._module import field, Module +from .._tree import tree_at + + +_Layer = TypeVar("_Layer") + + +def _norm_except_axis( + v: Array, pow: Optional[Union[int, str]] = None, axis: Optional[int] = 0 +) -> Array: + for ax in range(len(v.shape)): + if ax != axis: + v = jnp.linalg.norm(v, ord=pow, axis=ax, keepdims=True) + return v if axis is not None else v.reshape([]) + + +class WeightNorm(Module, Generic[_Layer]): + r""" + Applies weight normalisation to a given parameter. + + Given a weight matrix $\mathbf{W}$, computes the follow reparametrization: + + $\mathbf{W} = g \frac{\mathbf{v}}{\lVert \mathbf{v} \rVert}$ + + where $g$ is initially chosen to equal $\lVert \mathbf{v} \rVert$. + + ??? cite + [Weight Normalisation](https://arxiv.org/abs/1602.07868) + + ```bibtex + @article{DBLP:journals/corr/SalimansK16, + author = {Tim Salimans and + Diederik P. Kingma}, + title = {Weight Normalisation: {A} Simple + Reparameterization to Accelerate + Training of Deep Neural Networks}, + journal = {CoRR}, + volume = {abs/1602.07868}, + year = {2016}, + url = {http://arxiv.org/abs/1602.07868}, + eprinttype = {arXiv}, + eprint = {1602.07868}, + timestamp = {Mon, 13 Aug 2018 16:47:07 +0200}, + biburl = {https://dblp.org/rec/journals/corr/SalimansK16.bib}, + bibsource = {dblp computer science bibliography, https://dblp.org} + } + ``` + + + """ + + layer: _Layer + v: Array + g: Array + weight_name: str = field(static=True) + axis: Optional[int] = field(static=True) + + def __init__( + self, + layer: _Layer, + weight_name: str = "weight", + axis: Optional[int] = 0, + ): + """**Arguments:** + + - `layer`: The layer to wrap. Usually a [`equinox.nn.Linear`][] or + a convolutional layer (e.g. [`equinox.nn.Conv2d`][]). + - `weight_name`: The name of the layer's parameter (a JAX array) to apply + weight normalisation to. + - `axis`: The norm is computed across every axis except this one. + If `None`, compute across every axis. + """ + self.layer = layer + self.weight_name = weight_name + self.axis = axis + + self.v = getattr(layer, weight_name) + self.g = _norm_except_axis(self.v, axis=axis) + + @jax.named_scope("eqx.nn.WeightNorm") + def __call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array: + """**Arguments:** + + - `x`: A JAX Array. + - `key`: Ignored; provided for compatibility with the rest of the Equinox API. + + **Returns:** + + - The JAX array from calling `self.layer(x)` (with weight normalisation + applied). + """ + weight = self.v * self.g / _norm_except_axis(self.v, axis=self.axis) + layer = tree_at(lambda l: getattr(l, self.weight_name), self.layer, weight) + return layer(x) diff --git a/tests/test_nn.py b/tests/test_nn.py index 7a88720f..5289a104 100644 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -957,6 +957,47 @@ def λ1(): assert out.shape == (4, 6, 6, 6) +def test_weight_norm(getkey): + # Linear + linear = eqx.nn.Linear(4, 4, key=getkey()) + weight_norm_linear = eqx.nn.WeightNorm(layer=linear, weight_name="weight") + + x = jrandom.normal(getkey(), (4,)) + out_weight_norm = weight_norm_linear(x) + out_linear = linear(x) + + assert jnp.allclose(out_weight_norm, out_linear) + + # Axis == None + linear = eqx.nn.Linear(4, 4, key=getkey()) + weight_norm_linear = eqx.nn.WeightNorm( + layer=linear, weight_name="weight", axis=None + ) + + x = jrandom.normal(getkey(), (4,)) + out_weight_norm = weight_norm_linear(x) + out_linear = linear(x) + + assert jnp.allclose(out_weight_norm, out_linear) + + # Conv3d (ndim weight matrices > 2) + conv = eqx.nn.Conv3d(2, 3, 3, key=getkey()) + weight_norm_conv = eqx.nn.WeightNorm(layer=conv, weight_name="weight") + x = jrandom.normal(getkey(), (2, 3, 3, 3)) + out_weight_norm = weight_norm_conv(x) + out_conv = conv(x) + + assert jnp.allclose(out_weight_norm, out_conv) + + # Grads get generated for reparametrized weights, not original + grads = eqx.filter_grad(lambda model, x: jnp.mean(model(x)))( + weight_norm_linear, jrandom.normal(getkey(), (4,)) + ) + assert jnp.count_nonzero(grads.layer.weight) == 0 + assert jnp.any(grads.v) + assert jnp.any(grads.g) + + def test_maxpool1d(): x = jnp.arange(14).reshape(1, 14) max_pool = eqx.nn.MaxPool1d(2, 3)