Skip to content

Commit

Permalink
Add weight normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
boris committed Dec 19, 2023
1 parent 1e2d8c2 commit fdd9932
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 0 deletions.
8 changes: 8 additions & 0 deletions docs/api/nn/normalisation.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,11 @@
members:
- __init__
- __call__

---

::: equinox.nn.WeightNorm
selection:
members:
- __init__
- __call__
1 change: 1 addition & 0 deletions equinox/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,4 @@
State as State,
StateIndex as StateIndex,
)
from ._weight_norm import WeightNorm as WeightNorm
100 changes: 100 additions & 0 deletions equinox/nn/_weight_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from typing import Generic, Optional, TypeVar, Union

import jax
import jax.numpy as jnp
from jaxtyping import Array, PRNGKeyArray

from .._module import field, Module
from .._tree import tree_at


_Layer = TypeVar("_Layer")


def _norm_except_axis(
v: Array, pow: Optional[Union[int, str]] = None, axis: Optional[int] = 0
) -> Array:
for ax in range(len(v.shape)):
if ax != axis:
v = jnp.linalg.norm(v, ord=pow, axis=ax, keepdims=True)
return v if axis is not None else v.reshape([])


class WeightNorm(Module, Generic[_Layer]):
r"""
Applies weight normalisation to a given parameter.
Given a weight matrix $\mathbf{W}$, computes the follow reparametrization:
$\mathbf{W} = g \frac{\mathbf{v}}{\lVert \mathbf{v} \rVert}$
where $g$ is initially chosen to equal $\lVert \mathbf{v} \rVert$.
??? cite
[Weight Normalisation](https://arxiv.org/abs/1602.07868)
```bibtex
@article{DBLP:journals/corr/SalimansK16,
author = {Tim Salimans and
Diederik P. Kingma},
title = {Weight Normalisation: {A} Simple
Reparameterization to Accelerate
Training of Deep Neural Networks},
journal = {CoRR},
volume = {abs/1602.07868},
year = {2016},
url = {http://arxiv.org/abs/1602.07868},
eprinttype = {arXiv},
eprint = {1602.07868},
timestamp = {Mon, 13 Aug 2018 16:47:07 +0200},
biburl = {https://dblp.org/rec/journals/corr/SalimansK16.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
```
"""

layer: _Layer
v: Array
g: Array
weight_name: str = field(static=True)
axis: Optional[int] = field(static=True)

def __init__(
self,
layer: _Layer,
weight_name: str = "weight",
axis: Optional[int] = 0,
):
"""**Arguments:**
- `layer`: The layer to wrap. Usually a [`equinox.nn.Linear`][] or
a convolutional layer (e.g. [`equinox.nn.Conv2d`][]).
- `weight_name`: The name of the layer's parameter (a JAX array) to apply
weight normalisation to.
- `axis`: The norm is computed across every axis except this one.
If `None`, compute across every axis.
"""
self.layer = layer
self.weight_name = weight_name
self.axis = axis

self.v = getattr(layer, weight_name)
self.g = _norm_except_axis(self.v, axis=axis)

@jax.named_scope("eqx.nn.WeightNorm")
def __call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array:
"""**Arguments:**
- `x`: A JAX Array.
- `key`: Ignored; provided for compatibility with the rest of the Equinox API.
**Returns:**
- The JAX array from calling `self.layer(x)` (with weight normalisation
applied).
"""
weight = self.v * self.g / _norm_except_axis(self.v, axis=self.axis)
layer = tree_at(lambda l: getattr(l, self.weight_name), self.layer, weight)
return layer(x)
41 changes: 41 additions & 0 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,6 +957,47 @@ def λ1():
assert out.shape == (4, 6, 6, 6)


def test_weight_norm(getkey):
# Linear
linear = eqx.nn.Linear(4, 4, key=getkey())
weight_norm_linear = eqx.nn.WeightNorm(layer=linear, weight_name="weight")

x = jrandom.normal(getkey(), (4,))
out_weight_norm = weight_norm_linear(x)
out_linear = linear(x)

assert jnp.allclose(out_weight_norm, out_linear)

# Axis == None
linear = eqx.nn.Linear(4, 4, key=getkey())
weight_norm_linear = eqx.nn.WeightNorm(
layer=linear, weight_name="weight", axis=None
)

x = jrandom.normal(getkey(), (4,))
out_weight_norm = weight_norm_linear(x)
out_linear = linear(x)

assert jnp.allclose(out_weight_norm, out_linear)

# Conv3d (ndim weight matrices > 2)
conv = eqx.nn.Conv3d(2, 3, 3, key=getkey())
weight_norm_conv = eqx.nn.WeightNorm(layer=conv, weight_name="weight")
x = jrandom.normal(getkey(), (2, 3, 3, 3))
out_weight_norm = weight_norm_conv(x)
out_conv = conv(x)

assert jnp.allclose(out_weight_norm, out_conv)

# Grads get generated for reparametrized weights, not original
grads = eqx.filter_grad(lambda model, x: jnp.mean(model(x)))(
weight_norm_linear, jrandom.normal(getkey(), (4,))
)
assert jnp.count_nonzero(grads.layer.weight) == 0
assert jnp.any(grads.v)
assert jnp.any(grads.g)


def test_maxpool1d():
x = jnp.arange(14).reshape(1, 14)
max_pool = eqx.nn.MaxPool1d(2, 3)
Expand Down

0 comments on commit fdd9932

Please sign in to comment.