diff --git a/src/sugar.jl b/src/sugar.jl index 83c078c9df..6fa9effe37 100644 --- a/src/sugar.jl +++ b/src/sugar.jl @@ -1160,3 +1160,76 @@ grad return nothing end + +""" + seeded_autodiff_thunk( + rmode::ReverseModeSplit, + dresult, + f, + ReturnActivity, + annotated_args... + ) + +Call [`autodiff_thunk`](@ref), execute the forward pass, increment output tangent with `dresult`, then execute the reverse pass. + +Useful for computing pullbacks / VJPs for functions whose output is not a scalar. +""" +function seeded_autodiff_thunk( + rmode::ReverseModeSplit{ReturnPrimal}, + dresult, + f::FA, + ::Type{RA}, + args::Vararg{Annotation,N}, +) where {ReturnPrimal,FA<:Annotation,RA<:Annotation,N} + forward, reverse = autodiff_thunk(rmode, FA, RA, typeof.(args)...) + tape, result, shadow_result = forward(f, args...) + if RA <: Active + dinputs = only(reverse(f, args..., dresult, tape)) + else + shadow_result .+= dresult # TODO: generalize beyond arrays + dinputs = only(reverse(f, args..., tape)) + end + if ReturnPrimal + return (dinputs, result) + else + return (dinputs,) + end +end + +""" + batch_seeded_autodiff_thunk( + rmode::ReverseModeSplit, + dresults::NTuple, + f, + ReturnActivity, + annotated_args... + ) + +Call [`autodiff_thunk`](@ref), execute the forward pass, increment each output tangent with the corresponding element from `dresults`, then execute the reverse pass. + +Useful for computing pullbacks / VJPs for functions whose output is not a scalar. +""" +function batch_seeded_autodiff_thunk( + rmode::ReverseModeSplit{ReturnPrimal}, + dresults::NTuple{B}, + f::FA, + ::Type{RA}, + args::Vararg{Annotation,N}, +) where {ReturnPrimal,B,FA<:Annotation,RA<:Annotation,N} + rmode_rightwidth = ReverseSplitWidth(rmode, Val(B)) + forward, reverse = autodiff_thunk(rmode_rightwidth, FA, RA, typeof.(args)...) + tape, result, shadow_results = forward(f, args...) + if RA <: Active + dinputs = only(reverse(f, args..., dresults, tape)) + else + foreach(shadow_results, dresults) do d0, d + d0 .+= d # TODO: generalize beyond arrays + end + dinputs = only(reverse(f, args..., tape)) + end + if ReturnPrimal + return (dinputs, result) + else + return (dinputs,) + end +end diff --git a/test/sugar.jl b/test/sugar.jl index 340a54c569..6f1bdb138e 100644 --- a/test/sugar.jl +++ b/test/sugar.jl @@ -650,3 +650,75 @@ end # @show J_r_3(u, A, x) # @show J_f_3(u, A, x) end + +using Enzyme: seeded_autodiff_thunk, batch_seeded_autodiff_thunk + +@testset "seeded_autodiff_thunk" begin + + f(x::Vector{Float64}, y::Float64) = sum(abs2, x) * y + g(x::Vector{Float64}, y::Float64) = [f(x, y)] + + x = [1.0, 2.0, 3.0] + y = 4.0 + dx = similar(x) + dresult = 5.0 + dxs = (similar(x), similar(x)) + dresults = (5.0, 7.0) + + @testset "simple" begin + for mode in (ReverseSplitNoPrimal, ReverseSplitWithPrimal) + make_zero!(dx) + dinputs_and_maybe_result = seeded_autodiff_thunk(mode, dresult, Const(f), Active, Duplicated(x, dx), Active(y)) + dinputs = first(dinputs_and_maybe_result) + @test isnothing(dinputs[1]) + @test dinputs[2] == dresult * sum(abs2, x) + @test dx == dresult * 2x * y + if mode == ReverseSplitWithPrimal + @test last(dinputs_and_maybe_result) == f(x, y) + end + end + + for mode in (ReverseSplitNoPrimal, ReverseSplitWithPrimal) + make_zero!(dx) + dinputs_and_maybe_result = seeded_autodiff_thunk(mode, [dresult], Const(g), Duplicated, Duplicated(x, dx), Active(y)) + dinputs = first(dinputs_and_maybe_result) + @test isnothing(dinputs[1]) + @test dinputs[2] == dresult * sum(abs2, x) + @test dx == dresult * 2x * y + if mode == ReverseSplitWithPrimal + @test last(dinputs_and_maybe_result) == g(x, y) + end + end + end + + @testset "batch" begin + for mode in (ReverseSplitNoPrimal, ReverseSplitWithPrimal) + make_zero!(dxs) + dinputs_and_maybe_result = batch_seeded_autodiff_thunk(mode, dresults, Const(f), Active, BatchDuplicated(x, dxs), Active(y)) + dinputs = first(dinputs_and_maybe_result) + @test isnothing(dinputs[1]) + @test dinputs[2][1] == dresults[1] * sum(abs2, x) + @test dinputs[2][2] == dresults[2] * sum(abs2, x) + @test dxs[1] == dresults[1] * 2x * y + @test dxs[2] == dresults[2] * 2x * y + if mode == ReverseSplitWithPrimal + @test last(dinputs_and_maybe_result) == f(x, y) + end + end + + for mode in (ReverseSplitNoPrimal, ReverseSplitWithPrimal) + make_zero!(dxs) + dinputs_and_maybe_result = batch_seeded_autodiff_thunk(mode, ([dresults[1]], [dresults[2]]), Const(g), BatchDuplicated, BatchDuplicated(x, dxs), Active(y)) + dinputs = first(dinputs_and_maybe_result) + @test isnothing(dinputs[1]) + @test dinputs[2][1] == dresults[1] * sum(abs2, x) + @test dinputs[2][2] == dresults[2] * sum(abs2, x) + @test dxs[1] == dresults[1] * 2x * y + @test dxs[2] == dresults[2] * 2x * y + if mode == ReverseSplitWithPrimal + @test last(dinputs_and_maybe_result) == g(x, y) + end + end + end + +end