diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 3d5df1e266..84c26039b1 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -131,6 +131,10 @@ struct EnzymeInterpreter{T} <: AbstractInterpreter reverse_rules::Bool inactive_rules::Bool broadcast_rewrite::Bool + + # When false, leave the check for within_autodiff to the handler. + within_autodiff_rewrite::Bool + handler::T end @@ -169,6 +173,7 @@ function EnzymeInterpreter( reverse_rules::Bool, inactive_rules::Bool, broadcast_rewrite::Bool = true, + within_autodiff_rewrite::Bool = true, handler = nothing ) @assert world <= Base.get_world_counter() @@ -229,6 +234,7 @@ function EnzymeInterpreter( reverse_rules::Bool, inactive_rules::Bool, broadcast_rewrite::Bool, + within_autodiff_rewrite::Bool, handler ) end @@ -240,8 +246,42 @@ EnzymeInterpreter( mode::API.CDerivativeMode, inactive_rules::Bool, broadcast_rewrite::Bool = true, + within_autodiff_rewrite::Bool = true, handler = nothing -) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, inactive_rules, broadcast_rewrite, handler) +) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, inactive_rules, broadcast_rewrite, within_autodiff_rewrite, handler) + +function EnzymeInterpreter(interp::EnzymeInterpreter; + cache_or_token = (@static if HAS_INTEGRATED_CACHE + interp.token + else + interp.code_cache + end), + mt = interp.method_table, + local_cache = interp.local_cache, + world = interp.world, + inf_params = interp.inf_params, + opt_params = interp.opt_params, + forward_rules = interp.forward_rules, + reverse_rules = interp.reverse_rules, + inactive_rules = interp.inactive_rules, + broadcast_rewrite = interp.broadcast_rewrite, + within_autodiff_rewrite = interp.within_autodiff_rewrite, + handler = interp.handler) + return EnzymeInterpreter( + cache_or_token, + mt, + local_cache, + world, + inf_params, + opt_params, + forward_rules, + reverse_rules, + inactive_rules, + broadcast_rewrite, + within_autodiff_rewrite, + handler + ) +end Core.Compiler.InferenceParams(@nospecialize(interp::EnzymeInterpreter)) = interp.inf_params Core.Compiler.OptimizationParams(@nospecialize(interp::EnzymeInterpreter)) = interp.opt_params @@ -909,7 +949,7 @@ function abstract_call_known( (; fargs, argtypes) = arginfo - if f === Enzyme.within_autodiff + if interp.within_autodiff_rewrite && f === Enzyme.within_autodiff if length(argtypes) != 1 @static if VERSION < v"1.11.0-" return CallMeta(Union{}, Effects(), NoCallInfo())