From a7029a817e3b3a5209be538437a8db68376d7b27 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Tue, 28 Jan 2025 23:45:13 +0900 Subject: [PATCH] implement Jameson's suggestion --- Compiler/src/optimize.jl | 41 +++++++++++++++++------ Compiler/src/typeinfer.jl | 68 +++++++++++++++++++++++++-------------- 2 files changed, 75 insertions(+), 34 deletions(-) diff --git a/Compiler/src/optimize.jl b/Compiler/src/optimize.jl index c0302444e5fef..a0607e05f0c71 100644 --- a/Compiler/src/optimize.jl +++ b/Compiler/src/optimize.jl @@ -143,16 +143,37 @@ struct InliningState{Interp<:AbstractInterpreter} edges::Vector{Any} world::UInt interp::Interp + opt_cache::IdDict{MethodInstance,CodeInstance} end -function InliningState(sv::InferenceState, interp::AbstractInterpreter) - return InliningState(sv.edges, frame_world(sv), interp) +function InliningState(sv::InferenceState, interp::AbstractInterpreter, opt_cache::IdDict{MethodInstance,CodeInstance}) + return InliningState(sv.edges, frame_world(sv), interp, opt_cache) end -function InliningState(interp::AbstractInterpreter) - return InliningState(Any[], get_inference_world(interp), interp) +function InliningState(interp::AbstractInterpreter, opt_cache::IdDict{MethodInstance,CodeInstance}) + return InliningState(Any[], get_inference_world(interp), interp, opt_cache) +end + +struct OptimizerCache{CodeCache} + code_cache::CodeCache + opt_cache::IdDict{MethodInstance,CodeInstance} +end +function get((; code_cache, opt_cache)::OptimizerCache{WorldView{InternalCodeCache}}, mi::MethodInstance, default) + if haskey(opt_cache, mi) + codeinst = opt_cache[mi] + if (codeinst.min_world ≤ code_cache.worlds.min_world && + code_cache.worlds.max_world ≤ codeinst.max_world && + codeinst.owner === code_cache.cache.owner) + @assert isdefined(codeinst, :inferred) && codeinst.inferred === nothing + return codeinst + end + end + return get(code_cache, mi, default) end # get `code_cache(::AbstractInterpreter)` from `state::InliningState` -code_cache(state::InliningState) = WorldView(code_cache(state.interp), state.world) +function code_cache(state::InliningState) + cache = WorldView(code_cache(state.interp), state.world) + return OptimizerCache(cache, state.opt_cache) +end mutable struct OptimizationState{Interp<:AbstractInterpreter} linfo::MethodInstance @@ -168,13 +189,15 @@ mutable struct OptimizationState{Interp<:AbstractInterpreter} bb_vartables::Vector{Union{Nothing,VarTable}} insert_coverage::Bool end -function OptimizationState(sv::InferenceState, interp::AbstractInterpreter) - inlining = InliningState(sv, interp) +function OptimizationState(sv::InferenceState, interp::AbstractInterpreter, + opt_cache::IdDict{MethodInstance,CodeInstance}=IdDict{MethodInstance,CodeInstance}()) + inlining = InliningState(sv, interp, opt_cache) return OptimizationState(sv.linfo, sv.src, nothing, sv.stmt_info, sv.mod, sv.sptypes, sv.slottypes, inlining, sv.cfg, sv.unreachable, sv.bb_vartables, sv.insert_coverage) end -function OptimizationState(mi::MethodInstance, src::CodeInfo, interp::AbstractInterpreter) +function OptimizationState(mi::MethodInstance, src::CodeInfo, interp::AbstractInterpreter, + opt_cache::IdDict{MethodInstance,CodeInstance}=IdDict{MethodInstance,CodeInstance}()) # prepare src for running optimization passes if it isn't already nssavalues = src.ssavaluetypes if nssavalues isa Int @@ -194,7 +217,7 @@ function OptimizationState(mi::MethodInstance, src::CodeInfo, interp::AbstractIn mod = isa(def, Method) ? def.module : def # Allow using the global MI cache, but don't track edges. # This method is mostly used for unit testing the optimizer - inlining = InliningState(interp) + inlining = InliningState(interp, opt_cache) cfg = compute_basic_blocks(src.code) unreachable = BitSet() bb_vartables = Union{VarTable,Nothing}[] diff --git a/Compiler/src/typeinfer.jl b/Compiler/src/typeinfer.jl index 38b412f06d9be..0328f02355090 100644 --- a/Compiler/src/typeinfer.jl +++ b/Compiler/src/typeinfer.jl @@ -106,6 +106,7 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState) #@assert last(result.valid_worlds) <= get_world_counter() || isempty(caller.edges) if isdefined(result, :ci) ci = result.ci + mi = result.linfo # if we aren't cached, we don't need this edge # but our caller might, so let's just make it anyways if last(result.valid_worlds) >= get_world_counter() @@ -132,7 +133,7 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState) end di = inferred_result.debuginfo uncompressed = inferred_result - inferred_result = maybe_compress_codeinfo(interp, result.linfo, inferred_result) + inferred_result = maybe_compress_codeinfo(interp, mi, inferred_result) result.is_src_volatile = false elseif ci.owner === nothing # The global cache can only handle objects that codegen understands @@ -140,7 +141,7 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState) end end if !@isdefined di - di = DebugInfo(result.linfo) + di = DebugInfo(mi) end min_world, max_world = first(result.valid_worlds), last(result.valid_worlds) ipo_effects = encode_effects(result.ipo_effects) @@ -157,7 +158,6 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState) # This is necessary to get decent bootstrapping performance # when compiling the compiler to inject everything eagerly # where codegen can start finding and using it right away - mi = result.linfo if mi.def isa Method && isa_compileable_sig(mi) ccall(:jl_add_codeinst_to_jit, Cvoid, (Any, Any), ci, uncompressed) end @@ -167,6 +167,15 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState) return nothing end +function cache_result!(interp::AbstractInterpreter, frame::InferenceState) + result = frame.result + if isdefined(result, :ci) + if is_cached(frame) + code_cache(interp)[result.linfo] = result.ci + end + end +end + function finish!(interp::AbstractInterpreter, mi::MethodInstance, ci::CodeInstance, src::CodeInfo) user_edges = src.edges edges = user_edges isa SimpleVector ? user_edges : user_edges === nothing ? Core.svec() : Core.svec(user_edges...) @@ -200,11 +209,14 @@ function finish!(interp::AbstractInterpreter, mi::MethodInstance, ci::CodeInstan end function finish_nocycle(::AbstractInterpreter, frame::InferenceState) - finishinfer!(frame, frame.interp) + opt_cache = IdDict{MethodInstance,CodeInstance}() + finishinfer!(frame, frame.interp, opt_cache) opt = frame.result.src if opt isa OptimizationState # implies `may_optimize(caller.interp) === true` optimize(frame.interp, opt, frame.result) end + empty!(opt_cache) + cache_result!(frame.interp, frame) finish!(frame.interp, frame) if frame.cycleid != 0 frames = frame.callstack::Vector{AbsIntState} @@ -227,10 +239,11 @@ function finish_cycle(::AbstractInterpreter, frames::Vector{AbsIntState}, cyclei cycle_valid_worlds = intersect(cycle_valid_worlds, caller.world.valid_worlds) cycle_valid_effects = merge_effects(cycle_valid_effects, caller.ipo_effects) end + opt_cache = IdDict{MethodInstance,CodeInstance}() for frameid = cycleid:length(frames) caller = frames[frameid]::InferenceState adjust_cycle_frame!(caller, cycle_valid_worlds, cycle_valid_effects) - finishinfer!(caller, caller.interp) + finishinfer!(caller, caller.interp, opt_cache) end for frameid = cycleid:length(frames) caller = frames[frameid]::InferenceState @@ -239,6 +252,11 @@ function finish_cycle(::AbstractInterpreter, frames::Vector{AbsIntState}, cyclei optimize(caller.interp, opt, caller.result) end end + empty!(opt_cache) + for frameid = cycleid:length(frames) + caller = frames[frameid]::InferenceState + cache_result!(caller.interp, caller) + end for frameid = cycleid:length(frames) caller = frames[frameid]::InferenceState finish!(caller.interp, caller) @@ -285,22 +303,6 @@ function maybe_compress_codeinfo(interp::AbstractInterpreter, mi::MethodInstance end end -function cache_result!(interp::AbstractInterpreter, result::InferenceResult, ci::CodeInstance) - @assert isdefined(ci, :inferred) - # check if the existing linfo metadata is also sufficient to describe the current inference result - # to decide if it is worth caching this right now - mi = result.linfo - cache = WorldView(code_cache(interp), result.valid_worlds) - if haskey(cache, mi) - ci = cache[mi] - # n.b.: accurate edge representation might cause the CodeInstance for this to be constructed later - @assert isdefined(ci, :inferred) - return false - end - code_cache(interp)[mi] = ci - return true -end - function cycle_fix_limited(@nospecialize(typ), sv::InferenceState) if typ isa LimitedAccuracy if sv.parentid === 0 @@ -428,7 +430,8 @@ const empty_edges = Core.svec() # inference completed on `me` # update the MethodInstance -function finishinfer!(me::InferenceState, interp::AbstractInterpreter) +function finishinfer!(me::InferenceState, interp::AbstractInterpreter, + opt_cache::IdDict{MethodInstance, CodeInstance}) # prepare to run optimization passes on fulltree @assert isempty(me.ip) # inspect whether our inference had a limited result accuracy, @@ -481,7 +484,7 @@ function finishinfer!(me::InferenceState, interp::AbstractInterpreter) # disable optimization if we've already obtained very accurate result !result_is_constabi(interp, result) if doopt - result.src = OptimizationState(me, interp) + result.src = OptimizationState(me, interp, opt_cache) else result.src = me.src # for reflection etc. end @@ -502,9 +505,11 @@ function finishinfer!(me::InferenceState, interp::AbstractInterpreter) ci, rettype, exctype, rettype_const, const_flags, min_world, max_world, encode_effects(result.ipo_effects), result.analysis_results, di, edges) if is_cached(me) # CACHE_MODE_GLOBAL - cached_result = cache_result!(me.interp, result, ci) - if !cached_result + already_cached = is_already_cached(me.interp, result, ci) + if already_cached me.cache_mode = CACHE_MODE_VOLATILE + else + opt_cache[result.linfo] = ci end end end @@ -551,6 +556,19 @@ end return ResultForCache(rettype, exctype, rettype_const, const_flags) end +function is_already_cached(interp::AbstractInterpreter, result::InferenceResult, ci::CodeInstance) + # check if the existing linfo metadata is also sufficient to describe the current inference result + # to decide if it is worth caching this right now + mi = result.linfo + cache = WorldView(code_cache(interp), result.valid_worlds) + if haskey(cache, mi) + # n.b.: accurate edge representation might cause the CodeInstance for this to be constructed later + @assert isdefined(cache[mi], :inferred) + return true + end + return false +end + # record the backedges function store_backedges(caller::CodeInstance, edges::SimpleVector) isa(caller.def.def, Method) || return # don't add backedges to toplevel method instance