From 231b1961194470a7ed4a4154917207a0f9ab5fb3 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Wed, 27 Nov 2024 01:00:15 +0900 Subject: [PATCH] perform inference using optimizer-derived type information In certain cases, the optimizer can introduce new type information. This is particularly evident in SROA, where load forwarding can reveal type information that was not visible during abstract interpretation. In such cases, re-running abstract interpretation using this new type information can be highly valuable, however, currently, this only occurs when semi-concrete interpretation happens to be triggered. This commit introduces a new "post-optimization inference" phase at the end of the optimizer pipeline. When the optimizer derives new type information, this phase performs IR abstract interpretation to further optimize the IR. --- Compiler/src/inferencestate.jl | 3 +- Compiler/src/optimize.jl | 64 ++++++++++++++++++++++------ Compiler/src/ssair/EscapeAnalysis.jl | 3 +- Compiler/src/ssair/ir.jl | 3 +- Compiler/src/ssair/irinterp.jl | 8 +++- Compiler/src/ssair/passes.jl | 36 +++++++++++----- Compiler/src/typeinfer.jl | 3 +- Compiler/test/inference.jl | 17 +++++++- 8 files changed, 108 insertions(+), 29 deletions(-) diff --git a/Compiler/src/inferencestate.jl b/Compiler/src/inferencestate.jl index 9eb929b725fbf..8c69fdd54cf5d 100644 --- a/Compiler/src/inferencestate.jl +++ b/Compiler/src/inferencestate.jl @@ -804,6 +804,7 @@ mutable struct IRInterpretationState callstack #::Vector{AbsIntState} frameid::Int parentid::Int + new_call_inferred::Bool function IRInterpretationState(interp::AbstractInterpreter, spec_info::SpecInfo, ir::IRCode, mi::MethodInstance, argtypes::Vector{Any}, @@ -829,7 +830,7 @@ mutable struct IRInterpretationState edges = Any[] callstack = AbsIntState[] return new(spec_info, ir, mi, WorldWithRange(world, valid_worlds), curridx, argtypes_refined, ir.sptypes, tpdum, - ssa_refined, lazyreachability, tasks, edges, callstack, 0, 0) + ssa_refined, lazyreachability, tasks, edges, callstack, 0, 0, #=new_call_inferred=#false) end end diff --git a/Compiler/src/optimize.jl b/Compiler/src/optimize.jl index d2dfd26bfa00d..0dfe732844aed 100644 --- a/Compiler/src/optimize.jl +++ b/Compiler/src/optimize.jl @@ -992,7 +992,7 @@ end # run the optimization work function optimize(interp::AbstractInterpreter, opt::OptimizationState, caller::InferenceResult) - @timeit "optimizer" ir = run_passes_ipo_safe(opt.src, opt) + @timeit "optimizer" ir = run_passes_ipo_safe(interp, opt, caller) ipo_dataflow_analysis!(interp, opt, ir, caller) return finish(interp, opt, ir, caller) end @@ -1012,11 +1012,9 @@ matchpass(optimize_until::Int, stage, _) = optimize_until == stage matchpass(optimize_until::String, _, name) = optimize_until == name matchpass(::Nothing, _, _) = false -function run_passes_ipo_safe( - ci::CodeInfo, - sv::OptimizationState, - optimize_until = nothing, # run all passes by default -) +function run_passes_ipo_safe(interp::AbstractInterpreter, sv::OptimizationState, result::InferenceResult; + optimize_until = nothing) # run all passes by default + ci = sv.src __stage__ = 0 # used by @pass # NOTE: The pass name MUST be unique for `optimize_until::AbstractString` to work @pass "convert" ir = convert_to_ircode(ci, sv) @@ -1024,15 +1022,15 @@ function run_passes_ipo_safe( # TODO: Domsorting can produce an updated domtree - no need to recompute here @pass "compact 1" ir = compact!(ir) @pass "Inlining" ir = ssa_inlining_pass!(ir, sv.inlining, ci.propagate_inbounds) - # @timeit "verify 2" verify_ir(ir) @pass "compact 2" ir = compact!(ir) @pass "SROA" ir = sroa_pass!(ir, sv.inlining) - @pass "ADCE" (ir, made_changes) = adce_pass!(ir, sv.inlining) - if made_changes - @pass "compact 3" ir = compact!(ir, true) - end + @pass "ADCE" ir, changed = adce_pass!(ir, sv.inlining) + @pass "compact 3" changed && ( + ir = compact!(ir, true)) + @pass "optinf" optinf_worthwhile(ir) && ( + ir = optinf!(ir, interp, sv, result)) if is_asserts() - @timeit "verify 3" begin + @timeit "verify" begin verify_ir(ir, true, false, optimizer_lattice(sv.inlining.interp), sv.linfo) verify_linetable(ir.debuginfo, length(ir.stmts)) end @@ -1041,6 +1039,48 @@ function run_passes_ipo_safe( return ir end +# If the optimizer derives new type information (as implied by `IR_FLAG_REFINED`), +# and this new type information is available for the arguments of a call expression, +# further optimizations may be possible by performing irinterp on the optimized IR. +function optinf_worthwhile(ir::IRCode) + @assert isempty(ir.new_nodes) "expected compacted IRCode" + for i = 1:length(ir.stmts) + if has_flag(ir[SSAValue(i)], IR_FLAG_REFINED) + stmt = ir[SSAValue(i)][:stmt] + if isexpr(stmt, :call) + return true + end + end + end + return false +end + +function optinf!(ir::IRCode, interp::AbstractInterpreter, sv::OptimizationState, result::InferenceResult) + ci = sv.src + spec_info = SpecInfo(ci) + world = get_inference_world(interp) + min_world, max_world = first(result.valid_worlds), last(result.valid_worlds) + irsv = IRInterpretationState(interp, spec_info, ir, result.linfo, ir.argtypes, + world, min_world, max_world) + rt, (nothrow, noub) = ir_abstract_constant_propagation(interp, irsv) + if irsv.new_call_inferred + ir = ssa_inlining_pass!(ir, sv.inlining, ci.propagate_inbounds) + ir = compact!(ir) + effects = result.effects + if nothrow + effects = Effects(effects; nothrow=true) + end + if noub + effects = Effects(effects; noub=ALWAYS_TRUE) + end + result.effects = effects + result.exc_result = refine_exception_type(result.exc_result, effects) + ⋤ = strictneqpartialorder(ipo_lattice(interp)) + result.result = rt ⋤ result.result ? rt : result.result + end + return ir +end + function strip_trailing_junk!(code::Vector{Any}, ssavaluetypes::Vector{Any}, ssaflags::Vector, debuginfo::DebugInfoStream, cfg::CFG, info::Vector{CallInfo}) # Remove `nothing`s at the end, we don't handle them well # (we expect the last instruction to be a terminator) diff --git a/Compiler/src/ssair/EscapeAnalysis.jl b/Compiler/src/ssair/EscapeAnalysis.jl index 47a7840628bb5..f006efa40824c 100644 --- a/Compiler/src/ssair/EscapeAnalysis.jl +++ b/Compiler/src/ssair/EscapeAnalysis.jl @@ -25,10 +25,11 @@ using Base: # Base definitions unwrap_unionall, !, !=, !==, &, *, +, -, :, <, <<, =>, >, |, ∈, ∉, ∩, ∪, ≠, ≤, ≥, ⊆, hasintersect using ..Compiler: # Compiler specific definitions + Compiler, @show, ⊑, AbstractLattice, Bottom, IRCode, IR_FLAG_NOTHROW, InferenceResult, SimpleInferenceLattice, argextype, fieldcount_noerror, hasintersect, has_flag, intrinsic_nothrow, is_meta_expr_head, is_identity_free_argtype, isexpr, println, setfield!_nothrow, - singleton_type, try_compute_field, try_compute_fieldidx, widenconst, ⊑, Compiler + singleton_type, try_compute_field, try_compute_fieldidx, widenconst function include(x::String) if !isdefined(Base, :end_base_include) diff --git a/Compiler/src/ssair/ir.jl b/Compiler/src/ssair/ir.jl index 9103dba04fa54..efc3ae412c503 100644 --- a/Compiler/src/ssair/ir.jl +++ b/Compiler/src/ssair/ir.jl @@ -1709,7 +1709,8 @@ function reprocess_phi_node!(𝕃ₒ::AbstractLattice, compact::IncrementalCompa # There's only one predecessor left - just replace it v = phi.values[1] - if !⊑(𝕃ₒ, compact[compact.ssa_rename[old_idx]][:type], argextype(v, compact)) + ⋤ = strictneqpartialorder(𝕃ₒ) + if argextype(v, compact) ⋤ compact[compact.ssa_rename[old_idx]][:type] v = Refined(v) end compact.ssa_rename[old_idx] = v diff --git a/Compiler/src/ssair/irinterp.jl b/Compiler/src/ssair/irinterp.jl index a4969e81828cc..2b8050f778d85 100644 --- a/Compiler/src/ssair/irinterp.jl +++ b/Compiler/src/ssair/irinterp.jl @@ -58,6 +58,7 @@ function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo, sstate::St call = abstract_call(interp, arginfo, si, irsv)::Future Future{Any}(call, interp, irsv) do call, interp, irsv irsv.ir.stmts[irsv.curridx][:info] = call.info + irsv.new_call_inferred |= true nothing end return call @@ -204,7 +205,8 @@ function reprocess_instruction!(interp::AbstractInterpreter, inst::Instruction, # Handled at the very end return false elseif isa(stmt, PiNode) - rt = tmeet(typeinf_lattice(interp), argextype(stmt.val, ir), widenconst(stmt.typ)) + ⊓ = join(typeinf_lattice(interp)) + rt = argextype(stmt.val, ir) ⊓ widenconst(stmt.typ) elseif stmt === nothing return false elseif isa(stmt, GlobalRef) @@ -226,7 +228,9 @@ function reprocess_instruction!(interp::AbstractInterpreter, inst::Instruction, inst[:stmt] = quoted(rt.val) end return true - elseif !⊑(typeinf_lattice(interp), inst[:type], rt) + end + ⋤ = strictneqpartialorder(typeinf_lattice(interp)) + if rt ⋤ inst[:type] inst[:type] = rt return true end diff --git a/Compiler/src/ssair/passes.jl b/Compiler/src/ssair/passes.jl index ff333b9b0a129..436f29b93c43c 100644 --- a/Compiler/src/ssair/passes.jl +++ b/Compiler/src/ssair/passes.jl @@ -989,9 +989,10 @@ function lift_keyvalue_get!(compact::IncrementalCompact, idx::Int, stmt::Expr, lifted_leaves === nothing && return result_t = Union{} + ⊔ = join(𝕃ₒ) for v in values(lifted_leaves) v === nothing && return - result_t = tmerge(𝕃ₒ, result_t, argextype(v.val, compact)) + result_t = result_t ⊔ argextype(v.val, compact) end (lifted_val, nest) = perform_lifting!(compact, @@ -1001,8 +1002,12 @@ function lift_keyvalue_get!(compact::IncrementalCompact, idx::Int, stmt::Expr, compact[idx] = lifted_val === nothing ? nothing : Expr(:call, GlobalRef(Core, :tuple), lifted_val.val) finish_phi_nest!(compact, nest) if lifted_val !== nothing - if !⊑(𝕃ₒ, compact[SSAValue(idx)][:type], tuple_tfunc(𝕃ₒ, Any[result_t])) - add_flag!(compact[SSAValue(idx)], IR_FLAG_REFINED) + stmttype = tuple_tfunc(𝕃ₒ, Any[result_t]) + inst = compact[SSAValue(idx)] + ⋤ = strictneqpartialorder(𝕃ₒ) + if stmttype ⋤ inst[:type] + inst[:type] = stmttype + add_flag!(inst, IR_FLAG_REFINED) end end @@ -1440,19 +1445,23 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing) lifted_leaves, any_undef = lifted_result result_t = Union{} + ⊔ = join(𝕃ₒ) for v in values(lifted_leaves) v === nothing && continue - result_t = tmerge(𝕃ₒ, result_t, argextype(v.val, compact)) + result_t = result_t ⊔ argextype(v.val, compact) end (lifted_val, nest) = perform_lifting!(compact, visited_philikes, field, result_t, lifted_leaves, val, lazydomtree) should_delete_node = false - line = compact[SSAValue(idx)][:line] - if lifted_val !== nothing && !⊑(𝕃ₒ, compact[SSAValue(idx)][:type], result_t) + inst = compact[SSAValue(idx)] + line = inst[:line] + ⋤ = strictneqpartialorder(𝕃ₒ) + if lifted_val !== nothing && result_t ⋤ inst[:type] compact[idx] = lifted_val === nothing ? nothing : lifted_val.val - add_flag!(compact[SSAValue(idx)], IR_FLAG_REFINED) + inst[:type] = result_t + add_flag!(inst, IR_FLAG_REFINED) elseif lifted_val === nothing || isa(lifted_val.val, AnySSAValue) # Save some work in a later compaction, by inserting this into the renamer now, # but only do this if we didn't set the REFINED flag, to save work for irinterp @@ -1855,9 +1864,15 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int,Tuple{SPCSet,SSADefUse}} for use in du.uses if use.kind === :getfield inst = ir[SSAValue(use.idx)] - inst[:stmt] = compute_value_for_use(ir, domtree, allblocks, + newvalue = compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, use.idx) - add_flag!(inst, IR_FLAG_REFINED) + inst[:stmt] = newvalue + newvaluetyp = argextype(newvalue, ir) + ⋤ = strictneqpartialorder(𝕃ₒ) + if newvaluetyp ⋤ inst[:type] + inst[:type] = newvaluetyp + add_flag!(inst, IR_FLAG_REFINED) + end elseif use.kind === :isdefined continue # already rewritten if possible elseif use.kind === :nopreserve @@ -1878,11 +1893,12 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int,Tuple{SPCSet,SSADefUse}} for b in phiblocks n = ir[phinodes[b]][:stmt]::PhiNode result_t = Bottom + ⊔ = join(𝕃ₒ) for p in ir.cfg.blocks[b].preds push!(n.edges, p) v = compute_value_for_block(ir, domtree, allblocks, du, phinodes, fidx, p) push!(n.values, v) - result_t = tmerge(𝕃ₒ, result_t, argextype(v, ir)) + result_t = result_t ⊔ argextype(v, ir) end ir[phinodes[b]][:type] = result_t end diff --git a/Compiler/src/typeinfer.jl b/Compiler/src/typeinfer.jl index 83ec0271ea474..76830dda4cb78 100644 --- a/Compiler/src/typeinfer.jl +++ b/Compiler/src/typeinfer.jl @@ -999,7 +999,7 @@ function typeinf_ircode(interp::AbstractInterpreter, mi::MethodInstance, end (; result) = frame opt = OptimizationState(frame, interp) - ir = run_passes_ipo_safe(opt.src, opt, optimize_until) + ir = run_passes_ipo_safe(interp, opt, result; optimize_until) rt = widenconst(ignorelimited(result.result)) return ir, rt end @@ -1024,6 +1024,7 @@ function typeinf_frame(interp::AbstractInterpreter, mi::MethodInstance, run_opti opt = OptimizationState(frame, interp) optimize(interp, opt, frame.result) src = ir_to_codeinf!(opt) + src.rettype = widenconst(result.result) end result.src = frame.src = src end diff --git a/Compiler/test/inference.jl b/Compiler/test/inference.jl index b3099897faf51..ad6ead8a0dac0 100644 --- a/Compiler/test/inference.jl +++ b/Compiler/test/inference.jl @@ -3495,7 +3495,7 @@ f31974(n::Int) = f31974(1:n) @test code_typed(f31974, Tuple{Int}) !== nothing f_overly_abstract_complex() = Complex(Ref{Number}(1)[]) -@test Base.return_types(f_overly_abstract_complex, Tuple{}) == [Complex] +@test Base.infer_return_type(f_overly_abstract_complex, Tuple{}) == Complex{Int} # Issue 26724 const IntRange = AbstractUnitRange{<:Integer} @@ -6126,3 +6126,18 @@ function func_swapglobal!_must_throw(x) end @test Base.infer_return_type(func_swapglobal!_must_throw, (Int,); interp=SwapGlobalInterp()) === Union{} @test !Compiler.is_effect_free(Base.infer_effects(func_swapglobal!_must_throw, (Int,); interp=SwapGlobalInterp()) ) + +# opt inf +@test Base.infer_return_type((Vector{Any},)) do argtypes + box = Core.Box() + box.contents = argtypes + return length(box.contents) +end == Int +@test Base.infer_return_type((Vector{Any},)) do argtypes + local argtypesi + function cls() + argtypesi = @noinline copy(argtypes) + return length(argtypesi) + end + return @inline cls() +end == Int