From 009ba04de9784d1972efd87c9c65d49ac7349cbf Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Wed, 1 Jan 2025 01:57:45 +0900 Subject: [PATCH] fix regression --- Compiler/src/inferencestate.jl | 26 +++++++++++++++----------- Compiler/src/optimize.jl | 14 ++++++++++---- Compiler/src/ssair/irinterp.jl | 9 +++++---- Compiler/test/inference.jl | 6 ++++++ 4 files changed, 36 insertions(+), 19 deletions(-) diff --git a/Compiler/src/inferencestate.jl b/Compiler/src/inferencestate.jl index 5003a884ab1cd..f5d3d41f349e1 100644 --- a/Compiler/src/inferencestate.jl +++ b/Compiler/src/inferencestate.jl @@ -828,21 +828,25 @@ mutable struct IRInterpretationState new_call_inferred::Bool function IRInterpretationState(interp::AbstractInterpreter, - spec_info::SpecInfo, ir::IRCode, mi::MethodInstance, argtypes::Vector{Any}, + spec_info::SpecInfo, ir::IRCode, mi::MethodInstance, argtypes::Union{Nothing,Vector{Any}}, world::UInt, min_world::UInt, max_world::UInt) curridx = 1 - given_argtypes = Vector{Any}(undef, length(argtypes)) - for i = 1:length(given_argtypes) - given_argtypes[i] = widenslotwrapper(argtypes[i]) - end - if isa(mi.def, Method) - argtypes_refined = Bool[!⊑(optimizer_lattice(interp), ir.argtypes[i], given_argtypes[i]) - for i = 1:length(given_argtypes)] + if argtypes !== nothing + given_argtypes = Vector{Any}(undef, length(argtypes)) + for i = 1:length(given_argtypes) + given_argtypes[i] = widenslotwrapper(argtypes[i]) + end + if isa(mi.def, Method) + argtypes_refined = Bool[!⊑(optimizer_lattice(interp), ir.argtypes[i], given_argtypes[i]) + for i = 1:length(given_argtypes)] + else + argtypes_refined = Bool[false for _ = 1:length(given_argtypes)] + end + empty!(ir.argtypes) + append!(ir.argtypes, given_argtypes) else - argtypes_refined = Bool[false for i = 1:length(given_argtypes)] + argtypes_refined = Bool[false for _ = 1:length(ir.argtypes)] end - empty!(ir.argtypes) - append!(ir.argtypes, given_argtypes) tpdum = TwoPhaseDefUseMap(length(ir.stmts)) ssa_refined = BitSet() lazyreachability = LazyCFGReachability(ir) diff --git a/Compiler/src/optimize.jl b/Compiler/src/optimize.jl index d12da346ccdbb..ad8eb1608cd46 100644 --- a/Compiler/src/optimize.jl +++ b/Compiler/src/optimize.jl @@ -1067,20 +1067,26 @@ function optinf!(ir::IRCode, interp::AbstractInterpreter, sv::OptimizationState, spec_info = SpecInfo(ci) world = get_inference_world(interp) min_world, max_world = first(result.valid_worlds), last(result.valid_worlds) - irsv = IRInterpretationState(interp, spec_info, ir, result.linfo, ir.argtypes, + irsv = IRInterpretationState(interp, spec_info, ir, result.linfo, #=argtypes=#nothing, world, min_world, max_world) - rt, (nothrow, noub) = ir_abstract_constant_propagation(interp, irsv) + rt, (nothrow, noub) = ir_abstract_constant_propagation(interp, irsv; + # While `optinf!` itself performs reanalysis based on `IR_FLAG_REFINED`, since + # `IR_FLAG_REFINED` is also useful for subsequent `semi_concrete_eval`s that may + # occur on this `IRCode`, it is necessary to ensure that `optinf!` does not + # subtract `IR_FLAG_REFINED` (otherwise there might cases where the expected + # constant propagation information is not obtained through `irinterp`) + sub_ir_flag_refined=false) if irsv.new_call_inferred ir = ssa_inlining_pass!(ir, sv.inlining, ci.propagate_inbounds) ir = compact!(ir) - effects = result.effects + effects = result.ipo_effects if nothrow effects = Effects(effects; nothrow=true) end if noub effects = Effects(effects; noub=ALWAYS_TRUE) end - result.effects = effects + result.ipo_effects = effects result.exc_result = refine_exception_type(result.exc_result, effects) ⋤ = strictneqpartialorder(ipo_lattice(interp)) result.result = rt ⋤ result.result ? rt : result.result diff --git a/Compiler/src/ssair/irinterp.jl b/Compiler/src/ssair/irinterp.jl index 2b8050f778d85..df34fb711df06 100644 --- a/Compiler/src/ssair/irinterp.jl +++ b/Compiler/src/ssair/irinterp.jl @@ -229,8 +229,8 @@ function reprocess_instruction!(interp::AbstractInterpreter, inst::Instruction, end return true end - ⋤ = strictneqpartialorder(typeinf_lattice(interp)) - if rt ⋤ inst[:type] + ⊏ = strictpartialorder(typeinf_lattice(interp)) + if rt ⊏ inst[:type] inst[:type] = rt return true end @@ -319,6 +319,7 @@ function is_all_const_call(@nospecialize(stmt), interp::AbstractInterpreter, irs end function ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IRInterpretationState; + sub_ir_flag_refined::Bool = true, externally_refined::Union{Nothing,BitSet} = nothing) (; ir, tpdum, ssa_refined) = irsv @@ -341,7 +342,7 @@ function ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IRI any_refined = false if has_flag(flag, IR_FLAG_REFINED) any_refined = true - sub_flag!(inst, IR_FLAG_REFINED) + sub_ir_flag_refined && sub_flag!(inst, IR_FLAG_REFINED) elseif is_all_const_call(stmt, interp, irsv) # force reinference on calls with all constant arguments any_refined = true @@ -394,7 +395,7 @@ function ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IRI stmt = inst[:stmt] flag = inst[:flag] if has_flag(flag, IR_FLAG_REFINED) - sub_flag!(inst, IR_FLAG_REFINED) + sub_ir_flag_refined && sub_flag!(inst, IR_FLAG_REFINED) push!(stmt_ip, idx) end check_ret!(stmt, idx) diff --git a/Compiler/test/inference.jl b/Compiler/test/inference.jl index 04c675862e7a5..2ea08f0868538 100644 --- a/Compiler/test/inference.jl +++ b/Compiler/test/inference.jl @@ -6190,3 +6190,9 @@ let mi = only(methods(func_opt_inf, ())).specializations ci = mi.cache @test ci.rettype_const == sin(1.0) end +@test fully_eliminated((BitSet,)) do b + iterate((pairs((b,))))[1][1] +end +@test fully_eliminated((BitSet,)) do b + iterate((pairs((b,))))[2] +end