Skip to content

Commit

Permalink
fix regression
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk committed Jan 6, 2025
1 parent 333604b commit 009ba04
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 19 deletions.
26 changes: 15 additions & 11 deletions Compiler/src/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -828,21 +828,25 @@ mutable struct IRInterpretationState
new_call_inferred::Bool

function IRInterpretationState(interp::AbstractInterpreter,
spec_info::SpecInfo, ir::IRCode, mi::MethodInstance, argtypes::Vector{Any},
spec_info::SpecInfo, ir::IRCode, mi::MethodInstance, argtypes::Union{Nothing,Vector{Any}},
world::UInt, min_world::UInt, max_world::UInt)
curridx = 1
given_argtypes = Vector{Any}(undef, length(argtypes))
for i = 1:length(given_argtypes)
given_argtypes[i] = widenslotwrapper(argtypes[i])
end
if isa(mi.def, Method)
argtypes_refined = Bool[!(optimizer_lattice(interp), ir.argtypes[i], given_argtypes[i])
for i = 1:length(given_argtypes)]
if argtypes !== nothing
given_argtypes = Vector{Any}(undef, length(argtypes))
for i = 1:length(given_argtypes)
given_argtypes[i] = widenslotwrapper(argtypes[i])
end
if isa(mi.def, Method)
argtypes_refined = Bool[!(optimizer_lattice(interp), ir.argtypes[i], given_argtypes[i])
for i = 1:length(given_argtypes)]
else
argtypes_refined = Bool[false for _ = 1:length(given_argtypes)]
end
empty!(ir.argtypes)
append!(ir.argtypes, given_argtypes)
else
argtypes_refined = Bool[false for i = 1:length(given_argtypes)]
argtypes_refined = Bool[false for _ = 1:length(ir.argtypes)]
end
empty!(ir.argtypes)
append!(ir.argtypes, given_argtypes)
tpdum = TwoPhaseDefUseMap(length(ir.stmts))
ssa_refined = BitSet()
lazyreachability = LazyCFGReachability(ir)
Expand Down
14 changes: 10 additions & 4 deletions Compiler/src/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1067,20 +1067,26 @@ function optinf!(ir::IRCode, interp::AbstractInterpreter, sv::OptimizationState,
spec_info = SpecInfo(ci)
world = get_inference_world(interp)
min_world, max_world = first(result.valid_worlds), last(result.valid_worlds)
irsv = IRInterpretationState(interp, spec_info, ir, result.linfo, ir.argtypes,
irsv = IRInterpretationState(interp, spec_info, ir, result.linfo, #=argtypes=#nothing,
world, min_world, max_world)
rt, (nothrow, noub) = ir_abstract_constant_propagation(interp, irsv)
rt, (nothrow, noub) = ir_abstract_constant_propagation(interp, irsv;
# While `optinf!` itself performs reanalysis based on `IR_FLAG_REFINED`, since
# `IR_FLAG_REFINED` is also useful for subsequent `semi_concrete_eval`s that may
# occur on this `IRCode`, it is necessary to ensure that `optinf!` does not
# subtract `IR_FLAG_REFINED` (otherwise there might cases where the expected
# constant propagation information is not obtained through `irinterp`)
sub_ir_flag_refined=false)
if irsv.new_call_inferred
ir = ssa_inlining_pass!(ir, sv.inlining, ci.propagate_inbounds)
ir = compact!(ir)
effects = result.effects
effects = result.ipo_effects
if nothrow
effects = Effects(effects; nothrow=true)
end
if noub
effects = Effects(effects; noub=ALWAYS_TRUE)
end
result.effects = effects
result.ipo_effects = effects
result.exc_result = refine_exception_type(result.exc_result, effects)
= strictneqpartialorder(ipo_lattice(interp))
result.result = rt result.result ? rt : result.result
Expand Down
9 changes: 5 additions & 4 deletions Compiler/src/ssair/irinterp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,8 @@ function reprocess_instruction!(interp::AbstractInterpreter, inst::Instruction,
end
return true
end
= strictneqpartialorder(typeinf_lattice(interp))
if rt inst[:type]
= strictpartialorder(typeinf_lattice(interp))
if rt inst[:type]
inst[:type] = rt
return true
end
Expand Down Expand Up @@ -319,6 +319,7 @@ function is_all_const_call(@nospecialize(stmt), interp::AbstractInterpreter, irs
end

function ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IRInterpretationState;
sub_ir_flag_refined::Bool = true,
externally_refined::Union{Nothing,BitSet} = nothing)
(; ir, tpdum, ssa_refined) = irsv

Expand All @@ -341,7 +342,7 @@ function ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IRI
any_refined = false
if has_flag(flag, IR_FLAG_REFINED)
any_refined = true
sub_flag!(inst, IR_FLAG_REFINED)
sub_ir_flag_refined && sub_flag!(inst, IR_FLAG_REFINED)
elseif is_all_const_call(stmt, interp, irsv)
# force reinference on calls with all constant arguments
any_refined = true
Expand Down Expand Up @@ -394,7 +395,7 @@ function ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IRI
stmt = inst[:stmt]
flag = inst[:flag]
if has_flag(flag, IR_FLAG_REFINED)
sub_flag!(inst, IR_FLAG_REFINED)
sub_ir_flag_refined && sub_flag!(inst, IR_FLAG_REFINED)
push!(stmt_ip, idx)
end
check_ret!(stmt, idx)
Expand Down
6 changes: 6 additions & 0 deletions Compiler/test/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6190,3 +6190,9 @@ let mi = only(methods(func_opt_inf, ())).specializations
ci = mi.cache
@test ci.rettype_const == sin(1.0)
end
@test fully_eliminated((BitSet,)) do b
iterate((pairs((b,))))[1][1]
end
@test fully_eliminated((BitSet,)) do b
iterate((pairs((b,))))[2]
end

0 comments on commit 009ba04

Please sign in to comment.