Skip to content

Commit

Permalink
Fix AOT compilation for eqx.filter_jit with kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Dec 22, 2023
1 parent 1e2d8c2 commit b10eb85
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
2 changes: 1 addition & 1 deletion equinox/_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _preprocess(info, args, kwargs, return_static: bool = False):
dynamic_rest, static_rest = hashable_partition((rest_args, kwargs), is_array)
else:
dynamic_first = hashable_filter(first_arg, is_array)
dynamic_rest = hashable_filter(rest_args, is_array)
dynamic_rest = hashable_filter((rest_args, kwargs), is_array)
dynamic_donate = dict()
dynamic_nodonate = dict()
if donate_first:
Expand Down
13 changes: 13 additions & 0 deletions tests/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,3 +402,16 @@ def f(x, y):
lowered.as_text()
compiled = lowered.compile()
compiled(x, y)


# Issue 625
@pytest.mark.parametrize("donate", ("all", "all-except-first", "none"))
def test_aot_compilation_kwargs(donate):
def f(x, y, **kwargs):
return 2 * x + y

x, y = jnp.array(3), 4
lowered = eqx.filter_jit(f, donate=donate).lower(x, y, test=123)
lowered.as_text()
compiled = lowered.compile()
compiled(x, y, test=123)

0 comments on commit b10eb85

Please sign in to comment.