From 5664ae5cf06ccd2d4e46d0f283836f29eb1e5d04 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Thu, 19 Dec 2024 05:12:28 +0900 Subject: [PATCH] make EA able to handle `new_nodes` --- Compiler/src/ssair/EscapeAnalysis.jl | 59 +++++++++++++++++++++++----- Compiler/test/EscapeAnalysis.jl | 37 +++++++++++++++++ 2 files changed, 86 insertions(+), 10 deletions(-) diff --git a/Compiler/src/ssair/EscapeAnalysis.jl b/Compiler/src/ssair/EscapeAnalysis.jl index 3fc95a44c21120..d7a59dde56a206 100644 --- a/Compiler/src/ssair/EscapeAnalysis.jl +++ b/Compiler/src/ssair/EscapeAnalysis.jl @@ -19,13 +19,13 @@ import Base: ==, ∈, copy, delete!, getindex, isempty, setindex! using Core: Builtin, IntrinsicFunction, SimpleVector, ifelse, sizeof using Core.IR using Base: # Base definitions - @__MODULE__, @assert, @eval, @goto, @inbounds, @inline, @label, @noinline, - @nospecialize, @specialize, BitSet, IdDict, IdSet, UnitRange, Vector, + @__MODULE__, @assert, @eval, @goto, @inbounds, @inline, @isdefined, @label, @noinline, + @nospecialize, @specialize, BitSet, IdDict, IdSet, Pair, UnitRange, Vector, _bits_findnext, copy!, empty!, enumerate, fill!, first, get, hasintersect, haskey, - isassigned, isexpr, last, length, max, min, missing, only, println, push!, pushfirst!, - resize!, :, !, !==, <, <<, >, =>, ≠, ≤, ≥, ∉, ⊆, ⊇, &, *, +, -, | + isassigned, isexpr, keys, last, length, max, min, missing, only, println, push!, + pushfirst!, resize!, :, !, !==, <, <<, >, =>, ≠, ≤, ≥, ∉, ⊆, ⊇, &, *, +, -, | using ..Compiler: # Compiler specific definitions - @show, Compiler, HandlerInfo, IRCode, IR_FLAG_NOTHROW, SimpleHandler, + @show, Compiler, HandlerInfo, IRCode, IR_FLAG_NOTHROW, NewNodeInfo, SimpleHandler, argextype, block_for_inst, compute_trycatch, fieldcount_noerror, gethandler, has_flag, intrinsic_nothrow, is_meta_expr_head, is_identity_free_argtype, isterminator, singleton_type, try_compute_field, try_compute_fieldidx, widenconst @@ -506,7 +506,7 @@ struct BlockEscapeState{Sealed#=::Bool=#} nargs::Int end function BlockEscapeState(ir::IRCode, nargs::Int) - nstmts = length(ir.stmts) + nstmts = length(ir.stmts) + length(ir.new_nodes) nelms = nargs + nstmts escapes = EscapeTable() aliasset = AliasSet(nelms) @@ -694,19 +694,54 @@ Analyzes escape information in `ir`: retrieves cached argument escape information """ function analyze_escapes(ir::IRCode, nargs::Int, get_escape_cache) - @assert isempty(ir.new_nodes.stmts) "compacted IRCode is assumed currently" currbb = 1 bbs = ir.cfg.blocks W = BitSet() W.offset = 0 # for _bits_findnext astate = AnalysisState(ir, nargs, get_escape_cache) (; currstate) = astate + if isempty(ir.new_nodes) + new_nodes_pcs = nothing + else + new_nodes_pcs = IdDict{Int,Vector{Pair{Int,NewNodeInfo}}}() + for (i, nni) in enumerate(ir.new_nodes.info) + if haskey(new_nodes_pcs, nni.pos) + push!(new_nodes_pcs[nni.pos], i => nni) + else + new_nodes_pcs[nni.pos] = Pair{Int,NewNodeInfo}[i => nni] + end + end + end + while true local nextbb::Int bbstart, bbend = first(bbs[currbb].stmts), last(bbs[currbb].stmts) for pc = bbstart:bbend - stmt = ir[SSAValue(pc)][:stmt] - if pc == bbend + local new_nodes_counter::Int = 0 + if new_nodes_pcs === nothing || pc ∉ keys(new_nodes_pcs) + stmt = ir[SSAValue(pc)][:stmt] + isterminator = pc == bbend + else + new_node_infos = new_nodes_pcs[pc] + attach_before_idxs = Int[i for (i, nni) in new_node_infos if !nni.attach_after] + attach_after_idxs = Int[i for (i, nni) in new_node_infos if nni.attach_after] + na, nb = length(attach_after_idxs), length(attach_before_idxs) + n_nodes = new_nodes_counter = na + nb + 1 # +1 for this statement + @label analyze_new_node + curridx = n_nodes - new_nodes_counter + 1 + if curridx ≤ nb + stmt = ir.new_nodes.stmts[attach_before_idxs[curridx]][:stmt] + elseif curridx == nb + 1 + stmt = ir[SSAValue(pc)][:stmt] + else + @assert curridx ≤ n_nodes + stmt = ir.new_nodes.stmts[attach_after_idxs[curridx - nb - 1]][:stmt] + end + isterminator = curridx == n_nodes + new_nodes_counter -= 1 + end + + if isterminator # if this is the last statement of the current block, handle the control-flow if stmt isa GotoNode succs = bbs[currbb].succs @@ -770,6 +805,10 @@ function analyze_escapes(ir::IRCode, nargs::Int, get_escape_cache) end end end + + if new_nodes_counter > 0 + @goto analyze_new_node + end end begin @label fall_through @@ -799,7 +838,7 @@ function analyze_escapes(ir::IRCode, nargs::Int, get_escape_cache) end function initialize_state!(currstate::BlockEscapeState, ir::IRCode, nargs::Int) - nstmts = length(ir.stmts) + nstmts = length(ir.stmts) + length(ir.new_nodes) nelms = nargs + nstmts empty!(currstate.escapes) resize!(currstate.aliasset.parents, nelms) diff --git a/Compiler/test/EscapeAnalysis.jl b/Compiler/test/EscapeAnalysis.jl index 6b686ee673138c..456e167ef4813e 100644 --- a/Compiler/test/EscapeAnalysis.jl +++ b/Compiler/test/EscapeAnalysis.jl @@ -1896,4 +1896,41 @@ let result = code_escapes(get_value, (Int,Any)) end == 1 end +# `analyze_escapes` should be able to handle `IRCode` with new nodes +let code = Any[ + # block 1 + #=1=# Expr(:new, Base.RefValue{Bool}, Argument(2)) + #=2=# Expr(:call, GlobalRef(Core, :getfield), SSAValue(1), 1) + #=3=# GotoIfNot(SSAValue(2), 5) + # block 2 + #=4=# nothing + # block 3 + #=5=# Expr(:call, GlobalRef(Core, :setfield!), SSAValue(1), 1, false) + #=6=# ReturnNode(nothing) + ] + ir = make_ircode(code; slottypes=Any[Any,Bool]) + ir.stmts[1][:type] = Base.RefValue{Bool} + Compiler.insert_node!(ir, SSAValue(4), Compiler.NewInstruction(Expr(:call, GlobalRef(Core, :setfield!), SSAValue(1), 1, false), Any), #=attach_after=#false) + Compiler.insert_node!(ir, SSAValue(4), Compiler.NewInstruction(GotoNode(3), Any), #=attach_after=#true) + ir[SSAValue(6)] = nothing # eliminate the ReturnNode + s = Compiler.insert_node!(ir, SSAValue(6), Compiler.NewInstruction(Expr(:call, GlobalRef(Core, :getfield), SSAValue(1), 1), Any), #=attach_after=#false) + Compiler.insert_node!(ir, SSAValue(6), Compiler.NewInstruction(ReturnNode(s), Any), #=attach_after=#true) + # now this `ir` would look like: + # 1 ─ %1 = %new(Base.RefValue{Bool}, _2)::Base.RefValue{Bool} │ + # │ %2 = builtin Core.getfield(%1, 1)::Any │ + # └── goto #3 if not %2 │ + # 2 ─ builtin Core.setfield!(%1, 1, false)::Any │ + # │ nothing::Any + # └── goto #3 + # 3 ┄ builtin Core.setfield!(%1, 1, false)::Any │ + # │ %9 = builtin Core.getfield(%1, 1)::Any │ + # │ nothing::Any + # └── return %9 + result = code_escapes(ir, 2) + idxs = findall(iscall((result.ir, getfield)), result.ir.stmts.stmt) + for idx = idxs + @test is_load_forwardable(result, idx) + end +end + end # module test_EA