-
-
Notifications
You must be signed in to change notification settings - Fork 150
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
boris
committed
Dec 19, 2023
1 parent
1e2d8c2
commit fdd9932
Showing
4 changed files
with
150 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,3 +29,11 @@ | |
members: | ||
- __init__ | ||
- __call__ | ||
|
||
--- | ||
|
||
::: equinox.nn.WeightNorm | ||
selection: | ||
members: | ||
- __init__ | ||
- __call__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,3 +47,4 @@ | |
State as State, | ||
StateIndex as StateIndex, | ||
) | ||
from ._weight_norm import WeightNorm as WeightNorm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
from typing import Generic, Optional, TypeVar, Union | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
from jaxtyping import Array, PRNGKeyArray | ||
|
||
from .._module import field, Module | ||
from .._tree import tree_at | ||
|
||
|
||
_Layer = TypeVar("_Layer") | ||
|
||
|
||
def _norm_except_axis( | ||
v: Array, pow: Optional[Union[int, str]] = None, axis: Optional[int] = 0 | ||
) -> Array: | ||
for ax in range(len(v.shape)): | ||
if ax != axis: | ||
v = jnp.linalg.norm(v, ord=pow, axis=ax, keepdims=True) | ||
return v if axis is not None else v.reshape([]) | ||
|
||
|
||
class WeightNorm(Module, Generic[_Layer]): | ||
r""" | ||
Applies weight normalisation to a given parameter. | ||
Given a weight matrix $\mathbf{W}$, computes the follow reparametrization: | ||
$\mathbf{W} = g \frac{\mathbf{v}}{\lVert \mathbf{v} \rVert}$ | ||
where $g$ is initially chosen to equal $\lVert \mathbf{v} \rVert$. | ||
??? cite | ||
[Weight Normalisation](https://arxiv.org/abs/1602.07868) | ||
```bibtex | ||
@article{DBLP:journals/corr/SalimansK16, | ||
author = {Tim Salimans and | ||
Diederik P. Kingma}, | ||
title = {Weight Normalisation: {A} Simple | ||
Reparameterization to Accelerate | ||
Training of Deep Neural Networks}, | ||
journal = {CoRR}, | ||
volume = {abs/1602.07868}, | ||
year = {2016}, | ||
url = {http://arxiv.org/abs/1602.07868}, | ||
eprinttype = {arXiv}, | ||
eprint = {1602.07868}, | ||
timestamp = {Mon, 13 Aug 2018 16:47:07 +0200}, | ||
biburl = {https://dblp.org/rec/journals/corr/SalimansK16.bib}, | ||
bibsource = {dblp computer science bibliography, https://dblp.org} | ||
} | ||
``` | ||
""" | ||
|
||
layer: _Layer | ||
v: Array | ||
g: Array | ||
weight_name: str = field(static=True) | ||
axis: Optional[int] = field(static=True) | ||
|
||
def __init__( | ||
self, | ||
layer: _Layer, | ||
weight_name: str = "weight", | ||
axis: Optional[int] = 0, | ||
): | ||
"""**Arguments:** | ||
- `layer`: The layer to wrap. Usually a [`equinox.nn.Linear`][] or | ||
a convolutional layer (e.g. [`equinox.nn.Conv2d`][]). | ||
- `weight_name`: The name of the layer's parameter (a JAX array) to apply | ||
weight normalisation to. | ||
- `axis`: The norm is computed across every axis except this one. | ||
If `None`, compute across every axis. | ||
""" | ||
self.layer = layer | ||
self.weight_name = weight_name | ||
self.axis = axis | ||
|
||
self.v = getattr(layer, weight_name) | ||
self.g = _norm_except_axis(self.v, axis=axis) | ||
|
||
@jax.named_scope("eqx.nn.WeightNorm") | ||
def __call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array: | ||
"""**Arguments:** | ||
- `x`: A JAX Array. | ||
- `key`: Ignored; provided for compatibility with the rest of the Equinox API. | ||
**Returns:** | ||
- The JAX array from calling `self.layer(x)` (with weight normalisation | ||
applied). | ||
""" | ||
weight = self.v * self.g / _norm_except_axis(self.v, axis=self.axis) | ||
layer = tree_at(lambda l: getattr(l, self.weight_name), self.layer, weight) | ||
return layer(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters