Skip to content

Commit

Permalink
Fixed spurious global function error
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Jun 13, 2024
1 parent ce61c91 commit fb63b4c
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 0 deletions.
5 changes: 5 additions & 0 deletions optimistix/_fixed_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ def _rewrite_fn(fixed_point, _, inputs):
return (f_val**ω - fixed_point**ω).ω


# Keep `optx.implicit_jvp` is happy.
if _rewrite_fn.__globals__["__name__"].startswith("jaxtyping"):
_rewrite_fn = _rewrite_fn.__wrapped__ # pyright: ignore[reportFunctionMemberAccess]


class _ToRootFn(eqx.Module, Generic[Y, Aux]):
fixed_point_fn: Fn[Y, Y, Aux]

Expand Down
5 changes: 5 additions & 0 deletions optimistix/_iterate.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,11 @@ def buffers(carry):
)


# Keep `optx.implicit_jvp` is happy.
if _iterate.__globals__["__name__"].startswith("jaxtyping"):
_iterate = _iterate.__wrapped__ # pyright: ignore[reportFunctionMemberAccess]


def iterative_solve(
fn: Fn[Y, Out, Aux],
# no type parameters, see https://github.com/microsoft/pyright/discussions/5599
Expand Down
5 changes: 5 additions & 0 deletions optimistix/_least_squares.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ def objective(_optimum):
return jax.grad(objective)(optimum)


# Keep `optx.implicit_jvp` is happy.
if _rewrite_fn.__globals__["__name__"].startswith("jaxtyping"):
_rewrite_fn = _rewrite_fn.__wrapped__ # pyright: ignore[reportFunctionMemberAccess]


class _ToMinimiseFn(eqx.Module, Generic[Y, Out, Aux]):
residual_fn: Fn[Y, Out, Aux]

Expand Down
5 changes: 5 additions & 0 deletions optimistix/_minimise.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ def min_no_aux(x):
return jax.grad(min_no_aux)(minimum)


# Keep `optx.implicit_jvp` is happy.
if _rewrite_fn.__globals__["__name__"].startswith("jaxtyping"):
_rewrite_fn = _rewrite_fn.__wrapped__ # pyright: ignore[reportFunctionMemberAccess]


@eqx.filter_jit
def minimise(
fn: MaybeAuxFn[Y, Scalar, Aux],
Expand Down
5 changes: 5 additions & 0 deletions optimistix/_root_find.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ def _rewrite_fn(root, _, inputs):
return f_val


# Keep `optx.implicit_jvp` is happy.
if _rewrite_fn.__globals__["__name__"].startswith("jaxtyping"):
_rewrite_fn = _rewrite_fn.__wrapped__ # pyright: ignore[reportFunctionMemberAccess]


def _to_minimise_fn(root_fn, norm, y, args):
root, aux = root_fn(y, args)
return norm(root), (root, aux)
Expand Down

0 comments on commit fb63b4c

Please sign in to comment.