Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/lockwo/equinox into filter-…
Browse files Browse the repository at this point in the history
…hessian
  • Loading branch information
lockwo committed Mar 4, 2024
2 parents 508806f + 1e60167 commit 83b376b
Show file tree
Hide file tree
Showing 147 changed files with 11,199 additions and 2,445 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
build:
strategy:
matrix:
python-version: [ 3.8 ]
python-version: [ 3.11 ]
os: [ ubuntu-latest ]
runs-on: ${{ matrix.os }}
steps:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ jobs:
run-test:
strategy:
matrix:
python-version: [ 3.8, 3.9 ]
python-version: [ 3.9, 3.11 ]
os: [ ubuntu-latest ]
fail-fast: false
runs-on: ${{ matrix.os }}
Expand Down
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ build/
dist/
site/
examples/data
examples/CIFAR
.all_objects.cache
.pymon
.idea
examples/MNIST
examples/MNIST
examples/multipart_serialised.eqx
.python-version
29 changes: 10 additions & 19 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,23 +1,14 @@
repos:
- repo: https://github.com/ambv/black
rev: 22.3.0
hooks:
- id: black
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: 'v0.0.255'
hooks:
- id: ruff
args: ["--fix"]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.7
hooks:
- id: ruff # linter
types_or: [ python, pyi, jupyter ]
args: [ --fix ]
- id: ruff-format # formatter
types_or: [ python, pyi, jupyter ]
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.302
rev: v1.1.315
hooks:
- id: pyright
additional_dependencies: [beartype, einops, jax, jaxtyping, pytest, tensorflow, tf2onnx, typing_extensions]
- repo: https://github.com/nbQA-dev/nbQA
rev: 1.6.3
hooks:
- id: nbqa-black
additional_dependencies: [ipython==8.12, black]
- id: nbqa-ruff
args: ["--ignore=I001"]
additional_dependencies: [ipython==8.12, ruff]
additional_dependencies: [beartype, einops, jax, jaxtyping, optax, pytest, tensorflow, tf2onnx, typing_extensions]
45 changes: 31 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
<h1 align='center'>Equinox</h1>

Equinox is a JAX library for parameterised functions (e.g. neural networks) offering:
Equinox is your one-stop [JAX](https://github.com/google/jax) library, for everything you need that isn't already in core JAX:

- a PyTorch-like API...
- ...that's fully compatible with *native* JAX transformations...
- ...with no new concepts you have to learn.
- neural networks (or more generally any model), with easy-to-use PyTorch-like syntax;
- filtered APIs for transformations;
- useful PyTree manipulation routines;
- advanced features like runtime errors;

If you're completely new to JAX, then start with this [CNN on MNIST example](https://docs.kidger.site/equinox/examples/mnist/).
If you're already familiar with JAX, then the main idea is to represent parameterised functions (such as neural networks) as PyTrees, so that they can pass across JIT/grad/etc. boundaries smoothly.
and best of all, Equinox isn't a framework: everything you write in Equinox is compatible with anything else in JAX or the ecosystem.

The elegance of Equinox is its selling point in a world that already has [Haiku](https://github.com/deepmind/dm-haiku), [Flax](https://github.com/google/flax) and so on.
If you're completely new to JAX, then start with this [CNN on MNIST example](https://docs.kidger.site/equinox/examples/mnist/).

_In other words, why should you care? Because Equinox is really simple to learn, and really simple to use._
_Coming from [Flax](https://github.com/google/flax) or [Haiku](https://github.com/deepmind/haiku)? The main difference is that Equinox (a) offers a lot of advanced features not found in these libraries, like PyTree manipulation or runtime errors; (b) has a simpler way of building models: they're just PyTrees, so they can pass across JIT/grad/etc. boundaries smoothly._

## Installation

```bash
pip install equinox
```

Requires Python 3.8+ and JAX 0.4.4+.
Requires Python 3.9+ and JAX 0.4.13+.

## Documentation

Expand Down Expand Up @@ -79,13 +79,30 @@ If you found this library to be useful in academic work, then please cite: ([arX

(Also consider starring the project on GitHub.)

## See also: other libraries in the JAX ecosystem

## See also
[jaxtyping](https://github.com/google/jaxtyping): type annotations for shape/dtype of arrays.

Numerical differential equation solvers: [Diffrax](https://github.com/patrick-kidger/diffrax).
[Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers.

Type annotations and runtime checking for PyTrees and shape/dtype of JAX arrays: [jaxtyping](https://github.com/google/jaxtyping).
[Diffrax](https://github.com/patrick-kidger/diffrax): numerical differential equation solvers.

Computer vision models: [Eqxvision](https://github.com/paganpasta/eqxvision).
[Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares.

SymPy<->JAX conversion; train symbolic expressions via gradient descent: [sympy2jax](https://github.com/google/sympy2jax).
[Lineax](https://github.com/google/lineax): linear solvers.

[BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling.

[Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device).

[sympy2jax](https://github.com/google/sympy2jax): SymPy<->JAX conversion; train symbolic expressions via gradient descent.

[Eqxvision](https://github.com/paganpasta/eqxvision): computer vision models.

[Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs).

[PySR](https://github.com/milesCranmer/PySR): symbolic regression. (Non-JAX honourable mention!)

## Disclaimer

Equinox is maintained by Patrick Kidger at Google X, but this is not an official Google product.
95 changes: 47 additions & 48 deletions docs/all-of-equinox.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,80 +2,76 @@

Equinox is a small and easy to understand library. So as the title suggests, this page tells you essentially everything you need to know to use Equinox.

## Parameterised functions as PyTrees
## 1. Models as PyTrees

As we saw on the [Getting Started](./index.md) page, Equinox represents parameterised functions as [PyTrees](https://jax.readthedocs.io/en/latest/pytrees.html).
!!! info "What's a PyTree?"

!!! example
[PyTrees](https://jax.readthedocs.io/en/latest/pytrees.html) are what JAX calls nested collections of tuples, lists, and dicts. (And any custom-registered PyTree nodes.) The "leaves" of the tree can be anything at all: JAX/NumPy arrays, floats, functions, etc. Most JAX operations will accept either (a) arbitrary PyTrees; (b) PyTrees with just JAX/NumPy arrays as the leaves; (c) PyTrees without any JAX/NumPy arrays as the leaves.

A neural network is a function parameterised by its weights, biases, etc.
As we saw on the [Getting Started](./index.md) page, Equinox offers the ability to represents models as PyTrees. This is one of Equinox's main features.

But you can use Equinox to represent any kind of parameterised function! For example [Diffrax](http://github.com/patrick-kidger/diffrax) uses Equinox to represent numerical differential equation solvers.

And now you can JIT/grad/etc. with respect to your model. For example, using a few built-in layers by way of demonstration:
Once we've done so, we'll be able to JIT/grad/etc. with respect to the model. For example, using a few built-in layers by way of demonstration, here's a small neural network:

```python
import equinox as eqx
import jax

class MyModule(eqx.Module):
class NeuralNetwork(eqx.Module):
layers: list
extra_bias: jax.Array

def __init__(self, key):
key1, key2, key3 = jax.random.split(key, 3)
# These contain trainable parameters.
self.layers = [eqx.nn.Linear(2, 8, key=key1),
eqx.nn.Linear(8, 8, key=key2),
eqx.nn.Linear(8, 2, key=key3)]
# This is a trainable parameter.
# This is also a trainable parameter.
self.extra_bias = jax.numpy.ones(2)

def __call__(self, x):
for layer in self.layers[:-1]:
x = jax.nn.relu(layer(x))
return self.layers[-1](x) + self.extra_bias

@jax.jit
@jax.grad
@jax.jit # compile this function to make it run fast.
@jax.grad # differentiate all floating-point arrays in `model`.
def loss(model, x, y):
pred_y = jax.vmap(model)(x)
return jax.numpy.mean((y - pred_y) ** 2)
pred_y = jax.vmap(model)(x) # vectorise the model over a batch of data
return jax.numpy.mean((y - pred_y) ** 2) # L2 loss

x_key, y_key, model_key = jax.random.split(jax.random.PRNGKey(0), 3)
# Example data
x = jax.random.normal(x_key, (100, 2))
y = jax.random.normal(y_key, (100, 2))
model = MyModule(model_key)
model = NeuralNetwork(model_key)
# Compute gradients
grads = loss(model, x, y)
# Perform gradient descent
learning_rate = 0.1
model = jax.tree_util.tree_map(lambda m, g: m - learning_rate * g, model, grads)
new_model = jax.tree_util.tree_map(lambda m, g: m - learning_rate * g, model, grads)
```

## Filtering

In the previous example, all of the model attributes were `Module`s and JAX arrays. To be precise: the overall model was a PyTree of JAX arrays.

Equinox supports using arbitrary Python objects too. (That is, the model is a PyTree of arbitrary Python objects, which may or may not include JAX arrays.) Equinox offers the tools to handle these appropriately around transforms like `jax.jit` and `jax.grad`. (Which themselves only work with JAX arrays.)
In this example, `model = NeuralNetwork(...)` is the overall PyTree. Nested within that is `model.layers` and `model.extra_bias`. The former is also a PyTree, containing three `eqx.nn.Linear` layers at `model.layers[0]`, `model.layers[1]`, and `model.layers[2]`. Each of these are also PyTrees, containing their weights and biases, e.g. `model.layers[0].weight`.

!!! example
## 2. Filtering

The activation function in [`equinox.nn.MLP`][] isn't a JAX array -- it's a Python function.
In the previous example, all of the leaves were JAX arrays. This made things simple, because `jax.jit` and `jax.grad`-decorated functions require that all of their inputs are PyTrees of arrays.

!!! example
Equinox goes further, and supports using arbitrary Python objects for its leaves. For example, we might like to make our activation function part of the PyTree (rather than just hardcoding it as above). The activation function will just be some arbitrary Python function, and this isn't an array. Another common example is having a `bool`-ean flag in your model, which specifies whether to enable some extra piece of behaviour.

You might have a `bool`-ean flag in your model-as-a-PyTree, specifying whether to enable some extra piece of behaviour. You might want to treat that as a `static_argnum` to `jax.jit`.

If you want to do this, then Equinox offers *filtering*, as follows.
To support this, then Equinox offers *filtering*, as follows.

**Create a model**

Start off by creating a model just like normal, now with some arbitrary Python objects as part of its parameterisation. In this case, we have `jax.nn.relu`, which is a Python function.
Start off by creating a model just like normal, now with some arbitrary Python objects as part of its PyTree structure. In this case, we have `jax.nn.relu`, which is a Python function.

```python
import equinox as eqx
import functools as ft
import jax

class AnotherModule(eqx.Module):
class NeuralNetwork2(eqx.Module):
layers: list

def __init__(self, key):
Expand All @@ -91,14 +87,14 @@ class AnotherModule(eqx.Module):

x_key, y_key, model_key = jax.random.split(jax.random.PRNGKey(0), 3)
x, y = jax.random.normal(x_key, (100, 2)), jax.random.normal(y_key, (100, 2))
model = AnotherModule(model_key)
model = NeuralNetwork2(model_key)
```

**Option 1: use `eqx.{partition,combine}`**

```python
@ft.partial(jax.jit, static_argnums=1)
@jax.grad
@ft.partial(jax.jit, static_argnums=1) # `static` must be a PyTree of non-arrays.
@jax.grad # differentiates with respect to `params`, as it is the first argument
def loss(params, static, x, y):
model = eqx.combine(params, static)
pred_y = jax.vmap(model)(x)
Expand All @@ -108,7 +104,7 @@ params, static = eqx.partition(model, eqx.is_array)
loss(params, static, x, y)
```

Here, `params` and `static` are both instances of `AnotherModule`: `params` keeps just the leaves that are JAX arrays; `static` keeps everything else. Then `combine` merges the two PyTrees back together after crossing the `jax.jit` and `jax.grad` API boundaries.
Here, we split our model PyTree into two pieces. `params` and `static` are both instances of `NeuralNetwork2`. `params` keeps just the leaves that are arrays; `static` keeps everything else. Then `combine` merges the two PyTrees back together after crossing the `jax.jit` and `jax.grad` API boundaries.

The choice of `eqx.is_array` is a *filter function*: a boolean function specifying whether each leaf should go into `params` or into `static`. In this case very simply `eqx.is_array(x)` returns `True` for JAX and NumPy arrays, and `False` for everything else.

Expand All @@ -128,31 +124,34 @@ As a convenience, `eqx.filter_jit` and `eqx.filter_grad` wrap filtering and tran

If your models only use JAX arrays, then `eqx.filter_{jit,grad,...}` will do exactly the same as `jax.{jit,grad,...}`. So if you just want to keep things simple, it is safe to just always use `eqx.filter_{jit,grad,...}`.

Both approaches are equally valid. Some people prefer to explicitly see the `jax.{jit,grad,...}` operations without using a wrapper. Some people prefer a shorter syntax instead.
Both approaches are equally valid. Some people prefer the shorter syntax of the filtered transformations. Some people prefer to explicitly see the `jax.{jit,grad,...}` operations directly.

## Integrates smoothly with JAX
## 3. PyTree manipulation routines.

Equinox introduces a powerful yet straightforward way to build neural networks, without introducing lots of new notions or tieing you into a framework.
Equinox clearly places a heavy focus on PyTrees! As such, it's quite common to need to perform operations on PyTrees. Whilst many common operations are already provided by JAX (for example, `jax.tree_util.tree_map` will apply an operation to every leaf of a PyTree), Equinox additionally offers some extra features. For example, `eqx.tree_at` mutates a particular leaf or leaves of a PyTree.

Equinox is all just regular JAX -- PyTrees and transformations. Together, these two pieces allow us to specify complex models in JAX-friendly ways.
## 4. Advanced goodies.

## Summary
Finally, Equinox offers a number of more advanced goodies, like serialisation, debugging tools, and runtime errors. We won't discuss them here, but check out the API reference on the left.

Equinox includes four main things:
## 5. Summary

- For building models: `equinox.Module`.
- Prebuilt neural network layers: `equinox.nn.Linear`, `equinox.nn.Conv2d`, etc.
- Filtering, and filtered transformations: `equinox.partition`, `equinox.filter_jit` etc.
- Some utilities to help manipulate PyTrees: `equinox.tree_at` etc.
**Equinox integrates smoothly with JAX**

See also the API reference on the left.
Equinox introduces a powerful yet straightforward way to build neural networks, without introducing lots of new notions or tieing you into a framework. Indeed Equinox is a *library*, not a *framework* -- this means that anything you write in Equinox is fully compatible with anything else in the JAX ecosystem.

## Next steps
Equinox is all just regular JAX: PyTrees and transformations. Together, these two pieces allow us to specify complex models in JAX-friendly ways.

And that's it! That's pretty much everything you need to know about Equinox. Everything you've seen so far should be enough to get started with using the library. Also see the [Train RNN](./examples/train_rnn.ipynb) example for a fully worked example.
**API reference**

!!! faq "FAQ"
- For building models: [`equinox.Module`][].
- Prebuilt neural network layers: [`equinox.nn.Linear`][], [`equinox.nn.Conv2d`][], etc.
- Filtered transformations: [`equinox.filter_jit`][] etc.
- Tools for PyTree manipulation: [`equinox.partition`][], etc.
- Advanced goodies: serialisation, debugging tools, runtime errors, etc.

One common question: a lot of other libraries introduce custom `library.jit` etc. operations, specifically to work with `library.Module`. What makes the filtered transformations of Equinox different?
See the API reference on the left.

The answer is that filtered transformations and `eqx.Module` are not coupled together; they are independent tools. Filtered transformations work with any PyTree. And `eqx.Module`s just happens to be a PyTree.
**Next steps**

And that's it! That's pretty much everything you need to know about Equinox. Everything you've seen so far should be enough to get started with using the library. Also see the [Train RNN](./examples/train_rnn.ipynb) example for a fully worked example.
3 changes: 3 additions & 0 deletions docs/api/caches.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Clear caches

::: equinox.clear_caches
49 changes: 49 additions & 0 deletions docs/api/debug.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Debugging tools

Both Equinox and JAX provide a number of debugging tools.

## Common sources of NaNs

A common source of NaNs on the forward pass is calling `jnp.log` or `jnp.sqrt` on a negative number, or when dividing by zero. If you get a NaN whilst using these operations, check their inputs carefully (e.g. with `jax.debug.print`).

A common source of NaNs when backpropagating is when using one of the above operations with a `jnp.where`, for example `y = jnp.where(x > 0, jnp.log(x), 0)`. In this case the NaN is created on the forward pass, but is then masked by the `jnp.where`. Unfortunately, when backpropagating, the order of the `log` and the `where` is flipped -- and the NaN is no longer masked! The solution is to use the "double where" trick: bracket your computation by a `where` on both sides. For this example, `safe_x = jnp.where(x > 0, x, 1); y = jnp.where(x > 0, jnp.log(safe_x), 0)`. This ensures that the NaN is never created in the first place at all.

## Debugging runtime errors

If you are getting a runtime error from [`equinox.error_if`][], then you can control the on-error behaviour via the environment variable `EQX_ON_ERROR`. In particular, setting `EQX_ON_ERROR=breakpoint` will open a `jax.debug.breakpoint` where the error arises. See the [runtime errors](./errors.md) for more information and for other values of this environment variable.

If ran from `jax.jit`, then the [`equinox.error_if`][] error will be a long error message starting `INTERNAL: Generated function failed: CpuCallback error: RuntimeError: ...`. You may prefer to use `eqx.filter_jit`, which will remove some of the extra boilerplate from the error message.

## JAX tools

JAX itself provides the following tools:

- the `jax.debug.print` function, for printing results under JIT.
- the `jax.debug.breakpoint` function, for opening a debugger under JIT.
- the `JAX_DEBUG_NANS=1` environment variable, for halting the computation once a NaN is encountered. This works best for NaNs encountered on the forward pass and outside of loops. If your NaN occurs on the backward pass only, then try [`equinox.debug.backward_nan`][] below. If the NaN occurs inside of a loop, then consider pairing this with `JAX_DISABLE_JIT=1`. (Many loops are implicitly jit'd.)
- the `JAX_DISABLE_JIT=1` environment variable, for running the computation without JIT. This will be *much* slower, so this isn't always practical.
- the `JAX_TRACEBACK_FILTERING=off` environment variable, which means errors and debuggers will include JAX and Equinox internals. (Which by default are filtered out.)

## Equinox tools

::: equinox.debug.announce_transform

---

::: equinox.debug.backward_nan

---

::: equinox.debug.breakpoint_if

---

::: equinox.debug.store_dce

::: equinox.debug.inspect_dce

---

::: equinox.debug.assert_max_traces

::: equinox.debug.get_num_traces
7 changes: 7 additions & 0 deletions docs/api/enumerations.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Enumerations

::: equinox.Enumeration
selection:
members:
- where
- promote
Loading

0 comments on commit 83b376b

Please sign in to comment.