Skip to content

Commit

Permalink
make EA able to handle new_nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk committed Dec 19, 2024
1 parent 70e1ca9 commit 5664ae5
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 10 deletions.
59 changes: 49 additions & 10 deletions Compiler/src/ssair/EscapeAnalysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ import Base: ==, ∈, copy, delete!, getindex, isempty, setindex!
using Core: Builtin, IntrinsicFunction, SimpleVector, ifelse, sizeof
using Core.IR
using Base: # Base definitions
@__MODULE__, @assert, @eval, @goto, @inbounds, @inline, @label, @noinline,
@nospecialize, @specialize, BitSet, IdDict, IdSet, UnitRange, Vector,
@__MODULE__, @assert, @eval, @goto, @inbounds, @inline, @isdefined, @label, @noinline,
@nospecialize, @specialize, BitSet, IdDict, IdSet, Pair, UnitRange, Vector,
_bits_findnext, copy!, empty!, enumerate, fill!, first, get, hasintersect, haskey,
isassigned, isexpr, last, length, max, min, missing, only, println, push!, pushfirst!,
resize!, :, !, !==, <, <<, >, =>, , , , , , , &, *, +, -, |
isassigned, isexpr, keys, last, length, max, min, missing, only, println, push!,
pushfirst!, resize!, :, !, !==, <, <<, >, =>, , , , , , , &, *, +, -, |
using ..Compiler: # Compiler specific definitions
@show, Compiler, HandlerInfo, IRCode, IR_FLAG_NOTHROW, SimpleHandler,
@show, Compiler, HandlerInfo, IRCode, IR_FLAG_NOTHROW, NewNodeInfo, SimpleHandler,
argextype, block_for_inst, compute_trycatch, fieldcount_noerror, gethandler, has_flag,
intrinsic_nothrow, is_meta_expr_head, is_identity_free_argtype, isterminator,
singleton_type, try_compute_field, try_compute_fieldidx, widenconst
Expand Down Expand Up @@ -506,7 +506,7 @@ struct BlockEscapeState{Sealed#=::Bool=#}
nargs::Int
end
function BlockEscapeState(ir::IRCode, nargs::Int)
nstmts = length(ir.stmts)
nstmts = length(ir.stmts) + length(ir.new_nodes)
nelms = nargs + nstmts
escapes = EscapeTable()
aliasset = AliasSet(nelms)
Expand Down Expand Up @@ -694,19 +694,54 @@ Analyzes escape information in `ir`:
retrieves cached argument escape information
"""
function analyze_escapes(ir::IRCode, nargs::Int, get_escape_cache)
@assert isempty(ir.new_nodes.stmts) "compacted IRCode is assumed currently"
currbb = 1
bbs = ir.cfg.blocks
W = BitSet()
W.offset = 0 # for _bits_findnext
astate = AnalysisState(ir, nargs, get_escape_cache)
(; currstate) = astate
if isempty(ir.new_nodes)
new_nodes_pcs = nothing
else
new_nodes_pcs = IdDict{Int,Vector{Pair{Int,NewNodeInfo}}}()
for (i, nni) in enumerate(ir.new_nodes.info)
if haskey(new_nodes_pcs, nni.pos)
push!(new_nodes_pcs[nni.pos], i => nni)
else
new_nodes_pcs[nni.pos] = Pair{Int,NewNodeInfo}[i => nni]
end
end
end

while true
local nextbb::Int
bbstart, bbend = first(bbs[currbb].stmts), last(bbs[currbb].stmts)
for pc = bbstart:bbend
stmt = ir[SSAValue(pc)][:stmt]
if pc == bbend
local new_nodes_counter::Int = 0
if new_nodes_pcs === nothing || pc keys(new_nodes_pcs)
stmt = ir[SSAValue(pc)][:stmt]
isterminator = pc == bbend
else
new_node_infos = new_nodes_pcs[pc]
attach_before_idxs = Int[i for (i, nni) in new_node_infos if !nni.attach_after]
attach_after_idxs = Int[i for (i, nni) in new_node_infos if nni.attach_after]
na, nb = length(attach_after_idxs), length(attach_before_idxs)
n_nodes = new_nodes_counter = na + nb + 1 # +1 for this statement
@label analyze_new_node
curridx = n_nodes - new_nodes_counter + 1
if curridx nb
stmt = ir.new_nodes.stmts[attach_before_idxs[curridx]][:stmt]
elseif curridx == nb + 1
stmt = ir[SSAValue(pc)][:stmt]
else
@assert curridx n_nodes
stmt = ir.new_nodes.stmts[attach_after_idxs[curridx - nb - 1]][:stmt]
end
isterminator = curridx == n_nodes
new_nodes_counter -= 1
end

if isterminator
# if this is the last statement of the current block, handle the control-flow
if stmt isa GotoNode
succs = bbs[currbb].succs
Expand Down Expand Up @@ -770,6 +805,10 @@ function analyze_escapes(ir::IRCode, nargs::Int, get_escape_cache)
end
end
end

if new_nodes_counter > 0
@goto analyze_new_node
end
end

begin @label fall_through
Expand Down Expand Up @@ -799,7 +838,7 @@ function analyze_escapes(ir::IRCode, nargs::Int, get_escape_cache)
end

function initialize_state!(currstate::BlockEscapeState, ir::IRCode, nargs::Int)
nstmts = length(ir.stmts)
nstmts = length(ir.stmts) + length(ir.new_nodes)
nelms = nargs + nstmts
empty!(currstate.escapes)
resize!(currstate.aliasset.parents, nelms)
Expand Down
37 changes: 37 additions & 0 deletions Compiler/test/EscapeAnalysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1896,4 +1896,41 @@ let result = code_escapes(get_value, (Int,Any))
end == 1
end

# `analyze_escapes` should be able to handle `IRCode` with new nodes
let code = Any[
# block 1
#=1=# Expr(:new, Base.RefValue{Bool}, Argument(2))
#=2=# Expr(:call, GlobalRef(Core, :getfield), SSAValue(1), 1)
#=3=# GotoIfNot(SSAValue(2), 5)
# block 2
#=4=# nothing
# block 3
#=5=# Expr(:call, GlobalRef(Core, :setfield!), SSAValue(1), 1, false)
#=6=# ReturnNode(nothing)
]
ir = make_ircode(code; slottypes=Any[Any,Bool])
ir.stmts[1][:type] = Base.RefValue{Bool}
Compiler.insert_node!(ir, SSAValue(4), Compiler.NewInstruction(Expr(:call, GlobalRef(Core, :setfield!), SSAValue(1), 1, false), Any), #=attach_after=#false)
Compiler.insert_node!(ir, SSAValue(4), Compiler.NewInstruction(GotoNode(3), Any), #=attach_after=#true)
ir[SSAValue(6)] = nothing # eliminate the ReturnNode
s = Compiler.insert_node!(ir, SSAValue(6), Compiler.NewInstruction(Expr(:call, GlobalRef(Core, :getfield), SSAValue(1), 1), Any), #=attach_after=#false)
Compiler.insert_node!(ir, SSAValue(6), Compiler.NewInstruction(ReturnNode(s), Any), #=attach_after=#true)
# now this `ir` would look like:
# 1 ─ %1 = %new(Base.RefValue{Bool}, _2)::Base.RefValue{Bool} │
# │ %2 = builtin Core.getfield(%1, 1)::Any │
# └── goto #3 if not %2 │
# 2 ─ builtin Core.setfield!(%1, 1, false)::Any │
# │ nothing::Any
# └── goto #3
# 3 ┄ builtin Core.setfield!(%1, 1, false)::Any │
# │ %9 = builtin Core.getfield(%1, 1)::Any │
# │ nothing::Any
# └── return %9
result = code_escapes(ir, 2)
idxs = findall(iscall((result.ir, getfield)), result.ir.stmts.stmt)
for idx = idxs
@test is_load_forwardable(result, idx)
end
end

end # module test_EA

0 comments on commit 5664ae5

Please sign in to comment.