Skip to content

Commit

Permalink
docs and remove dummy for simplicity
Browse files Browse the repository at this point in the history
  • Loading branch information
danielward27 committed Oct 21, 2024
1 parent b3893ba commit 1bec801
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 118 deletions.
54 changes: 32 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,26 @@ Paramax
Parameterizations and constraints for JAX PyTrees
-----------------------------------------------------------------------

Paramax allows applying custom constraints or behaviors to PyTree components,
using unwrappable placeholders. This can be used for
- Enforcing positivity (e.g., scale parameters)
- Structured matrices (triangular, symmetric, etc.)
- Applying tricks like weight normalization
- Marking components as non-trainable

Some benefits of the unwrappable pattern:
- It allows parameterizations to be computed once for a model (e.g. at the top of the
loss function).
- It is flexible, e.g. allowing custom parameterizations to be applied to PyTrees
from external libraries
- It is concise

If you found the package useful, please consider giving it a star on github, and if you
create ``AbstractUnwrappable``s that may be of interest to others, a pull request would
be much appreciated!

## Documentation

Documentation available [here](https://danielward27.github.io/paramax/).

## Installation
Expand All @@ -13,32 +33,22 @@ pip install paramax

## Example
```python
>>> from paramax.wrappers import Parameterize, unwrap
>>> import paramax
>>> import jax.numpy as jnp
>>> params = Parameterize(jnp.exp, jnp.zeros(3))
>>> unwrap(("abc", 1, params))
>>> scale = paramax.Parameterize(jnp.exp, jnp.log(jnp.ones(3))) # Enforce positivity
>>> paramax.unwrap(("abc", 1, scale))
('abc', 1, Array([1., 1., 1.], dtype=float32))
```

## Why use Paramax?
Paramax allows applying custom constraints or behaviors to PyTree components, such as:
- Enforcing positivity (e.g., scale parameters)
- Structured matrices (triangular, symmetric, etc.)
- Applying tricks like weight normalization
- Marking components as non-trainable

Some benefits of the pattern we use:
- It allows parameterizations to be computed once for a model (e.g. at the top of the loss function).
- It is concise, flexible, and allows custom parameterizations to be used with PyTrees from external libraries.

## Alternative patterns
Using properties to access parameterized components is common but has drawbacks:
## Alternative parameterization patterns
Using properties to access parameterized model components is common but has drawbacks:
- Parameterizations are tied to class definition, limiting flexibility e.g. this
cannot be used on PyTrees from external libraries.
- It can become verbose with many parameters.
- It often leads to repeatedly computing the parameterization.

cannot be used on PyTrees from external libraries
- It can become verbose with many parameters
- It often leads to repeatedly computing the parameterization

## Related
We make use of the [Equinox](https://arxiv.org/abs/2111.00254) package, to register
the PyTrees used in the package.
- We make use of the [Equinox](https://arxiv.org/abs/2111.00254) package, to register
the PyTrees used in the package
- This package spawned out of a need for a simple method to apply parameter constraints
in the distributions package [flowjax](https://github.com/danielward27/flowjax)
4 changes: 3 additions & 1 deletion docs/_static/icon.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
26 changes: 15 additions & 11 deletions docs/index.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
Paramax
===========

Paramax: a small package for applying parameterizations, constraints to JAX PyTrees.
A small package for applying parameterizations and constraints to nodes in JAX
PyTrees.


Installation
Expand All @@ -11,16 +12,18 @@ Installation
pip install paramax
Simple example
How it works
------------------
The most common way to apply parameterizations is via
:py:class:`~paramax.wrappers.Parameterize`. This class takes a callable and any
- :py:class:`~paramax.wrappers.AbstractUnwrappable` objects act as placeholders in the
PyTree, defining the parameterizations.
- :py:func:`~paramax.wrappers.unwrap` applies the parameterizations, replacing the
:py:class:`~paramax.wrappers.AbstractUnwrappable` objects.

A simple example of an :py:class:`~paramax.wrappers.AbstractUnwrappable`
is :py:class:`~paramax.wrappers.Parameterize`. This class takes a callable and any
positional or keyword arguments, which are stored and passed to the function when
unwrapping.

When :py:func:`~paramax.wrappers.unwrap` is called on a PyTree containing a
:py:class:`~paramax.wrappers.Parameterize` object, the stored function is applied
using the stored arguments.

.. doctest::

Expand All @@ -33,9 +36,8 @@ using the stored arguments.
('abc', 1, Array([1., 1., 1.], dtype=float32))


Many simple parameterizations can be handled with this class. As another example,
we can parameterize a lower triangular matrix (such that it remains lower triangular
during optimization) as follows
Many simple parameterizations can be handled with this class, for example,
we can parameterize a lower triangular matrix using

.. doctest::

Expand All @@ -45,14 +47,16 @@ during optimization) as follows
>>> tril = paramax.Parameterize(jnp.tril, tril)


See :doc:`/api/wrappers` for more possibilities.
See :doc:`/api/wrappers` for more :py:class:`~paramax.wrappers.AbstractUnwrappable`
objects.

When to unwrap
-------------------
- Unwrap whenever necessary, typically at the top of loss functions, functions or
methods requiring the parameterizations to have been applied.
- Unwrapping prior to a gradient computation used for optimization is usually a mistake!


.. toctree::
:caption: API
:maxdepth: 1
Expand Down
4 changes: 2 additions & 2 deletions paramax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def inv_softplus(x: ArrayLike) -> Array:
x,
x < 0,
"Expected positive inputs to inv_softplus. If you are trying to use a negative "
"scale parameter, consider constructing with positive scales and modifying the "
"scale attribute post-construction, e.g., using eqx.tree_at.",
"scale parameter, you may be able to construct with positive scales, and "
"modify the scale attribute post-construction, e.g., using eqx.tree_at.",
)
return jnp.log(-jnp.expm1(-x)) + x
143 changes: 63 additions & 80 deletions paramax/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,11 @@
""":class:`AbstractUnwrappable` objects and utilities.
These are "placeholder" values for specifying custom behaviour for nodes in a pytree.
Many of these facilitate similar functions to pytorch parameterizations. We use this
for example to apply parameter constraints, masking of parameters etc. To apply the
behaviour, we use :func:`unwrap`, which will replace any :class:`AbstractUnwrappable`
nodes in a pytree with the unwrapped versions.
.. note::
If creating a custom unwrappable, remember that unwrapping will generally occur
after initialization of the model. Because of this, we recommend ensuring that
the ``unwrap`` method supports unwrapping if the model is constructed in a
vectorized context, such as ``eqx.filter_vmap``, e.g. through broadcasting or
vectorization.
These are placeholder values for specifying custom behaviour for nodes in a pytree,
applied using :func:`unwrap`.
"""

from abc import abstractmethod
from collections.abc import Callable, Iterable
from collections.abc import Callable
from typing import Any, Generic, TypeVar

import equinox as eqx
Expand All @@ -25,15 +14,27 @@
from jax import lax
from jax.nn import softplus
from jax.tree_util import tree_leaves
from jaxtyping import Array, Int, PyTree, Scalar
from jaxtyping import Array, PyTree

from paramax.utils import inv_softplus

T = TypeVar("T")


class AbstractUnwrappable(eqx.Module, Generic[T]):
"""An abstract class representing an unwrappable object.
Unwrappables replace PyTree nodes, applying custom behavior upon unwrapping.
"""

@abstractmethod
def unwrap(self) -> T:
"""Returns the unwrapped pytree, assuming no wrapped subnodes exist."""
pass


def unwrap(tree: PyTree):
"""Recursively unwraps all :class:`AbstractUnwrappable` nodes within a pytree.
"""Map across a PyTree and unwrap all :class:`AbstractUnwrappable` nodes.
This leaves all other nodes unchanged. If nested, the innermost
``AbstractUnwrappable`` nodes are unwrapped first.
Expand All @@ -43,47 +44,40 @@ def unwrap(tree: PyTree):
.. doctest::
>>> from paramax.wrappers import Parameterize, unwrap
>>> import paramax
>>> import jax.numpy as jnp
>>> params = Parameterize(jnp.exp, jnp.zeros(3))
>>> unwrap(("abc", 1, params))
>>> params = paramax.Parameterize(jnp.exp, jnp.zeros(3))
>>> paramax.unwrap(("abc", 1, params))
('abc', 1, Array([1., 1., 1.], dtype=float32))
"""

def _map_fn(leaf):
if isinstance(leaf, AbstractUnwrappable):
# Flatten to ignore until all contained AbstractUnwrappables are unwrapped
flat, tree_def = eqx.tree_flatten_one_level(leaf)
tree = jax.tree_util.tree_unflatten(tree_def, unwrap(flat))
return tree.unwrap()
return leaf

return jax.tree_util.tree_map(
f=_map_fn,
tree=tree,
is_leaf=lambda x: isinstance(x, AbstractUnwrappable),
)

def _unwrap(tree, *, include_self: bool):
def _map_fn(leaf):
if isinstance(leaf, AbstractUnwrappable):
# Unwrap subnodes, then itself
return _unwrap(leaf, include_self=False).unwrap()
return leaf

class AbstractUnwrappable(eqx.Module, Generic[T]):
"""An abstract class representing an unwrappable object.
def is_leaf(x):
is_unwrappable = isinstance(x, AbstractUnwrappable)
included = include_self or x is not tree
return is_unwrappable and included

Unwrappables replace PyTree nodes, applying custom behavior upon unwrapping.
"""
return jax.tree_util.tree_map(f=_map_fn, tree=tree, is_leaf=is_leaf)

@abstractmethod
def unwrap(self) -> T:
"""Returns the unwrapped pytree, assuming no wrapped subnodes exist."""
pass
return _unwrap(tree, include_self=True)


class Parameterize(AbstractUnwrappable[T]):
"""Unwrap an object by calling fn with args and kwargs.
All of fn, args and kwargs may contain trainable parameters. If the Parameterize is
created within ``eqx.filter_vmap``, unwrapping is automatically vectorized
correctly, as long as the vmapped constructor adds leading batch
dimensions to all arrays (the default for ``eqx.filter_vmap``).
All of fn, args and kwargs may contain trainable parameters.
.. note::
Unwrapping typically occurs after model initialization. Therefore, if the
``Parameterize`` object may be created in a vectorized context, we recommend
ensuring that ``fn`` still unwraps correctly, e.g. by supporting broadcasting.
Example:
.. doctest::
Expand All @@ -101,42 +95,16 @@ class Parameterize(AbstractUnwrappable[T]):
"""

fn: Callable[..., T]
args: Iterable
args: tuple[Any, ...]
kwargs: dict[str, Any]
_dummy: Int[Scalar, ""] # Used to track vectorized construction.

def __init__(self, fn: Callable, *args, **kwargs):
self.fn = fn
self.args = args
self.args = tuple(args)
self.kwargs = kwargs
self._dummy = jnp.empty((), int)

def unwrap(self) -> T:

def _unwrap_fn(self):
return self.fn(*self.args, **self.kwargs)

for dim in reversed(self._dummy.shape): # vectorize if constructed under vmap
_unwrap_fn = eqx.filter_vmap(_unwrap_fn, axis_size=dim)
return _unwrap_fn(self)


class NonTrainable(AbstractUnwrappable[T]):
"""Applies stop gradient to all arraylike leaves before unwrapping.
See also :func:`non_trainable`, which is probably a generally prefereable way to
achieve similar behaviour, which wraps the arraylike leaves directly, rather than
the tree. Useful to mark pytrees (arrays, submodules, etc) as frozen/non-trainable.
We also filter out NonTrainable nodes when partitioning parameters for training,
or when parameterizing bijections in coupling/masked autoregressive flows
(transformers).
"""

tree: T

def unwrap(self) -> T:
differentiable, static = eqx.partition(self.tree, eqx.is_array_like)
return eqx.combine(lax.stop_gradient(differentiable), static)
return self.fn(*self.args, **self.kwargs)


def non_trainable(tree: PyTree):
Expand All @@ -156,11 +124,8 @@ def non_trainable(tree: PyTree):
... )
This is done in both :func:`~paramax.train.fit_to_data` and
:func:`~paramax.train.fit_to_key_based_loss`.
Wrapping the arrays rather than the entire tree is often preferable, allowing easier
access to attributes compared to wrapping the entire tree.
Wrapping the arrays in a model rather than the entire tree is often preferable,
allowing easier access to attributes compared to wrapping the entire tree.
Args:
tree: The pytree.
Expand All @@ -176,6 +141,24 @@ def _map_fn(leaf):
)


class NonTrainable(AbstractUnwrappable[T]):
"""Applies stop gradient to all arraylike leaves before unwrapping.
See also :func:`non_trainable`, which is probably a generally prefereable way to
achieve similar behaviour, which wraps the arraylike leaves directly, rather than
the tree. Useful to mark pytrees (arrays, submodules, etc) as frozen/non-trainable.
Note that the underlying parameters may still be impacted by regularization,
so it is generally advised to use this as a suggestively named class
for filtering parameters.
"""

tree: T

def unwrap(self) -> T:
differentiable, static = eqx.partition(self.tree, eqx.is_array_like)
return eqx.combine(lax.stop_gradient(differentiable), static)


class WeightNormalization(AbstractUnwrappable[Array]):
"""Applies weight normalization (https://arxiv.org/abs/1602.07868).
Expand All @@ -184,7 +167,7 @@ class WeightNormalization(AbstractUnwrappable[Array]):
"""

weight: Array | AbstractUnwrappable[Array]
scale: Array | AbstractUnwrappable[Array] = eqx.field(init=False)
scale: Array | AbstractUnwrappable[Array]

def __init__(self, weight: Array | AbstractUnwrappable[Array]):
self.weight = weight
Expand Down
3 changes: 1 addition & 2 deletions tests/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_Parameterize():
assert pytest.approx(jnp.eye(3)) == unwrap(diag)


def test_nested_Parameterized():
def test_nested_unwrap():
param = Parameterize(
jnp.square,
Parameterize(jnp.square, Parameterize(jnp.square, 2)),
Expand Down Expand Up @@ -56,7 +56,6 @@ def test_WeightNormalization():
test_cases = {
"NonTrainable": lambda key: NonTrainable(jr.normal(key, 10)),
"Parameterize-exp": lambda key: Parameterize(jnp.exp, jr.normal(key, 10)),
"Parameterize-diag": lambda key: Parameterize(jnp.diag, jr.normal(key, 10)),
"WeightNormalization": lambda key: WeightNormalization(jr.normal(key, (10, 2))),
}

Expand Down

0 comments on commit 1bec801

Please sign in to comment.