diff --git a/equinox/internal/_noinline.py b/equinox/internal/_noinline.py index c3269206..0e130ec7 100644 --- a/equinox/internal/_noinline.py +++ b/equinox/internal/_noinline.py @@ -385,7 +385,7 @@ def _noinline_mlir(ctx, *dynamic, treedef, static, flatten, **kwargs): class _NoInlineWrapper(Module): - dynamic_index: Int[Array, ""] + dynamic_index: Int[Union[Array, np.ndarray], ""] abstract_fn: Callable = field(static=True) dynamic_fn: Any @@ -502,6 +502,6 @@ def abstract_fn(__dynamic_fn, *args, **kwargs): dynamic_index = len(_index_to_fn) _fn_to_index[static_fn] = dynamic_index _index_to_fn.append(static_fn) - dynamic_index = jnp.array(dynamic_index) + dynamic_index = np.array(dynamic_index) noinline_fn = _NoInlineWrapper(dynamic_index, abstract_fn, dynamic_fn) return module_update_wrapper(noinline_fn)