Skip to content

Commit

Permalink
Now emits a warning if you're about to silently footgun yourself by a…
Browse files Browse the repository at this point in the history
…ssigning a jax-transformed layer.
  • Loading branch information
patrick-kidger committed Feb 21, 2024
1 parent 168d33d commit a61ff33
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 0 deletions.
90 changes: 90 additions & 0 deletions equinox/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Any, cast, Optional, Protocol, TYPE_CHECKING, TypeVar, Union
from typing_extensions import dataclass_transform, ParamSpec

import jax
import jax._src.traceback_util as traceback_util
import jax.tree_util as jtu
import numpy as np
Expand All @@ -22,6 +23,7 @@
from ._better_abstract import ABCMeta, dataclass
from ._caches import cache_clears
from ._doc_utils import doc_repr
from ._filters import is_array_like
from ._pretty_print import tree_pformat
from ._tree import tree_equal

Expand Down Expand Up @@ -506,6 +508,10 @@ def __call__(cls, *args, **kwargs):
if _is_force_abstract[cls]:
# Any other is-abstract checks will be handled in super().__call__.
raise TypeError("Cannot instantiate abstract `equinox.Module`.")
if _has_dataclass_init[cls]:
for x in jtu.tree_leaves((args, kwargs)):
_warn_jax_transformed_function(cls, x)
# else it's handled in __setattr__, but that isn't called here.
# [Step 1] Modules are immutable -- except during construction. So defreeze
# before init.
post_init = getattr(cls, "__post_init__", None)
Expand Down Expand Up @@ -635,6 +641,89 @@ class _Initable:
pass


_transform_types = {
type(transform(lambda x: x))
for transform in (
jax.jit,
jax.grad,
jax.vmap,
jax.value_and_grad,
jax.jacfwd,
jax.jacrev,
jax.hessian,
jax.custom_jvp,
jax.custom_vjp,
jax.checkpoint, # pyright: ignore
jax.pmap,
)
}


def _warn_jax_transformed_function(cls, x):
# not `isinstance`, just in case JAX every tries to override `__instancecheck__`.
if type(x) in _transform_types:

class _JaxTransformException(Exception):
pass

def _is_array_like(x):
if is_array_like(x):
raise _JaxTransformException

while True:
try:
x = x.__wrapped__
except AttributeError:
break
try:
jtu.tree_map(_is_array_like, x)
except _JaxTransformException:
warnings.warn(
f"""
Possibly assigning a JAX-transformed callable as an attribute on
{cls.__module__}.{cls.__qualname__}. This will not have any of its parameters updated.
For example, the following code is buggy:
```python
class MyModule(eqx.Module):
vmap_linear: Callable
def __init__(self, ...):
self.vmap_linear = jax.vmap(eqx.nn.Linear(...))
def __call__(self, ...):
... = self.vmap_linear(...)
```
This is because the callable returned from `jax.vmap` is *not* a PyTree. This means that
the parameters inside the `eqx.nn.Linear` layer will not receive gradient updates.
You can most easily fix this either by applying the wrapper at `__call__` time:
```python
class MyModule(eqx.Module):
linear: Callable
def __init__(self, ...):
self.linear = eqx.nn.Linear(...)
def __call__(self, ...):
... = jax.vmap(self.linear)(...)
```
or by using `eqx.filter_vmap` instead (which *does* return a PyTree):
```python
class MyModule(eqx.Module):
vmap_linear: Callable
def __init__(self, ...):
self.vmap_linear = eqx.filter_vmap(eqx.nn.Linear(...))
def __call__(self, ...):
... = self.vmap_linear(...)
```
"""
)
break


@ft.lru_cache(maxsize=128)
def _make_initable(cls: _ModuleMeta, init, post_init, wraps: bool) -> _ModuleMeta:
# Used as part of the key. Don't cache if these have changed.
Expand Down Expand Up @@ -688,6 +777,7 @@ def bar(self):
"""
)
else:
_warn_jax_transformed_function(type(self), value)
object.__setattr__(self, name, value)
else:
raise AttributeError(f"Cannot set attribute {name}")
Expand Down
32 changes: 32 additions & 0 deletions tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,3 +1085,35 @@ def g(self):
assert Foo.g is g # pyright: ignore
assert type(Foo.__dict__["f"]).__name__ == "_wrap_method"
assert type(Foo.__dict__["g"]).__name__ == "_wrap_method"


# See https://github.com/patrick-kidger/equinox/issues/206
def test_jax_transform_warn(getkey):
class A(eqx.Module):
linear: Callable

class B(eqx.Module):
linear: Callable

def __init__(self, linear):
self.linear = linear

for cls in (A, B):
for transform in (
jax.jit,
jax.grad,
jax.vmap,
jax.value_and_grad,
jax.jacfwd,
jax.jacrev,
jax.hessian,
jax.custom_jvp,
jax.custom_vjp,
jax.checkpoint, # pyright: ignore
jax.pmap,
):
with pytest.warns(
match="Possibly assigning a JAX-transformed callable as an attribute"
):
transformed = transform(eqx.nn.Linear(2, 2, key=getkey()))
cls(transformed)

0 comments on commit a61ff33

Please sign in to comment.