Skip to content

Commit

Permalink
implement Jameson's suggestion
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk committed Jan 28, 2025
1 parent 0f0c5f1 commit a7029a8
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 34 deletions.
41 changes: 32 additions & 9 deletions Compiler/src/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,37 @@ struct InliningState{Interp<:AbstractInterpreter}
edges::Vector{Any}
world::UInt
interp::Interp
opt_cache::IdDict{MethodInstance,CodeInstance}
end
function InliningState(sv::InferenceState, interp::AbstractInterpreter)
return InliningState(sv.edges, frame_world(sv), interp)
function InliningState(sv::InferenceState, interp::AbstractInterpreter, opt_cache::IdDict{MethodInstance,CodeInstance})
return InliningState(sv.edges, frame_world(sv), interp, opt_cache)
end
function InliningState(interp::AbstractInterpreter)
return InliningState(Any[], get_inference_world(interp), interp)
function InliningState(interp::AbstractInterpreter, opt_cache::IdDict{MethodInstance,CodeInstance})
return InliningState(Any[], get_inference_world(interp), interp, opt_cache)
end

struct OptimizerCache{CodeCache}
code_cache::CodeCache
opt_cache::IdDict{MethodInstance,CodeInstance}
end
function get((; code_cache, opt_cache)::OptimizerCache{WorldView{InternalCodeCache}}, mi::MethodInstance, default)
if haskey(opt_cache, mi)
codeinst = opt_cache[mi]
if (codeinst.min_world code_cache.worlds.min_world &&
code_cache.worlds.max_world codeinst.max_world &&
codeinst.owner === code_cache.cache.owner)
@assert isdefined(codeinst, :inferred) && codeinst.inferred === nothing
return codeinst
end
end
return get(code_cache, mi, default)
end

# get `code_cache(::AbstractInterpreter)` from `state::InliningState`
code_cache(state::InliningState) = WorldView(code_cache(state.interp), state.world)
function code_cache(state::InliningState)
cache = WorldView(code_cache(state.interp), state.world)
return OptimizerCache(cache, state.opt_cache)
end

mutable struct OptimizationState{Interp<:AbstractInterpreter}
linfo::MethodInstance
Expand All @@ -168,13 +189,15 @@ mutable struct OptimizationState{Interp<:AbstractInterpreter}
bb_vartables::Vector{Union{Nothing,VarTable}}
insert_coverage::Bool
end
function OptimizationState(sv::InferenceState, interp::AbstractInterpreter)
inlining = InliningState(sv, interp)
function OptimizationState(sv::InferenceState, interp::AbstractInterpreter,
opt_cache::IdDict{MethodInstance,CodeInstance}=IdDict{MethodInstance,CodeInstance}())
inlining = InliningState(sv, interp, opt_cache)
return OptimizationState(sv.linfo, sv.src, nothing, sv.stmt_info, sv.mod,
sv.sptypes, sv.slottypes, inlining, sv.cfg,
sv.unreachable, sv.bb_vartables, sv.insert_coverage)
end
function OptimizationState(mi::MethodInstance, src::CodeInfo, interp::AbstractInterpreter)
function OptimizationState(mi::MethodInstance, src::CodeInfo, interp::AbstractInterpreter,
opt_cache::IdDict{MethodInstance,CodeInstance}=IdDict{MethodInstance,CodeInstance}())
# prepare src for running optimization passes if it isn't already
nssavalues = src.ssavaluetypes
if nssavalues isa Int
Expand All @@ -194,7 +217,7 @@ function OptimizationState(mi::MethodInstance, src::CodeInfo, interp::AbstractIn
mod = isa(def, Method) ? def.module : def
# Allow using the global MI cache, but don't track edges.
# This method is mostly used for unit testing the optimizer
inlining = InliningState(interp)
inlining = InliningState(interp, opt_cache)
cfg = compute_basic_blocks(src.code)
unreachable = BitSet()
bb_vartables = Union{VarTable,Nothing}[]
Expand Down
68 changes: 43 additions & 25 deletions Compiler/src/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState)
#@assert last(result.valid_worlds) <= get_world_counter() || isempty(caller.edges)
if isdefined(result, :ci)
ci = result.ci
mi = result.linfo
# if we aren't cached, we don't need this edge
# but our caller might, so let's just make it anyways
if last(result.valid_worlds) >= get_world_counter()
Expand All @@ -132,15 +133,15 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState)
end
di = inferred_result.debuginfo
uncompressed = inferred_result
inferred_result = maybe_compress_codeinfo(interp, result.linfo, inferred_result)
inferred_result = maybe_compress_codeinfo(interp, mi, inferred_result)
result.is_src_volatile = false
elseif ci.owner === nothing
# The global cache can only handle objects that codegen understands
inferred_result = nothing
end
end
if !@isdefined di
di = DebugInfo(result.linfo)
di = DebugInfo(mi)
end
min_world, max_world = first(result.valid_worlds), last(result.valid_worlds)
ipo_effects = encode_effects(result.ipo_effects)
Expand All @@ -157,7 +158,6 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState)
# This is necessary to get decent bootstrapping performance
# when compiling the compiler to inject everything eagerly
# where codegen can start finding and using it right away
mi = result.linfo
if mi.def isa Method && isa_compileable_sig(mi)
ccall(:jl_add_codeinst_to_jit, Cvoid, (Any, Any), ci, uncompressed)
end
Expand All @@ -167,6 +167,15 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState)
return nothing
end

function cache_result!(interp::AbstractInterpreter, frame::InferenceState)
result = frame.result
if isdefined(result, :ci)
if is_cached(frame)
code_cache(interp)[result.linfo] = result.ci
end
end
end

function finish!(interp::AbstractInterpreter, mi::MethodInstance, ci::CodeInstance, src::CodeInfo)
user_edges = src.edges
edges = user_edges isa SimpleVector ? user_edges : user_edges === nothing ? Core.svec() : Core.svec(user_edges...)
Expand Down Expand Up @@ -200,11 +209,14 @@ function finish!(interp::AbstractInterpreter, mi::MethodInstance, ci::CodeInstan
end

function finish_nocycle(::AbstractInterpreter, frame::InferenceState)
finishinfer!(frame, frame.interp)
opt_cache = IdDict{MethodInstance,CodeInstance}()
finishinfer!(frame, frame.interp, opt_cache)
opt = frame.result.src
if opt isa OptimizationState # implies `may_optimize(caller.interp) === true`
optimize(frame.interp, opt, frame.result)
end
empty!(opt_cache)
cache_result!(frame.interp, frame)
finish!(frame.interp, frame)
if frame.cycleid != 0
frames = frame.callstack::Vector{AbsIntState}
Expand All @@ -227,10 +239,11 @@ function finish_cycle(::AbstractInterpreter, frames::Vector{AbsIntState}, cyclei
cycle_valid_worlds = intersect(cycle_valid_worlds, caller.world.valid_worlds)
cycle_valid_effects = merge_effects(cycle_valid_effects, caller.ipo_effects)
end
opt_cache = IdDict{MethodInstance,CodeInstance}()
for frameid = cycleid:length(frames)
caller = frames[frameid]::InferenceState
adjust_cycle_frame!(caller, cycle_valid_worlds, cycle_valid_effects)
finishinfer!(caller, caller.interp)
finishinfer!(caller, caller.interp, opt_cache)
end
for frameid = cycleid:length(frames)
caller = frames[frameid]::InferenceState
Expand All @@ -239,6 +252,11 @@ function finish_cycle(::AbstractInterpreter, frames::Vector{AbsIntState}, cyclei
optimize(caller.interp, opt, caller.result)
end
end
empty!(opt_cache)
for frameid = cycleid:length(frames)
caller = frames[frameid]::InferenceState
cache_result!(caller.interp, caller)
end
for frameid = cycleid:length(frames)
caller = frames[frameid]::InferenceState
finish!(caller.interp, caller)
Expand Down Expand Up @@ -285,22 +303,6 @@ function maybe_compress_codeinfo(interp::AbstractInterpreter, mi::MethodInstance
end
end

function cache_result!(interp::AbstractInterpreter, result::InferenceResult, ci::CodeInstance)
@assert isdefined(ci, :inferred)
# check if the existing linfo metadata is also sufficient to describe the current inference result
# to decide if it is worth caching this right now
mi = result.linfo
cache = WorldView(code_cache(interp), result.valid_worlds)
if haskey(cache, mi)
ci = cache[mi]
# n.b.: accurate edge representation might cause the CodeInstance for this to be constructed later
@assert isdefined(ci, :inferred)
return false
end
code_cache(interp)[mi] = ci
return true
end

function cycle_fix_limited(@nospecialize(typ), sv::InferenceState)
if typ isa LimitedAccuracy
if sv.parentid === 0
Expand Down Expand Up @@ -428,7 +430,8 @@ const empty_edges = Core.svec()

# inference completed on `me`
# update the MethodInstance
function finishinfer!(me::InferenceState, interp::AbstractInterpreter)
function finishinfer!(me::InferenceState, interp::AbstractInterpreter,
opt_cache::IdDict{MethodInstance, CodeInstance})
# prepare to run optimization passes on fulltree
@assert isempty(me.ip)
# inspect whether our inference had a limited result accuracy,
Expand Down Expand Up @@ -481,7 +484,7 @@ function finishinfer!(me::InferenceState, interp::AbstractInterpreter)
# disable optimization if we've already obtained very accurate result
!result_is_constabi(interp, result)
if doopt
result.src = OptimizationState(me, interp)
result.src = OptimizationState(me, interp, opt_cache)
else
result.src = me.src # for reflection etc.
end
Expand All @@ -502,9 +505,11 @@ function finishinfer!(me::InferenceState, interp::AbstractInterpreter)
ci, rettype, exctype, rettype_const, const_flags, min_world, max_world,
encode_effects(result.ipo_effects), result.analysis_results, di, edges)
if is_cached(me) # CACHE_MODE_GLOBAL
cached_result = cache_result!(me.interp, result, ci)
if !cached_result
already_cached = is_already_cached(me.interp, result, ci)
if already_cached
me.cache_mode = CACHE_MODE_VOLATILE
else
opt_cache[result.linfo] = ci
end
end
end
Expand Down Expand Up @@ -551,6 +556,19 @@ end
return ResultForCache(rettype, exctype, rettype_const, const_flags)
end

function is_already_cached(interp::AbstractInterpreter, result::InferenceResult, ci::CodeInstance)
# check if the existing linfo metadata is also sufficient to describe the current inference result
# to decide if it is worth caching this right now
mi = result.linfo
cache = WorldView(code_cache(interp), result.valid_worlds)
if haskey(cache, mi)
# n.b.: accurate edge representation might cause the CodeInstance for this to be constructed later
@assert isdefined(cache[mi], :inferred)
return true
end
return false
end

# record the backedges
function store_backedges(caller::CodeInstance, edges::SimpleVector)
isa(caller.def.def, Method) || return # don't add backedges to toplevel method instance
Expand Down

0 comments on commit a7029a8

Please sign in to comment.