Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: VJP utility based on autodiff_thunk #2309

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

gdalle
Copy link
Contributor

@gdalle gdalle commented Feb 16, 2025

Copy link

codecov bot commented Feb 16, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 75.43%. Comparing base (037dfed) to head (6bf6299).
Report is 353 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2309      +/-   ##
==========================================
+ Coverage   67.50%   75.43%   +7.92%     
==========================================
  Files          31       56      +25     
  Lines       12668    16758    +4090     
==========================================
+ Hits         8552    12642    +4090     
  Misses       4116     4116              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

forward, reverse = autodiff_thunk(rmode, FA, RA, typeof.(args)...)
tape, result, shadow_result = forward(f, args...)
if RA <: Active
dinputs = only(reverse(f, args..., dresult, tape))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering why not just add a vjp, that seems similar to this except doesn't need separate ones for batches

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed I'm not a fan of the name seeded_autodiff_thunk either. vjp does convey the notion that input and output have to be vectors, so maybe pullback is more generic? What kind of signature do you have in mind?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm personally fine with the name vjp but at minimum go ahead and add the code and we can iterate on names in parallel

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait what do you mean by "add the code"? I thought your request of "adding a vjp" was mostly about naming? what is missing here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't have two separate functions here for batched vs not

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also thought of doing it via dispatch but how do we handle the case where dresult itself is supposed to be an NTuple (as opposed to a batch)?

@vchuravy
Copy link
Member

In some way, this feels equivalent to implementing autodiff but for non-scalar returns and fixing my old mistake of always passing in one(T) as the seed.

I think of autodiff(Reverse generally as vjp (with the convention that the output is updated in-place)

@gdalle
Copy link
Contributor Author

gdalle commented Feb 17, 2025

We could also call it autodiff but then we'd need to figure out how the output seed is passed. It can't be inside a Duplicated or BatchDuplicated because we have no primal

@vchuravy
Copy link
Member

It can't be inside a Duplicated or BatchDuplicated because we have no primal

Seed and BatchSeed? Just throwing some ideas into the air.

@gdalle
Copy link
Contributor Author

gdalle commented Feb 17, 2025

And how would you see the order of arguments?

autodiff(Reverse, f, seed, args...)
autodiff(Reverse, f, args...; seed=...)

@vchuravy
Copy link
Member

The first variant, since that is already the convention used for the activity of the return.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Syntactic sugar for vjp
3 participants