-
Notifications
You must be signed in to change notification settings - Fork 70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: VJP utility based on autodiff_thunk
#2309
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #2309 +/- ##
==========================================
+ Coverage 67.50% 75.43% +7.92%
==========================================
Files 31 56 +25
Lines 12668 16758 +4090
==========================================
+ Hits 8552 12642 +4090
Misses 4116 4116 ☔ View full report in Codecov by Sentry. |
forward, reverse = autodiff_thunk(rmode, FA, RA, typeof.(args)...) | ||
tape, result, shadow_result = forward(f, args...) | ||
if RA <: Active | ||
dinputs = only(reverse(f, args..., dresult, tape)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just wondering why not just add a vjp, that seems similar to this except doesn't need separate ones for batches
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed I'm not a fan of the name seeded_autodiff_thunk
either. vjp
does convey the notion that input and output have to be vectors, so maybe pullback
is more generic? What kind of signature do you have in mind?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm personally fine with the name vjp but at minimum go ahead and add the code and we can iterate on names in parallel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wait what do you mean by "add the code"? I thought your request of "adding a vjp" was mostly about naming? what is missing here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't have two separate functions here for batched vs not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also thought of doing it via dispatch but how do we handle the case where dresult
itself is supposed to be an NTuple
(as opposed to a batch)?
In some way, this feels equivalent to implementing I think of |
We could also call it |
|
And how would you see the order of arguments? autodiff(Reverse, f, seed, args...)
autodiff(Reverse, f, args...; seed=...) |
The first variant, since that is already the convention used for the activity of the return. |
Fixes #1853
Related: