Skip to content

Commit

Permalink
Fix breaking changes introduced in JAX 0.4.36.
Browse files Browse the repository at this point in the history
See:

- jax-ml/jax#25289
- patrick-kidger/diffrax#532

The problem was that batching has now become a dynamic trace, and our batching rules were not set up to handle the case that every batch axis is `not_mapped`.
  • Loading branch information
patrick-kidger committed Dec 7, 2024
1 parent 336a347 commit f09bacd
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 19 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ jobs:
run-test:
strategy:
matrix:
python-version: [ 3.9, 3.11 ]
# must match the `language_version` in `.pre-commit-config.yaml`
python-version: [ 3.11 ]
os: [ ubuntu-latest ]
fail-fast: false
runs-on: ${{ matrix.os }}
Expand Down
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ repos:
rev: v1.1.379
hooks:
- id: pyright
# must match the Python version used in CI
language_version: python3.11
additional_dependencies:
[
beartype,
Expand Down
23 changes: 16 additions & 7 deletions equinox/internal/_loop/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,26 @@ def _select_if_vmap_batch(axis_size, axis_name, trace, inputs, batch_axes):
bp, bx, by = batch_axes
if bp is batching.not_mapped:
if bx is batching.not_mapped:
x = jnp.broadcast_to(x, (axis_size,) + x.shape)
else:
x = jnp.moveaxis(x, bx, 0)
if by is batching.not_mapped:
y = jnp.broadcast_to(y, (axis_size,) + y.shape)
if by is batching.not_mapped:
out_axis = None
else:
x = jnp.broadcast_to(x, (axis_size,) + x.shape)
y = jnp.moveaxis(y, by, 0)
out_axis = 0
else:
y = jnp.moveaxis(y, by, 0)
if by is batching.not_mapped:
x = jnp.moveaxis(x, bx, 0)
y = jnp.broadcast_to(y, (axis_size,) + y.shape)
out_axis = 0
else:
x = jnp.moveaxis(x, bx, 0)
y = jnp.moveaxis(y, by, 0)
out_axis = 0
out = _select_if_vmap(pred, x, y, makes_false_steps=False)
else:
out = jax.vmap(lax.select, in_axes=(bp, bx, by))(pred, x, y)
return out, 0
out_axis = 0
return out, out_axis


select_if_vmap_p = jax.core.Primitive("select_if_vmap")
Expand Down
24 changes: 14 additions & 10 deletions equinox/internal/_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,16 +306,20 @@ def create_vprim(name: str, impl, abstract_eval, jvp, transpose):

def batch_rule(axis_size, axis_name, trace_type, inputs, batch_axes, **params):
del trace_type
# delegates batching to `_vprim_p`
out = _vprim_p.bind(
*inputs,
prim=prim,
__axis_size=axis_size,
__axis_name=axis_name,
__batch_axes=batch_axes,
params=params,
)
batch_axes_out = jtu.tree_map(lambda _: 0, out)
if all(b is batching.not_mapped for b in jtu.tree_leaves(batch_axes)):
out = prim.bind(*inputs, **params)
batch_axes_out = jtu.tree_map(lambda _: batching.not_mapped, out)
else:
# delegates batching to `_vprim_p`
out = _vprim_p.bind(
*inputs,
prim=prim,
__axis_size=axis_size,
__axis_name=axis_name,
__batch_axes=batch_axes,
params=params,
)
batch_axes_out = jtu.tree_map(lambda _: 0, out)
return out, batch_axes_out

prim.def_impl(impl)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "equinox"
version = "0.11.9"
version = "0.11.10"
description = "Elegant easy-to-use neural networks in JAX."
readme = "README.md"
requires-python =">=3.9"
Expand Down

0 comments on commit f09bacd

Please sign in to comment.