diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b710cf7..2bdccfb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,7 +22,7 @@ repos: - id: ruff-format # formatter types_or: [ python, pyi, jupyter ] - repo: https://github.com/RobertCraigie/pyright-python - rev: v1.1.330 + rev: v1.1.331 hooks: - id: pyright additional_dependencies: ["equinox", "jax", "lineax", "pytest", "optax"] diff --git a/optimistix/_iterate.py b/optimistix/_iterate.py index 834bff3..53f1b2a 100644 --- a/optimistix/_iterate.py +++ b/optimistix/_iterate.py @@ -37,20 +37,9 @@ _Node = eqxi.doc_repr(Any, "Node") -def _is_jaxpr(x): - return isinstance(x, (jax.core.Jaxpr, jax.core.ClosedJaxpr)) - - -def _is_array_or_jaxpr(x): - return _is_jaxpr(x) or eqx.is_array(x) - - class AbstractIterativeSolver(eqx.Module, Generic[Y, Out, Aux, SolverState]): """Abstract base class for all iterative solvers.""" - # Essentially every solver has an rtol+atol+norm. So for now we're just hardcoding - # that every solver must have these variables, as they're needed when using a - # minimiser or least-squares solver on a root-finding problem. rtol: AbstractVar[float] atol: AbstractVar[float] norm: AbstractVar[Callable[[PyTree], Scalar]] @@ -255,11 +244,7 @@ def body_fun(carry): new_y, new_state, aux = solver.step(fn, y, args, options, state, tags) new_dynamic_state, new_static_state = eqx.partition(new_state, eqx.is_array) - new_static_state_no_jaxpr = eqx.filter( - new_static_state, _is_jaxpr, inverse=True - ) - static_state_no_jaxpr = eqx.filter(state, _is_array_or_jaxpr, inverse=True) - assert eqx.tree_equal(static_state_no_jaxpr, new_static_state_no_jaxpr) is True + assert eqx.tree_equal(static_state, new_static_state) is True return new_y, num_steps + 1, new_dynamic_state, aux def buffers(carry): diff --git a/optimistix/_misc.py b/optimistix/_misc.py index 13e7bb2..dec6130 100644 --- a/optimistix/_misc.py +++ b/optimistix/_misc.py @@ -231,18 +231,16 @@ def _true_fun(_dynamic): _operands = eqx.combine(_dynamic, static) _out = true_fun(*_operands) _dynamic_out, _static_out = eqx.partition(_out, eqx.is_array) - _static_out = wrap_jaxpr(_static_out) return _dynamic_out, eqxi.Static(_static_out) def _false_fun(_dynamic): _operands = eqx.combine(_dynamic, static) _out = false_fun(*_operands) _dynamic_out, _static_out = eqx.partition(_out, eqx.is_array) - _static_out = wrap_jaxpr(_static_out) return _dynamic_out, eqxi.Static(_static_out) dynamic_out, static_out = lax.cond(pred, _true_fun, _false_fun, dynamic) - return eqx.combine(dynamic_out, unwrap_jaxpr(static_out.value)) + return eqx.combine(dynamic_out, static_out.value) def verbose_print(*args: tuple[bool, str, Any]) -> None: diff --git a/optimistix/_solver/gauss_newton.py b/optimistix/_solver/gauss_newton.py index c9cf14b..eda0f43 100644 --- a/optimistix/_solver/gauss_newton.py +++ b/optimistix/_solver/gauss_newton.py @@ -188,14 +188,14 @@ class AbstractGaussNewton(AbstractLeastSquaresSolver[Y, Out, Aux, _GaussNewtonSt This includes methods such as [`optimistix.GaussNewton`][], [`optimistix.LevenbergMarquardt`][], and [`optimistix.Dogleg`][]. - Subclasses must provide the following abstract attributes, with the following types: - - - `rtol: float` - - `atol: float` - - `norm: Callable[[PyTree], Scalar]` - - `descent: AbstractDescent` - - `search: AbstractSearch` - - `verbose: frozenset[str] + Subclasses must provide the following attributes, with the following types: + + - `rtol`: `float` + - `atol`: `float` + - `norm`: `Callable[[PyTree], Scalar]` + - `descent`: `AbstractDescent` + - `search`: `AbstractSearch` + - `verbose`: `frozenset[str]` """ rtol: AbstractVar[float] @@ -243,6 +243,14 @@ def step( tags: frozenset[object], ) -> tuple[Y, _GaussNewtonState, Aux]: f_eval_info, aux_eval = _make_f_info(fn, state.y_eval, args, tags) + # We have a jaxpr in `f_info.jac`, which are compared by identity. Here we + # arrange to use the same one so that downstream equality checks (e.g. in the + # `filter_cond` below) + dynamic = eqx.filter(f_eval_info.jac, eqx.is_array) + static = eqx.filter(state.f_info.jac, eqx.is_array, inverse=True) + jac = eqx.combine(dynamic, static) + f_eval_info = eqx.tree_at(lambda f: f.jac, f_eval_info, jac) + step_size, accept, search_result, search_state = self.search.step( state.first_step, y,