Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: VJP utility based on autodiff_thunk #2309

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions src/sugar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1160,3 +1160,76 @@ grad
return nothing
end


"""
seeded_autodiff_thunk(
rmode::ReverseModeSplit,
dresult,
f,
ReturnActivity,
annotated_args...
)

Call [`autodiff_thunk`](@ref), execute the forward pass, increment output tangent with `dresult`, then execute the reverse pass.

Useful for computing pullbacks / VJPs for functions whose output is not a scalar.
"""
function seeded_autodiff_thunk(
rmode::ReverseModeSplit{ReturnPrimal},
dresult,
f::FA,
::Type{RA},
args::Vararg{Annotation,N},
) where {ReturnPrimal,FA<:Annotation,RA<:Annotation,N}
forward, reverse = autodiff_thunk(rmode, FA, RA, typeof.(args)...)
tape, result, shadow_result = forward(f, args...)
if RA <: Active
dinputs = only(reverse(f, args..., dresult, tape))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering why not just add a vjp, that seems similar to this except doesn't need separate ones for batches

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed I'm not a fan of the name seeded_autodiff_thunk either. vjp does convey the notion that input and output have to be vectors, so maybe pullback is more generic? What kind of signature do you have in mind?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm personally fine with the name vjp but at minimum go ahead and add the code and we can iterate on names in parallel

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait what do you mean by "add the code"? I thought your request of "adding a vjp" was mostly about naming? what is missing here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't have two separate functions here for batched vs not

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also thought of doing it via dispatch but how do we handle the case where dresult itself is supposed to be an NTuple (as opposed to a batch)?

else
shadow_result .+= dresult # TODO: generalize beyond arrays
dinputs = only(reverse(f, args..., tape))
end
if ReturnPrimal
return (dinputs, result)
else
return (dinputs,)
end
end

"""
batch_seeded_autodiff_thunk(
rmode::ReverseModeSplit,
dresults::NTuple,
f,
ReturnActivity,
annotated_args...
)

Call [`autodiff_thunk`](@ref), execute the forward pass, increment each output tangent with the corresponding element from `dresults`, then execute the reverse pass.

Useful for computing pullbacks / VJPs for functions whose output is not a scalar.
"""
function batch_seeded_autodiff_thunk(
rmode::ReverseModeSplit{ReturnPrimal},
dresults::NTuple{B},
f::FA,
::Type{RA},
args::Vararg{Annotation,N},
) where {ReturnPrimal,B,FA<:Annotation,RA<:Annotation,N}
rmode_rightwidth = ReverseSplitWidth(rmode, Val(B))
forward, reverse = autodiff_thunk(rmode_rightwidth, FA, RA, typeof.(args)...)
tape, result, shadow_results = forward(f, args...)
if RA <: Active
dinputs = only(reverse(f, args..., dresults, tape))
else
foreach(shadow_results, dresults) do d0, d
d0 .+= d # TODO: generalize beyond arrays
end
dinputs = only(reverse(f, args..., tape))
end
if ReturnPrimal
return (dinputs, result)
else
return (dinputs,)
end
end
72 changes: 72 additions & 0 deletions test/sugar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -650,3 +650,75 @@ end
# @show J_r_3(u, A, x)
# @show J_f_3(u, A, x)
end

using Enzyme: seeded_autodiff_thunk, batch_seeded_autodiff_thunk

@testset "seeded_autodiff_thunk" begin

f(x::Vector{Float64}, y::Float64) = sum(abs2, x) * y
g(x::Vector{Float64}, y::Float64) = [f(x, y)]

x = [1.0, 2.0, 3.0]
y = 4.0
dx = similar(x)
dresult = 5.0
dxs = (similar(x), similar(x))
dresults = (5.0, 7.0)

@testset "simple" begin
for mode in (ReverseSplitNoPrimal, ReverseSplitWithPrimal)
make_zero!(dx)
dinputs_and_maybe_result = seeded_autodiff_thunk(mode, dresult, Const(f), Active, Duplicated(x, dx), Active(y))
dinputs = first(dinputs_and_maybe_result)
@test isnothing(dinputs[1])
@test dinputs[2] == dresult * sum(abs2, x)
@test dx == dresult * 2x * y
if mode == ReverseSplitWithPrimal
@test last(dinputs_and_maybe_result) == f(x, y)
end
end

for mode in (ReverseSplitNoPrimal, ReverseSplitWithPrimal)
make_zero!(dx)
dinputs_and_maybe_result = seeded_autodiff_thunk(mode, [dresult], Const(g), Duplicated, Duplicated(x, dx), Active(y))
dinputs = first(dinputs_and_maybe_result)
@test isnothing(dinputs[1])
@test dinputs[2] == dresult * sum(abs2, x)
@test dx == dresult * 2x * y
if mode == ReverseSplitWithPrimal
@test last(dinputs_and_maybe_result) == g(x, y)
end
end
end

@testset "batch" begin
for mode in (ReverseSplitNoPrimal, ReverseSplitWithPrimal)
make_zero!(dxs)
dinputs_and_maybe_result = batch_seeded_autodiff_thunk(mode, dresults, Const(f), Active, BatchDuplicated(x, dxs), Active(y))
dinputs = first(dinputs_and_maybe_result)
@test isnothing(dinputs[1])
@test dinputs[2][1] == dresults[1] * sum(abs2, x)
@test dinputs[2][2] == dresults[2] * sum(abs2, x)
@test dxs[1] == dresults[1] * 2x * y
@test dxs[2] == dresults[2] * 2x * y
if mode == ReverseSplitWithPrimal
@test last(dinputs_and_maybe_result) == f(x, y)
end
end

for mode in (ReverseSplitNoPrimal, ReverseSplitWithPrimal)
make_zero!(dxs)
dinputs_and_maybe_result = batch_seeded_autodiff_thunk(mode, ([dresults[1]], [dresults[2]]), Const(g), BatchDuplicated, BatchDuplicated(x, dxs), Active(y))
dinputs = first(dinputs_and_maybe_result)
@test isnothing(dinputs[1])
@test dinputs[2][1] == dresults[1] * sum(abs2, x)
@test dinputs[2][2] == dresults[2] * sum(abs2, x)
@test dxs[1] == dresults[1] * 2x * y
@test dxs[2] == dresults[2] * 2x * y
if mode == ReverseSplitWithPrimal
@test last(dinputs_and_maybe_result) == g(x, y)
end
end
end

end
Loading