-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix jax deprecations #1346
fix jax deprecations #1346
Conversation
Codecov ReportAttention: Patch coverage is
❗ Your organization needs to install the Codecov GitHub app to enable full functionality. Additional details and impacted files@@ Coverage Diff @@
## develop #1346 +/- ##
===========================================
- Coverage 84.50% 84.43% -0.07%
===========================================
Files 157 157
Lines 12906 12881 -25
===========================================
- Hits 10906 10876 -30
- Misses 2000 2005 +5 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, not clear on the very details of the implementation though, but the tests checks out 👌🏼 thanks for the update
Co-authored-by: Paul Jonas Jost <70631928+PaulJonasJost@users.noreply.github.com>
…into fix_jax_callback
Now no longer requires passing of input function. Extended tests to demonstrate jax transformations of inputs and outputs. |
…into fix_jax_callback
host_callback
calls in jax, which are now deprecated, withpure_callback
calls.jax.jit
inJax.Objective
.jax.vmap
. This required quite a bit of refactoring as the base objective assumes inputs/outputs to be numpy arrays, which is incompatible with jax batch tracers.jax.grad
andjax.vmap
.