From 1bec8015ae25c571c587f95d19b5a17badb90ea3 Mon Sep 17 00:00:00 2001 From: danielward27 Date: Mon, 21 Oct 2024 12:46:55 +0100 Subject: [PATCH] docs and remove dummy for simplicity --- README.md | 54 +++++++++------- docs/_static/icon.svg | 4 +- docs/index.rst | 26 ++++---- paramax/utils.py | 4 +- paramax/wrappers.py | 143 ++++++++++++++++++----------------------- tests/test_wrappers.py | 3 +- 6 files changed, 116 insertions(+), 118 deletions(-) diff --git a/README.md b/README.md index 0fc197e..59c9027 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,26 @@ Paramax Parameterizations and constraints for JAX PyTrees ----------------------------------------------------------------------- +Paramax allows applying custom constraints or behaviors to PyTree components, +using unwrappable placeholders. This can be used for +- Enforcing positivity (e.g., scale parameters) +- Structured matrices (triangular, symmetric, etc.) +- Applying tricks like weight normalization +- Marking components as non-trainable + +Some benefits of the unwrappable pattern: +- It allows parameterizations to be computed once for a model (e.g. at the top of the + loss function). +- It is flexible, e.g. allowing custom parameterizations to be applied to PyTrees + from external libraries +- It is concise + +If you found the package useful, please consider giving it a star on github, and if you +create ``AbstractUnwrappable``s that may be of interest to others, a pull request would +be much appreciated! + +## Documentation + Documentation available [here](https://danielward27.github.io/paramax/). ## Installation @@ -13,32 +33,22 @@ pip install paramax ## Example ```python ->>> from paramax.wrappers import Parameterize, unwrap +>>> import paramax >>> import jax.numpy as jnp ->>> params = Parameterize(jnp.exp, jnp.zeros(3)) ->>> unwrap(("abc", 1, params)) +>>> scale = paramax.Parameterize(jnp.exp, jnp.log(jnp.ones(3))) # Enforce positivity +>>> paramax.unwrap(("abc", 1, scale)) ('abc', 1, Array([1., 1., 1.], dtype=float32)) ``` -## Why use Paramax? -Paramax allows applying custom constraints or behaviors to PyTree components, such as: -- Enforcing positivity (e.g., scale parameters) -- Structured matrices (triangular, symmetric, etc.) -- Applying tricks like weight normalization -- Marking components as non-trainable - -Some benefits of the pattern we use: -- It allows parameterizations to be computed once for a model (e.g. at the top of the loss function). -- It is concise, flexible, and allows custom parameterizations to be used with PyTrees from external libraries. - -## Alternative patterns -Using properties to access parameterized components is common but has drawbacks: +## Alternative parameterization patterns +Using properties to access parameterized model components is common but has drawbacks: - Parameterizations are tied to class definition, limiting flexibility e.g. this - cannot be used on PyTrees from external libraries. -- It can become verbose with many parameters. -- It often leads to repeatedly computing the parameterization. - + cannot be used on PyTrees from external libraries +- It can become verbose with many parameters +- It often leads to repeatedly computing the parameterization ## Related -We make use of the [Equinox](https://arxiv.org/abs/2111.00254) package, to register -the PyTrees used in the package. +- We make use of the [Equinox](https://arxiv.org/abs/2111.00254) package, to register +the PyTrees used in the package +- This package spawned out of a need for a simple method to apply parameter constraints + in the distributions package [flowjax](https://github.com/danielward27/flowjax) diff --git a/docs/_static/icon.svg b/docs/_static/icon.svg index 43e14b2..6805268 100644 --- a/docs/_static/icon.svg +++ b/docs/_static/icon.svg @@ -1 +1,3 @@ - \ No newline at end of file + + + diff --git a/docs/index.rst b/docs/index.rst index 5afe966..7bbbb30 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,7 +1,8 @@ Paramax =========== -Paramax: a small package for applying parameterizations, constraints to JAX PyTrees. +A small package for applying parameterizations and constraints to nodes in JAX +PyTrees. Installation @@ -11,16 +12,18 @@ Installation pip install paramax -Simple example +How it works ------------------ -The most common way to apply parameterizations is via -:py:class:`~paramax.wrappers.Parameterize`. This class takes a callable and any +- :py:class:`~paramax.wrappers.AbstractUnwrappable` objects act as placeholders in the + PyTree, defining the parameterizations. +- :py:func:`~paramax.wrappers.unwrap` applies the parameterizations, replacing the + :py:class:`~paramax.wrappers.AbstractUnwrappable` objects. + +A simple example of an :py:class:`~paramax.wrappers.AbstractUnwrappable` +is :py:class:`~paramax.wrappers.Parameterize`. This class takes a callable and any positional or keyword arguments, which are stored and passed to the function when unwrapping. -When :py:func:`~paramax.wrappers.unwrap` is called on a PyTree containing a -:py:class:`~paramax.wrappers.Parameterize` object, the stored function is applied -using the stored arguments. .. doctest:: @@ -33,9 +36,8 @@ using the stored arguments. ('abc', 1, Array([1., 1., 1.], dtype=float32)) -Many simple parameterizations can be handled with this class. As another example, -we can parameterize a lower triangular matrix (such that it remains lower triangular -during optimization) as follows +Many simple parameterizations can be handled with this class, for example, +we can parameterize a lower triangular matrix using .. doctest:: @@ -45,7 +47,8 @@ during optimization) as follows >>> tril = paramax.Parameterize(jnp.tril, tril) -See :doc:`/api/wrappers` for more possibilities. +See :doc:`/api/wrappers` for more :py:class:`~paramax.wrappers.AbstractUnwrappable` +objects. When to unwrap ------------------- @@ -53,6 +56,7 @@ When to unwrap methods requiring the parameterizations to have been applied. - Unwrapping prior to a gradient computation used for optimization is usually a mistake! + .. toctree:: :caption: API :maxdepth: 1 diff --git a/paramax/utils.py b/paramax/utils.py index b8aed03..7030cbf 100644 --- a/paramax/utils.py +++ b/paramax/utils.py @@ -11,7 +11,7 @@ def inv_softplus(x: ArrayLike) -> Array: x, x < 0, "Expected positive inputs to inv_softplus. If you are trying to use a negative " - "scale parameter, consider constructing with positive scales and modifying the " - "scale attribute post-construction, e.g., using eqx.tree_at.", + "scale parameter, you may be able to construct with positive scales, and " + "modify the scale attribute post-construction, e.g., using eqx.tree_at.", ) return jnp.log(-jnp.expm1(-x)) + x diff --git a/paramax/wrappers.py b/paramax/wrappers.py index 9ba0336..6b6d7d1 100644 --- a/paramax/wrappers.py +++ b/paramax/wrappers.py @@ -1,22 +1,11 @@ """:class:`AbstractUnwrappable` objects and utilities. -These are "placeholder" values for specifying custom behaviour for nodes in a pytree. -Many of these facilitate similar functions to pytorch parameterizations. We use this -for example to apply parameter constraints, masking of parameters etc. To apply the -behaviour, we use :func:`unwrap`, which will replace any :class:`AbstractUnwrappable` -nodes in a pytree with the unwrapped versions. - -.. note:: - - If creating a custom unwrappable, remember that unwrapping will generally occur - after initialization of the model. Because of this, we recommend ensuring that - the ``unwrap`` method supports unwrapping if the model is constructed in a - vectorized context, such as ``eqx.filter_vmap``, e.g. through broadcasting or - vectorization. +These are placeholder values for specifying custom behaviour for nodes in a pytree, +applied using :func:`unwrap`. """ from abc import abstractmethod -from collections.abc import Callable, Iterable +from collections.abc import Callable from typing import Any, Generic, TypeVar import equinox as eqx @@ -25,15 +14,27 @@ from jax import lax from jax.nn import softplus from jax.tree_util import tree_leaves -from jaxtyping import Array, Int, PyTree, Scalar +from jaxtyping import Array, PyTree from paramax.utils import inv_softplus T = TypeVar("T") +class AbstractUnwrappable(eqx.Module, Generic[T]): + """An abstract class representing an unwrappable object. + + Unwrappables replace PyTree nodes, applying custom behavior upon unwrapping. + """ + + @abstractmethod + def unwrap(self) -> T: + """Returns the unwrapped pytree, assuming no wrapped subnodes exist.""" + pass + + def unwrap(tree: PyTree): - """Recursively unwraps all :class:`AbstractUnwrappable` nodes within a pytree. + """Map across a PyTree and unwrap all :class:`AbstractUnwrappable` nodes. This leaves all other nodes unchanged. If nested, the innermost ``AbstractUnwrappable`` nodes are unwrapped first. @@ -43,47 +44,40 @@ def unwrap(tree: PyTree): .. doctest:: - >>> from paramax.wrappers import Parameterize, unwrap + >>> import paramax >>> import jax.numpy as jnp - >>> params = Parameterize(jnp.exp, jnp.zeros(3)) - >>> unwrap(("abc", 1, params)) + >>> params = paramax.Parameterize(jnp.exp, jnp.zeros(3)) + >>> paramax.unwrap(("abc", 1, params)) ('abc', 1, Array([1., 1., 1.], dtype=float32)) """ - def _map_fn(leaf): - if isinstance(leaf, AbstractUnwrappable): - # Flatten to ignore until all contained AbstractUnwrappables are unwrapped - flat, tree_def = eqx.tree_flatten_one_level(leaf) - tree = jax.tree_util.tree_unflatten(tree_def, unwrap(flat)) - return tree.unwrap() - return leaf - - return jax.tree_util.tree_map( - f=_map_fn, - tree=tree, - is_leaf=lambda x: isinstance(x, AbstractUnwrappable), - ) - + def _unwrap(tree, *, include_self: bool): + def _map_fn(leaf): + if isinstance(leaf, AbstractUnwrappable): + # Unwrap subnodes, then itself + return _unwrap(leaf, include_self=False).unwrap() + return leaf -class AbstractUnwrappable(eqx.Module, Generic[T]): - """An abstract class representing an unwrappable object. + def is_leaf(x): + is_unwrappable = isinstance(x, AbstractUnwrappable) + included = include_self or x is not tree + return is_unwrappable and included - Unwrappables replace PyTree nodes, applying custom behavior upon unwrapping. - """ + return jax.tree_util.tree_map(f=_map_fn, tree=tree, is_leaf=is_leaf) - @abstractmethod - def unwrap(self) -> T: - """Returns the unwrapped pytree, assuming no wrapped subnodes exist.""" - pass + return _unwrap(tree, include_self=True) class Parameterize(AbstractUnwrappable[T]): """Unwrap an object by calling fn with args and kwargs. - All of fn, args and kwargs may contain trainable parameters. If the Parameterize is - created within ``eqx.filter_vmap``, unwrapping is automatically vectorized - correctly, as long as the vmapped constructor adds leading batch - dimensions to all arrays (the default for ``eqx.filter_vmap``). + All of fn, args and kwargs may contain trainable parameters. + + .. note:: + + Unwrapping typically occurs after model initialization. Therefore, if the + ``Parameterize`` object may be created in a vectorized context, we recommend + ensuring that ``fn`` still unwraps correctly, e.g. by supporting broadcasting. Example: .. doctest:: @@ -101,42 +95,16 @@ class Parameterize(AbstractUnwrappable[T]): """ fn: Callable[..., T] - args: Iterable + args: tuple[Any, ...] kwargs: dict[str, Any] - _dummy: Int[Scalar, ""] # Used to track vectorized construction. def __init__(self, fn: Callable, *args, **kwargs): self.fn = fn - self.args = args + self.args = tuple(args) self.kwargs = kwargs - self._dummy = jnp.empty((), int) - - def unwrap(self) -> T: - - def _unwrap_fn(self): - return self.fn(*self.args, **self.kwargs) - - for dim in reversed(self._dummy.shape): # vectorize if constructed under vmap - _unwrap_fn = eqx.filter_vmap(_unwrap_fn, axis_size=dim) - return _unwrap_fn(self) - - -class NonTrainable(AbstractUnwrappable[T]): - """Applies stop gradient to all arraylike leaves before unwrapping. - - See also :func:`non_trainable`, which is probably a generally prefereable way to - achieve similar behaviour, which wraps the arraylike leaves directly, rather than - the tree. Useful to mark pytrees (arrays, submodules, etc) as frozen/non-trainable. - We also filter out NonTrainable nodes when partitioning parameters for training, - or when parameterizing bijections in coupling/masked autoregressive flows - (transformers). - """ - - tree: T def unwrap(self) -> T: - differentiable, static = eqx.partition(self.tree, eqx.is_array_like) - return eqx.combine(lax.stop_gradient(differentiable), static) + return self.fn(*self.args, **self.kwargs) def non_trainable(tree: PyTree): @@ -156,11 +124,8 @@ def non_trainable(tree: PyTree): ... ) - This is done in both :func:`~paramax.train.fit_to_data` and - :func:`~paramax.train.fit_to_key_based_loss`. - - Wrapping the arrays rather than the entire tree is often preferable, allowing easier - access to attributes compared to wrapping the entire tree. + Wrapping the arrays in a model rather than the entire tree is often preferable, + allowing easier access to attributes compared to wrapping the entire tree. Args: tree: The pytree. @@ -176,6 +141,24 @@ def _map_fn(leaf): ) +class NonTrainable(AbstractUnwrappable[T]): + """Applies stop gradient to all arraylike leaves before unwrapping. + + See also :func:`non_trainable`, which is probably a generally prefereable way to + achieve similar behaviour, which wraps the arraylike leaves directly, rather than + the tree. Useful to mark pytrees (arrays, submodules, etc) as frozen/non-trainable. + Note that the underlying parameters may still be impacted by regularization, + so it is generally advised to use this as a suggestively named class + for filtering parameters. + """ + + tree: T + + def unwrap(self) -> T: + differentiable, static = eqx.partition(self.tree, eqx.is_array_like) + return eqx.combine(lax.stop_gradient(differentiable), static) + + class WeightNormalization(AbstractUnwrappable[Array]): """Applies weight normalization (https://arxiv.org/abs/1602.07868). @@ -184,7 +167,7 @@ class WeightNormalization(AbstractUnwrappable[Array]): """ weight: Array | AbstractUnwrappable[Array] - scale: Array | AbstractUnwrappable[Array] = eqx.field(init=False) + scale: Array | AbstractUnwrappable[Array] def __init__(self, weight: Array | AbstractUnwrappable[Array]): self.weight = weight diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index bced911..5e88eaa 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -20,7 +20,7 @@ def test_Parameterize(): assert pytest.approx(jnp.eye(3)) == unwrap(diag) -def test_nested_Parameterized(): +def test_nested_unwrap(): param = Parameterize( jnp.square, Parameterize(jnp.square, Parameterize(jnp.square, 2)), @@ -56,7 +56,6 @@ def test_WeightNormalization(): test_cases = { "NonTrainable": lambda key: NonTrainable(jr.normal(key, 10)), "Parameterize-exp": lambda key: Parameterize(jnp.exp, jr.normal(key, 10)), - "Parameterize-diag": lambda key: Parameterize(jnp.diag, jr.normal(key, 10)), "WeightNormalization": lambda key: WeightNormalization(jr.normal(key, (10, 2))), }