Skip to content

Commit

Permalink
Fixed eqx.internal.noinline at the top-level in g3
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Jan 9, 2024
1 parent 84cbd70 commit d571353
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions equinox/internal/_noinline.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def _noinline_mlir(ctx, *dynamic, treedef, static, flatten, **kwargs):


class _NoInlineWrapper(Module):
dynamic_index: Int[Array, ""]
dynamic_index: Int[Union[Array, np.ndarray], ""]
abstract_fn: Callable = field(static=True)
dynamic_fn: Any

Expand Down Expand Up @@ -502,6 +502,6 @@ def abstract_fn(__dynamic_fn, *args, **kwargs):
dynamic_index = len(_index_to_fn)
_fn_to_index[static_fn] = dynamic_index
_index_to_fn.append(static_fn)
dynamic_index = jnp.array(dynamic_index)
dynamic_index = np.array(dynamic_index)
noinline_fn = _NoInlineWrapper(dynamic_index, abstract_fn, dynamic_fn)
return module_update_wrapper(noinline_fn)

0 comments on commit d571353

Please sign in to comment.