-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
dev #28
Merged
Merged
dev #28
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
…lose to the solution
This is quite an important fix! The bit that matters here is that the `f_eval_info.jac` in `AbstractGaussNewton.step` now throws away its static (non-array) parts of its PyTree, and instead uses the equivalent static (non-array) parts of `state.f_info.jac`, i.e. as were computed in `AbstractGaussNewton.init`. Now at a logical level this shouldn't matter at all: the static pieces should be the same in both cases, as they're just the output of `_make_f_info` with similarly-structured inputs. However, `_make_f_info` calls `lx.FunctionLinearOperator` which calls `eqx.filter_closure_convert` which calls `jax.make_jaxpr` which returns a jaxpr... and so between the two calls to `_make_f_info`, we actually end up with two jaxprs. Both encode the same program, but are two different Python objects. Now jaxprs have `__eq__` defined according to identity, so these two (functionally identical) jaxprs do not compare as equal. Previously we worked around this inside `_iterate.py`: we carefully removed or wrapped any jaxprs before anything that would try to compare them for equality. This was a bit ugly, but it worked. However, it turns out that this still left a problem when manually stepping an Optimistix solver! (In a way akin to an Optax solver: something like ```python @eqx.filter_jit def make_step(...): ... = solver.step(...) for ... in ...: # Python level for-loop ... = make_step(...) ``` ) then in fact on every iteration of the Python loop, we would end up recompiling, as we always gets a new jaxpr at ``` state # state for the Gauss-Newton solver .f_info # as returned by _make_f_info .jac # the FunctionLinearOperator .fn # the closure-converted function .jaxpr # the jaxpr from the closure conversion ``` ! Now one fix is simply to demand that manually stepping a solver requires similar hackery as we had in `_iterate.py`. But maybe enough is enough, and we should try doing something better instead: that is, we do what this PR does, and just preserves the same jaxpr all the way through. For bonus points, this means that we can now remove our special jaxpr handling from `_iterate.py` (and from `filter_cond`, which also needed this for the same reason). Finally, you might be wondering: why do we need to trace two equivalent jaxprs at all? This seems inefficient -- can we arrange to trace it just once? The answer is "probably, but not in this PR". This seems to require that (a) Lineax offer a way to turn off closure conversion (done in patrick-kidger/lineax#71), but that (b) when using this, this still seems to trigger a similar issue in JAX, that the primal and tangent results from `jax.custom_jvp` match. So for now this is just something to try and tackle later -- once we do, we'll get slightly better compile times.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
{Newton,Chord}(cauchy_termination=False)
failing when started close to the solutionsol.state
including only the dynamic parteqxi.GetKey
andtree_allclose
implicit_jvp
assuming that it was only being used with iterate