Skip to content

Commit

Permalink
add defer_within_autodiff to EnzymeInterpreter in order for `within…
Browse files Browse the repository at this point in the history
…_autodiff` to no return true during Reactant compilation.

When this flag is true, `interp.handler` is responsible for handling within_autodiff, or to toggle defer_within_autodiff to false somewhere down the call chain.
  • Loading branch information
jumerckx authored and vchuravy committed Jan 17, 2025
1 parent b95c737 commit aa6ef9c
Showing 1 changed file with 42 additions and 2 deletions.
44 changes: 42 additions & 2 deletions src/compiler/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ struct EnzymeInterpreter{T} <: AbstractInterpreter
reverse_rules::Bool
inactive_rules::Bool
broadcast_rewrite::Bool

# When true, leave the check for within_autodiff to the handler.
defer_within_autodiff::Bool

handler::T
end

Expand Down Expand Up @@ -169,6 +173,7 @@ function EnzymeInterpreter(
reverse_rules::Bool,
inactive_rules::Bool,
broadcast_rewrite::Bool = true,
defer_within_autodiff::Bool = false,
handler = nothing
)
@assert world <= Base.get_world_counter()
Expand Down Expand Up @@ -229,6 +234,7 @@ function EnzymeInterpreter(
reverse_rules::Bool,
inactive_rules::Bool,
broadcast_rewrite::Bool,
defer_within_autodiff::Bool,
handler
)
end
Expand All @@ -240,8 +246,42 @@ EnzymeInterpreter(
mode::API.CDerivativeMode,
inactive_rules::Bool,
broadcast_rewrite::Bool = true,
defer_within_autodiff::Bool = false,
handler = nothing
) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, inactive_rules, broadcast_rewrite, handler)
) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, inactive_rules, broadcast_rewrite, defer_within_autodiff, handler)

function EnzymeInterpreter(interp::EnzymeInterpreter;
cache_or_token = (@static if HAS_INTEGRATED_CACHE
interp.token
else
interp.code_cache
end),
mt = interp.method_table,
local_cache = interp.local_cache,
world = interp.world,
inf_params = interp.inf_params,
opt_params = interp.opt_params,
forward_rules = interp.forward_rules,
reverse_rules = interp.reverse_rules,
inactive_rules = interp.inactive_rules,
broadcast_rewrite = interp.broadcast_rewrite,
defer_within_autodiff = interp.defer_within_autodiff,
handler = interp.handler)
return EnzymeInterpreter(
cache_or_token,
mt,
local_cache,
world,
inf_params,
opt_params,
forward_rules,
reverse_rules,
inactive_rules,
broadcast_rewrite,
defer_within_autodiff,
handler
)
end

Core.Compiler.InferenceParams(@nospecialize(interp::EnzymeInterpreter)) = interp.inf_params
Core.Compiler.OptimizationParams(@nospecialize(interp::EnzymeInterpreter)) = interp.opt_params
Expand Down Expand Up @@ -909,7 +949,7 @@ function abstract_call_known(

(; fargs, argtypes) = arginfo

if f === Enzyme.within_autodiff
if !(interp.defer_within_autodiff) && f === Enzyme.within_autodiff
if length(argtypes) != 1
@static if VERSION < v"1.11.0-"
return CallMeta(Union{}, Effects(), NoCallInfo())
Expand Down

0 comments on commit aa6ef9c

Please sign in to comment.