diff --git a/optimistix/_fixed_point.py b/optimistix/_fixed_point.py index 7ff789a..9daa657 100644 --- a/optimistix/_fixed_point.py +++ b/optimistix/_fixed_point.py @@ -29,6 +29,11 @@ def _rewrite_fn(fixed_point, _, inputs): return (f_val**ω - fixed_point**ω).ω +# Keep `optx.implicit_jvp` is happy. +if _rewrite_fn.__globals__["__name__"].startswith("jaxtyping"): + _rewrite_fn = _rewrite_fn.__wrapped__ # pyright: ignore[reportFunctionMemberAccess] + + class _ToRootFn(eqx.Module, Generic[Y, Aux]): fixed_point_fn: Fn[Y, Y, Aux] diff --git a/optimistix/_iterate.py b/optimistix/_iterate.py index 832f689..c1b06cf 100644 --- a/optimistix/_iterate.py +++ b/optimistix/_iterate.py @@ -261,6 +261,11 @@ def buffers(carry): ) +# Keep `optx.implicit_jvp` is happy. +if _iterate.__globals__["__name__"].startswith("jaxtyping"): + _iterate = _iterate.__wrapped__ # pyright: ignore[reportFunctionMemberAccess] + + def iterative_solve( fn: Fn[Y, Out, Aux], # no type parameters, see https://github.com/microsoft/pyright/discussions/5599 diff --git a/optimistix/_least_squares.py b/optimistix/_least_squares.py index cec307a..bdad3e4 100644 --- a/optimistix/_least_squares.py +++ b/optimistix/_least_squares.py @@ -30,6 +30,11 @@ def objective(_optimum): return jax.grad(objective)(optimum) +# Keep `optx.implicit_jvp` is happy. +if _rewrite_fn.__globals__["__name__"].startswith("jaxtyping"): + _rewrite_fn = _rewrite_fn.__wrapped__ # pyright: ignore[reportFunctionMemberAccess] + + class _ToMinimiseFn(eqx.Module, Generic[Y, Out, Aux]): residual_fn: Fn[Y, Out, Aux] diff --git a/optimistix/_minimise.py b/optimistix/_minimise.py index f66faa0..31b67e5 100644 --- a/optimistix/_minimise.py +++ b/optimistix/_minimise.py @@ -30,6 +30,11 @@ def min_no_aux(x): return jax.grad(min_no_aux)(minimum) +# Keep `optx.implicit_jvp` is happy. +if _rewrite_fn.__globals__["__name__"].startswith("jaxtyping"): + _rewrite_fn = _rewrite_fn.__wrapped__ # pyright: ignore[reportFunctionMemberAccess] + + @eqx.filter_jit def minimise( fn: MaybeAuxFn[Y, Scalar, Aux], diff --git a/optimistix/_root_find.py b/optimistix/_root_find.py index 8bcbdb7..fdadeec 100644 --- a/optimistix/_root_find.py +++ b/optimistix/_root_find.py @@ -28,6 +28,11 @@ def _rewrite_fn(root, _, inputs): return f_val +# Keep `optx.implicit_jvp` is happy. +if _rewrite_fn.__globals__["__name__"].startswith("jaxtyping"): + _rewrite_fn = _rewrite_fn.__wrapped__ # pyright: ignore[reportFunctionMemberAccess] + + def _to_minimise_fn(root_fn, norm, y, args): root, aux = root_fn(y, args) return norm(root), (root, aux)