diff --git a/equinox/_ad.py b/equinox/_ad.py index f2f1db8f..3b9a2825 100644 --- a/equinox/_ad.py +++ b/equinox/_ad.py @@ -344,12 +344,13 @@ def _fn(*_flat_dynamic): _in = combine(_dynamic, static_primals) _out = fn(*_in, **kwargs) _dynamic_out, _static_out = partition(_out, _is_jvp_tracer(_main)) - return _dynamic_out, Static(_static_out) + _arr_out, _non_arr_out = partition(_static_out, is_array) + return _dynamic_out, _arr_out, Static(_non_arr_out) primal_out, tangent_out = jax.jvp(_fn, flat_dynamic_primals, flat_tangents) - dynamic_primal_out, static_primal_out = primal_out - primal_out = combine(dynamic_primal_out, static_primal_out.value) - tangent_out, _ = tangent_out + dynamic_primal_out, arr_primal_out, static_primal_out = primal_out + primal_out = combine(dynamic_primal_out, arr_primal_out, static_primal_out.value) + tangent_out, _, _ = tangent_out return primal_out, tangent_out diff --git a/equinox/_module.py b/equinox/_module.py index 7e0fe5f0..de5b7f1f 100644 --- a/equinox/_module.py +++ b/equinox/_module.py @@ -24,7 +24,7 @@ from ._better_abstract import ABCMeta, dataclass from ._caches import cache_clears from ._doc_utils import doc_repr -from ._filters import is_array_like +from ._filters import is_array, is_array_like from ._pretty_print import tree_pformat from ._tree import tree_equal @@ -584,6 +584,19 @@ def __call__(cls, *args, **kwargs): f"The following fields were not initialised during __init__: " f"{missing_names}" ) + # [Step 3.5] Prevent arrays from being marked as static + for field in dataclasses.fields(self): + if field.metadata.get("static", False): + if any( + jtu.tree_map( + is_array, jtu.tree_flatten(getattr(self, field.name))[0] + ) + ): + warnings.warn( + "A JAX array is being set as static! This can result " + "in unexpected behavior and is usually a mistake to do.", + stacklevel=2, + ) # Freeze. object.__setattr__(self, "__class__", cls) # [Step 4] Run any custom validators. (After freezing; as they run diff --git a/equinox/_vmap_pmap.py b/equinox/_vmap_pmap.py index 216619f7..6da7b0cc 100644 --- a/equinox/_vmap_pmap.py +++ b/equinox/_vmap_pmap.py @@ -206,25 +206,28 @@ def _fun_wrapper(_dynamic_args): _out_axes = _resolve_axes(_out, _out_axes) _none_axes = jtu.tree_map(_is_none, _out_axes, is_leaf=_is_none) _nonvmapd, _vmapd = partition(_out, _none_axes, is_leaf=_is_none) - return _vmapd, Static((_nonvmapd, _out_axes)) + _nonvmapd_arr, _nonvmapd_static = partition(_nonvmapd, is_array) + return _vmapd, _nonvmapd_arr, Static((_nonvmapd_static, _out_axes)) if len(jtu.tree_leaves(in_axes)) == 0 and self._axis_size is None: - vmapd, static = _fun_wrapper(dynamic_args) + vmapd, nonvmapd_arr, static = _fun_wrapper(dynamic_args) if len(jtu.tree_leaves(vmapd)) != 0: raise ValueError( "Cannot resolve batch dimension. Non-`None` `out_axes` requires " "either `in_axes` or `axis_size` to be not `None`." ) else: - vmapd, static = jax.vmap( + vmapd, nonvmapd_arr, static = jax.vmap( _fun_wrapper, in_axes=(in_axes,), - out_axes=(0, None), + out_axes=(0, None, None), axis_name=self._axis_name, axis_size=self._axis_size, **self._vmapkwargs, )(dynamic_args) - nonvmapd, out_axes = static.value + + nonvmapd_static, out_axes = static.value + nonvmapd = combine(nonvmapd_arr, nonvmapd_static) assert jtu.tree_structure(vmapd) == jtu.tree_structure(out_axes) vmapd = jtu.tree_map(_swapaxes, vmapd, out_axes) diff --git a/tests/test_module.py b/tests/test_module.py index ff8f5e01..76996e36 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -1141,3 +1141,31 @@ class Foo3(eqx.Module): assert Foo1.__init__.__annotations__["x"] is Any assert Foo2.__init__.__annotations__["x"] is int assert Foo3.__init__.__annotations__["x"] is bool + + +def test_no_jax_array_static(): + class Valid(eqx.Module): + a: tuple + b: jax.Array + + class InvalidTuple(eqx.Module): + a: tuple = eqx.field(static=True) + b: jax.Array + + class InvalidArr(eqx.Module): + a: tuple + b: jax.Array = eqx.field(static=True) + + Valid((), jnp.ones(2)) + + with pytest.warns( + UserWarning, + match="A JAX array is being set as static!", + ): + InvalidTuple((jnp.ones(10),), jnp.ones(10)) + + with pytest.warns( + UserWarning, + match="A JAX array is being set as static!", + ): + InvalidArr((), jnp.ones(10))