diff --git a/README.md b/README.md
index 0fc197e..59c9027 100644
--- a/README.md
+++ b/README.md
@@ -4,6 +4,26 @@ Paramax
Parameterizations and constraints for JAX PyTrees
-----------------------------------------------------------------------
+Paramax allows applying custom constraints or behaviors to PyTree components,
+using unwrappable placeholders. This can be used for
+- Enforcing positivity (e.g., scale parameters)
+- Structured matrices (triangular, symmetric, etc.)
+- Applying tricks like weight normalization
+- Marking components as non-trainable
+
+Some benefits of the unwrappable pattern:
+- It allows parameterizations to be computed once for a model (e.g. at the top of the
+ loss function).
+- It is flexible, e.g. allowing custom parameterizations to be applied to PyTrees
+ from external libraries
+- It is concise
+
+If you found the package useful, please consider giving it a star on github, and if you
+create ``AbstractUnwrappable``s that may be of interest to others, a pull request would
+be much appreciated!
+
+## Documentation
+
Documentation available [here](https://danielward27.github.io/paramax/).
## Installation
@@ -13,32 +33,22 @@ pip install paramax
## Example
```python
->>> from paramax.wrappers import Parameterize, unwrap
+>>> import paramax
>>> import jax.numpy as jnp
->>> params = Parameterize(jnp.exp, jnp.zeros(3))
->>> unwrap(("abc", 1, params))
+>>> scale = paramax.Parameterize(jnp.exp, jnp.log(jnp.ones(3))) # Enforce positivity
+>>> paramax.unwrap(("abc", 1, scale))
('abc', 1, Array([1., 1., 1.], dtype=float32))
```
-## Why use Paramax?
-Paramax allows applying custom constraints or behaviors to PyTree components, such as:
-- Enforcing positivity (e.g., scale parameters)
-- Structured matrices (triangular, symmetric, etc.)
-- Applying tricks like weight normalization
-- Marking components as non-trainable
-
-Some benefits of the pattern we use:
-- It allows parameterizations to be computed once for a model (e.g. at the top of the loss function).
-- It is concise, flexible, and allows custom parameterizations to be used with PyTrees from external libraries.
-
-## Alternative patterns
-Using properties to access parameterized components is common but has drawbacks:
+## Alternative parameterization patterns
+Using properties to access parameterized model components is common but has drawbacks:
- Parameterizations are tied to class definition, limiting flexibility e.g. this
- cannot be used on PyTrees from external libraries.
-- It can become verbose with many parameters.
-- It often leads to repeatedly computing the parameterization.
-
+ cannot be used on PyTrees from external libraries
+- It can become verbose with many parameters
+- It often leads to repeatedly computing the parameterization
## Related
-We make use of the [Equinox](https://arxiv.org/abs/2111.00254) package, to register
-the PyTrees used in the package.
+- We make use of the [Equinox](https://arxiv.org/abs/2111.00254) package, to register
+the PyTrees used in the package
+- This package spawned out of a need for a simple method to apply parameter constraints
+ in the distributions package [flowjax](https://github.com/danielward27/flowjax)
diff --git a/docs/_static/icon.svg b/docs/_static/icon.svg
index 43e14b2..6805268 100644
--- a/docs/_static/icon.svg
+++ b/docs/_static/icon.svg
@@ -1 +1,3 @@
-
\ No newline at end of file
+
diff --git a/docs/index.rst b/docs/index.rst
index 5afe966..7bbbb30 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -1,7 +1,8 @@
Paramax
===========
-Paramax: a small package for applying parameterizations, constraints to JAX PyTrees.
+A small package for applying parameterizations and constraints to nodes in JAX
+PyTrees.
Installation
@@ -11,16 +12,18 @@ Installation
pip install paramax
-Simple example
+How it works
------------------
-The most common way to apply parameterizations is via
-:py:class:`~paramax.wrappers.Parameterize`. This class takes a callable and any
+- :py:class:`~paramax.wrappers.AbstractUnwrappable` objects act as placeholders in the
+ PyTree, defining the parameterizations.
+- :py:func:`~paramax.wrappers.unwrap` applies the parameterizations, replacing the
+ :py:class:`~paramax.wrappers.AbstractUnwrappable` objects.
+
+A simple example of an :py:class:`~paramax.wrappers.AbstractUnwrappable`
+is :py:class:`~paramax.wrappers.Parameterize`. This class takes a callable and any
positional or keyword arguments, which are stored and passed to the function when
unwrapping.
-When :py:func:`~paramax.wrappers.unwrap` is called on a PyTree containing a
-:py:class:`~paramax.wrappers.Parameterize` object, the stored function is applied
-using the stored arguments.
.. doctest::
@@ -33,9 +36,8 @@ using the stored arguments.
('abc', 1, Array([1., 1., 1.], dtype=float32))
-Many simple parameterizations can be handled with this class. As another example,
-we can parameterize a lower triangular matrix (such that it remains lower triangular
-during optimization) as follows
+Many simple parameterizations can be handled with this class, for example,
+we can parameterize a lower triangular matrix using
.. doctest::
@@ -45,7 +47,8 @@ during optimization) as follows
>>> tril = paramax.Parameterize(jnp.tril, tril)
-See :doc:`/api/wrappers` for more possibilities.
+See :doc:`/api/wrappers` for more :py:class:`~paramax.wrappers.AbstractUnwrappable`
+objects.
When to unwrap
-------------------
@@ -53,6 +56,7 @@ When to unwrap
methods requiring the parameterizations to have been applied.
- Unwrapping prior to a gradient computation used for optimization is usually a mistake!
+
.. toctree::
:caption: API
:maxdepth: 1
diff --git a/paramax/utils.py b/paramax/utils.py
index b8aed03..7030cbf 100644
--- a/paramax/utils.py
+++ b/paramax/utils.py
@@ -11,7 +11,7 @@ def inv_softplus(x: ArrayLike) -> Array:
x,
x < 0,
"Expected positive inputs to inv_softplus. If you are trying to use a negative "
- "scale parameter, consider constructing with positive scales and modifying the "
- "scale attribute post-construction, e.g., using eqx.tree_at.",
+ "scale parameter, you may be able to construct with positive scales, and "
+ "modify the scale attribute post-construction, e.g., using eqx.tree_at.",
)
return jnp.log(-jnp.expm1(-x)) + x
diff --git a/paramax/wrappers.py b/paramax/wrappers.py
index 9ba0336..6b6d7d1 100644
--- a/paramax/wrappers.py
+++ b/paramax/wrappers.py
@@ -1,22 +1,11 @@
""":class:`AbstractUnwrappable` objects and utilities.
-These are "placeholder" values for specifying custom behaviour for nodes in a pytree.
-Many of these facilitate similar functions to pytorch parameterizations. We use this
-for example to apply parameter constraints, masking of parameters etc. To apply the
-behaviour, we use :func:`unwrap`, which will replace any :class:`AbstractUnwrappable`
-nodes in a pytree with the unwrapped versions.
-
-.. note::
-
- If creating a custom unwrappable, remember that unwrapping will generally occur
- after initialization of the model. Because of this, we recommend ensuring that
- the ``unwrap`` method supports unwrapping if the model is constructed in a
- vectorized context, such as ``eqx.filter_vmap``, e.g. through broadcasting or
- vectorization.
+These are placeholder values for specifying custom behaviour for nodes in a pytree,
+applied using :func:`unwrap`.
"""
from abc import abstractmethod
-from collections.abc import Callable, Iterable
+from collections.abc import Callable
from typing import Any, Generic, TypeVar
import equinox as eqx
@@ -25,15 +14,27 @@
from jax import lax
from jax.nn import softplus
from jax.tree_util import tree_leaves
-from jaxtyping import Array, Int, PyTree, Scalar
+from jaxtyping import Array, PyTree
from paramax.utils import inv_softplus
T = TypeVar("T")
+class AbstractUnwrappable(eqx.Module, Generic[T]):
+ """An abstract class representing an unwrappable object.
+
+ Unwrappables replace PyTree nodes, applying custom behavior upon unwrapping.
+ """
+
+ @abstractmethod
+ def unwrap(self) -> T:
+ """Returns the unwrapped pytree, assuming no wrapped subnodes exist."""
+ pass
+
+
def unwrap(tree: PyTree):
- """Recursively unwraps all :class:`AbstractUnwrappable` nodes within a pytree.
+ """Map across a PyTree and unwrap all :class:`AbstractUnwrappable` nodes.
This leaves all other nodes unchanged. If nested, the innermost
``AbstractUnwrappable`` nodes are unwrapped first.
@@ -43,47 +44,40 @@ def unwrap(tree: PyTree):
.. doctest::
- >>> from paramax.wrappers import Parameterize, unwrap
+ >>> import paramax
>>> import jax.numpy as jnp
- >>> params = Parameterize(jnp.exp, jnp.zeros(3))
- >>> unwrap(("abc", 1, params))
+ >>> params = paramax.Parameterize(jnp.exp, jnp.zeros(3))
+ >>> paramax.unwrap(("abc", 1, params))
('abc', 1, Array([1., 1., 1.], dtype=float32))
"""
- def _map_fn(leaf):
- if isinstance(leaf, AbstractUnwrappable):
- # Flatten to ignore until all contained AbstractUnwrappables are unwrapped
- flat, tree_def = eqx.tree_flatten_one_level(leaf)
- tree = jax.tree_util.tree_unflatten(tree_def, unwrap(flat))
- return tree.unwrap()
- return leaf
-
- return jax.tree_util.tree_map(
- f=_map_fn,
- tree=tree,
- is_leaf=lambda x: isinstance(x, AbstractUnwrappable),
- )
-
+ def _unwrap(tree, *, include_self: bool):
+ def _map_fn(leaf):
+ if isinstance(leaf, AbstractUnwrappable):
+ # Unwrap subnodes, then itself
+ return _unwrap(leaf, include_self=False).unwrap()
+ return leaf
-class AbstractUnwrappable(eqx.Module, Generic[T]):
- """An abstract class representing an unwrappable object.
+ def is_leaf(x):
+ is_unwrappable = isinstance(x, AbstractUnwrappable)
+ included = include_self or x is not tree
+ return is_unwrappable and included
- Unwrappables replace PyTree nodes, applying custom behavior upon unwrapping.
- """
+ return jax.tree_util.tree_map(f=_map_fn, tree=tree, is_leaf=is_leaf)
- @abstractmethod
- def unwrap(self) -> T:
- """Returns the unwrapped pytree, assuming no wrapped subnodes exist."""
- pass
+ return _unwrap(tree, include_self=True)
class Parameterize(AbstractUnwrappable[T]):
"""Unwrap an object by calling fn with args and kwargs.
- All of fn, args and kwargs may contain trainable parameters. If the Parameterize is
- created within ``eqx.filter_vmap``, unwrapping is automatically vectorized
- correctly, as long as the vmapped constructor adds leading batch
- dimensions to all arrays (the default for ``eqx.filter_vmap``).
+ All of fn, args and kwargs may contain trainable parameters.
+
+ .. note::
+
+ Unwrapping typically occurs after model initialization. Therefore, if the
+ ``Parameterize`` object may be created in a vectorized context, we recommend
+ ensuring that ``fn`` still unwraps correctly, e.g. by supporting broadcasting.
Example:
.. doctest::
@@ -101,42 +95,16 @@ class Parameterize(AbstractUnwrappable[T]):
"""
fn: Callable[..., T]
- args: Iterable
+ args: tuple[Any, ...]
kwargs: dict[str, Any]
- _dummy: Int[Scalar, ""] # Used to track vectorized construction.
def __init__(self, fn: Callable, *args, **kwargs):
self.fn = fn
- self.args = args
+ self.args = tuple(args)
self.kwargs = kwargs
- self._dummy = jnp.empty((), int)
-
- def unwrap(self) -> T:
-
- def _unwrap_fn(self):
- return self.fn(*self.args, **self.kwargs)
-
- for dim in reversed(self._dummy.shape): # vectorize if constructed under vmap
- _unwrap_fn = eqx.filter_vmap(_unwrap_fn, axis_size=dim)
- return _unwrap_fn(self)
-
-
-class NonTrainable(AbstractUnwrappable[T]):
- """Applies stop gradient to all arraylike leaves before unwrapping.
-
- See also :func:`non_trainable`, which is probably a generally prefereable way to
- achieve similar behaviour, which wraps the arraylike leaves directly, rather than
- the tree. Useful to mark pytrees (arrays, submodules, etc) as frozen/non-trainable.
- We also filter out NonTrainable nodes when partitioning parameters for training,
- or when parameterizing bijections in coupling/masked autoregressive flows
- (transformers).
- """
-
- tree: T
def unwrap(self) -> T:
- differentiable, static = eqx.partition(self.tree, eqx.is_array_like)
- return eqx.combine(lax.stop_gradient(differentiable), static)
+ return self.fn(*self.args, **self.kwargs)
def non_trainable(tree: PyTree):
@@ -156,11 +124,8 @@ def non_trainable(tree: PyTree):
... )
- This is done in both :func:`~paramax.train.fit_to_data` and
- :func:`~paramax.train.fit_to_key_based_loss`.
-
- Wrapping the arrays rather than the entire tree is often preferable, allowing easier
- access to attributes compared to wrapping the entire tree.
+ Wrapping the arrays in a model rather than the entire tree is often preferable,
+ allowing easier access to attributes compared to wrapping the entire tree.
Args:
tree: The pytree.
@@ -176,6 +141,24 @@ def _map_fn(leaf):
)
+class NonTrainable(AbstractUnwrappable[T]):
+ """Applies stop gradient to all arraylike leaves before unwrapping.
+
+ See also :func:`non_trainable`, which is probably a generally prefereable way to
+ achieve similar behaviour, which wraps the arraylike leaves directly, rather than
+ the tree. Useful to mark pytrees (arrays, submodules, etc) as frozen/non-trainable.
+ Note that the underlying parameters may still be impacted by regularization,
+ so it is generally advised to use this as a suggestively named class
+ for filtering parameters.
+ """
+
+ tree: T
+
+ def unwrap(self) -> T:
+ differentiable, static = eqx.partition(self.tree, eqx.is_array_like)
+ return eqx.combine(lax.stop_gradient(differentiable), static)
+
+
class WeightNormalization(AbstractUnwrappable[Array]):
"""Applies weight normalization (https://arxiv.org/abs/1602.07868).
@@ -184,7 +167,7 @@ class WeightNormalization(AbstractUnwrappable[Array]):
"""
weight: Array | AbstractUnwrappable[Array]
- scale: Array | AbstractUnwrappable[Array] = eqx.field(init=False)
+ scale: Array | AbstractUnwrappable[Array]
def __init__(self, weight: Array | AbstractUnwrappable[Array]):
self.weight = weight
diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py
index bced911..5e88eaa 100644
--- a/tests/test_wrappers.py
+++ b/tests/test_wrappers.py
@@ -20,7 +20,7 @@ def test_Parameterize():
assert pytest.approx(jnp.eye(3)) == unwrap(diag)
-def test_nested_Parameterized():
+def test_nested_unwrap():
param = Parameterize(
jnp.square,
Parameterize(jnp.square, Parameterize(jnp.square, 2)),
@@ -56,7 +56,6 @@ def test_WeightNormalization():
test_cases = {
"NonTrainable": lambda key: NonTrainable(jr.normal(key, 10)),
"Parameterize-exp": lambda key: Parameterize(jnp.exp, jr.normal(key, 10)),
- "Parameterize-diag": lambda key: Parameterize(jnp.diag, jr.normal(key, 10)),
"WeightNormalization": lambda key: WeightNormalization(jr.normal(key, (10, 2))),
}