Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable static arrays #800

Merged
merged 5 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions equinox/_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,12 +344,13 @@ def _fn(*_flat_dynamic):
_in = combine(_dynamic, static_primals)
_out = fn(*_in, **kwargs)
_dynamic_out, _static_out = partition(_out, _is_jvp_tracer(_main))
return _dynamic_out, Static(_static_out)
_arr_out, _non_arr_out = partition(_static_out, is_array)
return _dynamic_out, _arr_out, Static(_non_arr_out)

primal_out, tangent_out = jax.jvp(_fn, flat_dynamic_primals, flat_tangents)
dynamic_primal_out, static_primal_out = primal_out
primal_out = combine(dynamic_primal_out, static_primal_out.value)
tangent_out, _ = tangent_out
dynamic_primal_out, arr_primal_out, static_primal_out = primal_out
primal_out = combine(dynamic_primal_out, arr_primal_out, static_primal_out.value)
tangent_out, _, _ = tangent_out

return primal_out, tangent_out

Expand Down
15 changes: 14 additions & 1 deletion equinox/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ._better_abstract import ABCMeta, dataclass
from ._caches import cache_clears
from ._doc_utils import doc_repr
from ._filters import is_array_like
from ._filters import is_array, is_array_like
from ._pretty_print import tree_pformat
from ._tree import tree_equal

Expand Down Expand Up @@ -584,6 +584,19 @@ def __call__(cls, *args, **kwargs):
f"The following fields were not initialised during __init__: "
f"{missing_names}"
)
# [Step 3.5] Prevent arrays from being marked as static
for field in dataclasses.fields(self):
if field.metadata.get("static", False):
if any(
jtu.tree_map(
is_array, jtu.tree_flatten(getattr(self, field.name))[0]
)
):
warnings.warn(
"A JAX array is being set as static! This can result "
"in unexpected behavior and is usually a mistake to do.",
stacklevel=2,
)
# Freeze.
object.__setattr__(self, "__class__", cls)
# [Step 4] Run any custom validators. (After freezing; as they run
Expand Down
13 changes: 8 additions & 5 deletions equinox/_vmap_pmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,25 +206,28 @@ def _fun_wrapper(_dynamic_args):
_out_axes = _resolve_axes(_out, _out_axes)
_none_axes = jtu.tree_map(_is_none, _out_axes, is_leaf=_is_none)
_nonvmapd, _vmapd = partition(_out, _none_axes, is_leaf=_is_none)
return _vmapd, Static((_nonvmapd, _out_axes))
_nonvmapd_arr, _nonvmapd_static = partition(_nonvmapd, is_array)
return _vmapd, _nonvmapd_arr, Static((_nonvmapd_static, _out_axes))

if len(jtu.tree_leaves(in_axes)) == 0 and self._axis_size is None:
vmapd, static = _fun_wrapper(dynamic_args)
vmapd, nonvmapd_arr, static = _fun_wrapper(dynamic_args)
if len(jtu.tree_leaves(vmapd)) != 0:
raise ValueError(
"Cannot resolve batch dimension. Non-`None` `out_axes` requires "
"either `in_axes` or `axis_size` to be not `None`."
)
else:
vmapd, static = jax.vmap(
vmapd, nonvmapd_arr, static = jax.vmap(
_fun_wrapper,
in_axes=(in_axes,),
out_axes=(0, None),
out_axes=(0, None, None),
axis_name=self._axis_name,
axis_size=self._axis_size,
**self._vmapkwargs,
)(dynamic_args)
nonvmapd, out_axes = static.value

nonvmapd_static, out_axes = static.value
nonvmapd = combine(nonvmapd_arr, nonvmapd_static)

assert jtu.tree_structure(vmapd) == jtu.tree_structure(out_axes)
vmapd = jtu.tree_map(_swapaxes, vmapd, out_axes)
Expand Down
28 changes: 28 additions & 0 deletions tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,3 +1141,31 @@ class Foo3(eqx.Module):
assert Foo1.__init__.__annotations__["x"] is Any
assert Foo2.__init__.__annotations__["x"] is int
assert Foo3.__init__.__annotations__["x"] is bool


def test_no_jax_array_static():
class Valid(eqx.Module):
a: tuple
b: jax.Array

class InvalidTuple(eqx.Module):
a: tuple = eqx.field(static=True)
b: jax.Array

class InvalidArr(eqx.Module):
a: tuple
b: jax.Array = eqx.field(static=True)

Valid((), jnp.ones(2))

with pytest.warns(
UserWarning,
match="A JAX array is being set as static!",
):
InvalidTuple((jnp.ones(10),), jnp.ones(10))

with pytest.warns(
UserWarning,
match="A JAX array is being set as static!",
):
InvalidArr((), jnp.ones(10))
Loading