Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dev #28

Merged
merged 8 commits into from
Dec 27, 2023
Merged

dev #28

merged 8 commits into from
Dec 27, 2023

Conversation

patrick-kidger
Copy link
Owner

@patrick-kidger patrick-kidger commented Dec 11, 2023

  • Fixed {Newton,Chord}(cauchy_termination=False) failing when started close to the solution
  • Fixed sol.state including only the dynamic part
  • Fixed test function unwrapping with latest jaxtyping
  • Moved norms to Lineax
  • Updated to use eqxi.GetKey and tree_allclose
  • Switch to ruff-format and ruff for ipynb
  • Fixed implicit_jvp assuming that it was only being used with iterate
  • Now using the same jaxpr in the state.

This is quite an important fix!

The bit that matters here is that the `f_eval_info.jac` in
`AbstractGaussNewton.step` now throws away its static (non-array) parts
of its PyTree, and instead uses the equivalent static (non-array) parts
of `state.f_info.jac`, i.e. as were computed in
`AbstractGaussNewton.init`.

Now at a logical level this shouldn't matter at all: the static pieces
should be the same in both cases, as they're just the output of
`_make_f_info` with similarly-structured inputs.

However, `_make_f_info` calls `lx.FunctionLinearOperator` which calls
`eqx.filter_closure_convert` which calls `jax.make_jaxpr` which returns
a jaxpr... and so between the two calls to `_make_f_info`, we actually
end up with two jaxprs. Both encode the same program, but are two
different Python objects. Now jaxprs have `__eq__` defined according to
identity, so these two (functionally identical) jaxprs do not compare
as equal.

Previously we worked around this inside `_iterate.py`: we carefully
removed or wrapped any jaxprs before anything that would try to compare
them for equality. This was a bit ugly, but it worked.

However, it turns out that this still left a problem when manually
stepping an Optimistix solver! (In a way akin to an Optax solver:
something like
```python
@eqx.filter_jit
def make_step(...):
    ... = solver.step(...)

for ... in ...:  # Python level for-loop
    ... = make_step(...)
```
)
then in fact on every iteration of the Python loop, we would end up
recompiling, as we always gets a new jaxpr at
```
state      # state for the Gauss-Newton solver
  .f_info  # as returned by _make_f_info
  .jac     # the FunctionLinearOperator
  .fn      # the closure-converted function
  .jaxpr   # the jaxpr from the closure conversion
```
!

Now one fix is simply to demand that manually stepping a solver
requires similar hackery as we had in `_iterate.py`. But maybe enough
is enough, and we should try doing something better instead: that is,
we do what this PR does, and just preserves the same jaxpr all the way
through.

For bonus points, this means that we can now remove our special jaxpr
handling from `_iterate.py` (and from `filter_cond`, which also needed
this for the same reason).

Finally, you might be wondering: why do we need to trace two equivalent
jaxprs at all? This seems inefficient -- can we arrange to trace it
just once? The answer is "probably, but not in this PR". This seems to
require that (a) Lineax offer a way to turn off closure conversion
(done in patrick-kidger/lineax#71), but that (b) when
using this, this still seems to trigger a similar issue in JAX, that
the primal and tangent results from `jax.custom_jvp` match. So for now
this is just something to try and tackle later -- once we do, we'll get
slightly better compile times.
@patrick-kidger patrick-kidger merged commit 5528cc3 into main Dec 27, 2023
2 checks passed
@patrick-kidger patrick-kidger deleted the dev branch December 27, 2023 16:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant