Skip to content

Commit

Permalink
Ensure closures capture the types of typed variables
Browse files Browse the repository at this point in the history
  • Loading branch information
c42f committed Feb 7, 2025
1 parent 4c3494f commit c7c7bde
Show file tree
Hide file tree
Showing 7 changed files with 179 additions and 93 deletions.
13 changes: 5 additions & 8 deletions src/closure_conversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,29 +90,26 @@ end
# global and for converting the return value of a function call to the declared
# return type.
function convert_for_type_decl(ctx, srcref, ex, type, do_typeassert)
# Require that the caller make `type` "simple", for now (can generalize
# later if necessary)
kt = kind(type)
@assert (kt == K"Identifier" || kt == K"BindingId" || is_literal(kt))
# Use a slot to permit union-splitting this in inference
tmp = new_local_binding(ctx, srcref, "tmp", is_always_defined=true)

@ast ctx srcref [K"block"
type_tmp := type
# [K"=" type_ssa renumber_assigned_ssavalues(type)]
[K"=" tmp ex]
[K"if"
[K"call" "isa"::K"core" tmp type]
[K"call" "isa"::K"core" tmp type_tmp]
"nothing"::K"core"
[K"="
tmp
if do_typeassert
[K"call"
"typeassert"::K"core"
[K"call" "convert"::K"top" type tmp]
type
[K"call" "convert"::K"top" type_tmp tmp]
type_tmp
]
else
[K"call" "convert"::K"top" type tmp]
[K"call" "convert"::K"top" type_tmp tmp]
end
]
]
Expand Down
45 changes: 27 additions & 18 deletions src/scope_analysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,14 @@ function _resolve_scopes(ctx, ex::SyntaxTree)
throw(LoweringError(ex, "type declarations for global variables must be at top level, not inside a function"))
end
end
id = ex_out[1]
if kind(id) != K"Placeholder"
binfo = lookup_binding(ctx, id)
if !isnothing(binfo.type)
throw(LoweringError(ex, "multiple type declarations found for `$(binfo.name)`"))
end
update_binding!(ctx, id; type=ex_out[2])
end
ex_out
elseif k == K"always_defined"
id = lookup_var(ctx, NameKey(ex[1]))
Expand Down Expand Up @@ -624,14 +632,22 @@ function analyze_variables!(ctx, ex)
k = kind(ex)
if k == K"BindingId"
if has_lambda_binding(ctx, ex)
# FIXME: Move this after closure conversion so that we don't need
# TODO: Move this after closure conversion so that we don't need
# to model the closure conversion transformations here.
update_lambda_binding!(ctx, ex, is_read=true)
else
binfo = lookup_binding(ctx, ex.var_id)
if !binfo.is_ssa && binfo.kind != :global
# The type of typed locals is invisible in the previous pass,
# but is filled in here.
init_lambda_binding(ctx.lambda_bindings, ex.var_id, is_captured=true, is_read=true)
update_binding!(ctx, ex, is_captured=true)
end
end
elseif is_leaf(ex) || is_quoted(ex)
return
elseif k == K"local" || k == K"global"
# Uses of bindings which don't count as uses.
# Presence of BindingId within local/global is ignored.
return
elseif k == K"="
lhs = ex[1]
Expand All @@ -640,6 +656,12 @@ function analyze_variables!(ctx, ex)
if has_lambda_binding(ctx, lhs)
update_lambda_binding!(ctx, lhs, is_assigned=true)
end
lhs_binfo = lookup_binding(ctx, lhs)
if !isnothing(lhs_binfo.type)
# Assignments introduce a variable's type later during closure
# conversion, but we must model that explicitly here.
analyze_variables!(ctx, lhs_binfo.type)
end
end
analyze_variables!(ctx, ex[2])
elseif k == K"function_decl"
Expand All @@ -655,17 +677,6 @@ function analyze_variables!(ctx, ex)
if kind(ex[1]) != K"BindingId" || lookup_binding(ctx, ex[1]).kind !== :local
analyze_variables!(ctx, ex[1])
end
elseif k == K"decl"
@chk numchildren(ex) == 2
id = ex[1]
if kind(id) != K"Placeholder"
binfo = lookup_binding(ctx, id)
if !isnothing(binfo.type)
throw(LoweringError(ex, "multiple type declarations found for `$(binfo.name)`"))
end
update_binding!(ctx, id; type=ex[2])
end
analyze_variables!(ctx, ex[2])
elseif k == K"const"
id = ex[1]
if lookup_binding(ctx, id).kind == :local
Expand All @@ -677,7 +688,7 @@ function analyze_variables!(ctx, ex)
if kind(name) == K"BindingId"
id = name.var_id
if has_lambda_binding(ctx, id)
# FIXME: Move this after closure conversion so that we don't need
# TODO: Move this after closure conversion so that we don't need
# to model the closure conversion transformations.
update_lambda_binding!(ctx, id, is_called=true)
end
Expand Down Expand Up @@ -710,9 +721,10 @@ function analyze_variables!(ctx, ex)
end
ctx2 = VariableAnalysisContext(ctx.graph, ctx.bindings, ctx.mod, lambda_bindings,
ctx.method_def_stack, ctx.closure_bindings)
# Add any captured bindings to the enclosing lambda, if necessary.
foreach(e->analyze_variables!(ctx2, e), ex[3:end]) # body & return type
for (id,lbinfo) in pairs(lambda_bindings.bindings)
if lbinfo.is_captured
# Add any captured bindings to the enclosing lambda, if necessary.
outer_lbinfo = lookup_lambda_binding(ctx.lambda_bindings, id)
if isnothing(outer_lbinfo)
# Inner lambda captures a variable. If it's not yet present
Expand All @@ -723,9 +735,6 @@ function analyze_variables!(ctx, ex)
end
end
end

# TODO: Types of any assigned captured vars will also be used and might be captured.
foreach(e->analyze_variables!(ctx2, e), ex[3:end])
else
foreach(e->analyze_variables!(ctx, e), children(ex))
end
Expand Down
26 changes: 12 additions & 14 deletions test/assignments_ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,21 +81,19 @@ end
1 (newvar slot₁/x)
2 TestMod.f
3 (call %₂)
4 (= slot₂/tmp %₃)
5 slot₂/tmp
6 TestMod.T
7 (call core.isa %%)
4 TestMod.T
5 (= slot₂/tmp %₃)
6 slot₂/tmp
7 (call core.isa %%)
8 (gotoifnot %₇ label₁₀)
9 (goto label₁₅)
10 TestMod.T
11 slot₂/tmp
12 (call top.convert %₁₀ %₁₁)
13 TestMod.T
14 (= slot₂/tmp (call core.typeassert %₁₂ %₁₃))
15 slot₂/tmp
16 (= slot₁/x %₁₅)
17 slot₁/x
18 (return %₁₇)
9 (goto label₁₃)
10 slot₂/tmp
11 (call top.convert %%₁₀)
12 (= slot₂/tmp (call core.typeassert %₁₁ %₄))
13 slot₂/tmp
14 (= slot₁/x %₁₃)
15 slot₁/x
16 (return %₁₅)

########################################
# "complex lhs" of `::T` => type-assert, not decl
Expand Down
56 changes: 56 additions & 0 deletions test/closures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,62 @@ begin
end
""") === (1,2,3)

# Closure with return type must capture the return type
@test JuliaLowering.include_string(test_mod, """
let T = Int
function f_captured_return_type()::T
2.0
end
f_captured_return_type()
end
""") === 2

# Capturing a typed local
@test JuliaLowering.include_string(test_mod, """
let T = Int
x::T = 1.0
function f_captured_typed_local()
x = 2.0
end
f_captured_typed_local()
x
end
""") === 2

# Capturing a typed local where the type is a nontrivial expression
@test begin
res = JuliaLowering.include_string(test_mod, """
let T = Int, V=Vector
x::V{T} = [1,2]
function f_captured_typed_local_composite()
x = [100.0, 200.0]
end
f_captured_typed_local_composite()
x
end
""")
res == [100, 200] && eltype(res) == Int
end

# Evil case where we mutate `T` which is the type of `x`, such that x is
# eventually set to a Float64.
#
# Completely dynamic types for variables should be disallowed somehow?? For
# example, by emitting the expression computing the type of `x` alongside the
# newvar node. However, for now we verify that this potentially evil behavior
# is compatible with the existing implementation :)
@test JuliaLowering.include_string(test_mod, """
let T = Int
x::T = 1.0
function f_captured_mutating_typed_local()
x = 2
end
T = Float64
f_captured_mutating_typed_local()
x
end
""") === 2.0

# Anon function syntax
@test JuliaLowering.include_string(test_mod, """
begin
Expand Down
34 changes: 34 additions & 0 deletions test/closures_ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,40 @@ end
46 slot₃/f_kw_closure
47 (return %₄₆)

########################################
# Closure capturing a typed local must also capture the type expression
# [method_filter: #f_captured_typed_local##0]
let T=Blah
x::T = 1.0
function f_captured_typed_local()
x = 2.0
end
f_captured_typed_local()
x
end
#---------------------
slots: [slot₁/#self#(!read) slot₂/T(!read) slot₃/tmp(!read)]
1 2.0
2 (call core.getfield slot₁/#self# :x)
3 (call core.getfield slot₁/#self# :T)
4 (call core.isdefined %:contents)
5 (gotoifnot %₄ label₇)
6 (goto label₉)
7 (newvar slot₂/T)
8 slot₂/T
9 (call core.getfield %:contents)
10 (= slot₃/tmp %₁)
11 slot₃/tmp
12 (call core.isa %₁₁ %₉)
13 (gotoifnot %₁₂ label₁₅)
14 (goto label₁₈)
15 slot₃/tmp
16 (call top.convert %%₁₅)
17 (= slot₃/tmp (call core.typeassert %₁₆ %₉))
18 slot₃/tmp
19 (call core.setfield! %:contents %₁₈)
20 (return %₁)

########################################
# Error: Closure outside any top level context
# (Should only happen in a user-visible way when lowering code emitted
Expand Down
74 changes: 34 additions & 40 deletions test/decls_ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,18 @@ local x::T = 1
#---------------------
1 (newvar slot₁/x)
2 1
3 (= slot₂/tmp %₂)
4 slot₂/tmp
5 TestMod.T
6 (call core.isa %%)
3 TestMod.T
4 (= slot₂/tmp %₂)
5 slot₂/tmp
6 (call core.isa %%)
7 (gotoifnot %₆ label₉)
8 (goto label₁₄)
9 TestMod.T
10 slot₂/tmp
11 (call top.convert %%₁₀)
12 TestMod.T
13 (= slot₂/tmp (call core.typeassert %₁₁ %₁₂))
14 slot₂/tmp
15 (= slot₁/x %₁₄)
16 (return %₂)
8 (goto label₁₂)
9 slot₂/tmp
10 (call top.convert %%₉)
11 (= slot₂/tmp (call core.typeassert %₁₀ %₃))
12 slot₂/tmp
13 (= slot₁/x %₁₂)
14 (return %₂)

########################################
# const
Expand Down Expand Up @@ -133,35 +131,31 @@ end
8 --- method core.nothing %
slots: [slot₁/#self#(!read) slot₂/x slot₃/tmp(!read) slot₄/tmp(!read)]
1 1
2 (= slot₃/tmp %₁)
3 slot₃/tmp
4 TestMod.Int
5 (call core.isa %%)
2 TestMod.Int
3 (= slot₃/tmp %₁)
4 slot₃/tmp
5 (call core.isa %%)
6 (gotoifnot %₅ label₈)
7 (goto label₁)
8 TestMod.Int
9 slot₃/tmp
10 (call top.convert %%)
11 TestMod.Int
12 (= slot/tmp (call core.typeassert %₁₀ %₁₁))
13 slot₃/tmp
14 (= slot₂/x %₁₃)
15 2.0
16 (= slot₄/tmp %₁₅)
17 slot₄/tmp
18 TestMod.Int
19 (call core.isa %₁₇ %₁₈)
20 (gotoifnot %₁₉ label₂₂)
21 (goto label₂₇)
22 TestMod.Int
7 (goto label₁)
8 slot₃/tmp
9 (call top.convert %%₈)
10 (= slot₃/tmp (call core.typeassert %%₂))
11 slot₃/tmp
12 (= slot/x %₁₁)
13 2.0
14 TestMod.Int
15 (= slot₄/tmp %₁₃)
16 slot₄/tmp
17 (call core.isa %₁₆ %₁₄)
18 (gotoifnot %₁₇ label₂₀)
19 (goto label₂₃)
20 slot₄/tmp
21 (call top.convert %₁₄ %₂₀)
22 (= slot₄/tmp (call core.typeassert %₂₁ %₁₄))
23 slot₄/tmp
24 (call top.convert %₂₂ %₂₃)
25 TestMod.Int
26 (= slot₄/tmp (call core.typeassert %₂₄ %₂₅))
27 slot₄/tmp
28 (= slot₂/x %₂₇)
29 slot₂/x
30 (return %₂₉)
24 (= slot₂/x %₂₃)
25 slot₂/x
26 (return %₂₅)
9 TestMod.f
10 (return %₉)

Expand Down
24 changes: 11 additions & 13 deletions test/destructuring_ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -341,20 +341,18 @@ end
1 (newvar slot₁/x)
2 TestMod.rhs
3 (call top.getproperty %:x)
4 (= slot₂/tmp %₃)
5 slot₂/tmp
6 TestMod.T
7 (call core.isa %%)
4 TestMod.T
5 (= slot₂/tmp %₃)
6 slot₂/tmp
7 (call core.isa %%)
8 (gotoifnot %₇ label₁₀)
9 (goto label₁₅)
10 TestMod.T
11 slot₂/tmp
12 (call top.convert %₁₀ %₁₁)
13 TestMod.T
14 (= slot₂/tmp (call core.typeassert %₁₂ %₁₃))
15 slot₂/tmp
16 (= slot₁/x %₁₅)
17 (return %₂)
9 (goto label₁₃)
10 slot₂/tmp
11 (call top.convert %%₁₀)
12 (= slot₂/tmp (call core.typeassert %₁₁ %₄))
13 slot₂/tmp
14 (= slot₁/x %₁₃)
15 (return %₂)

########################################
# Error: Property destructuring with frankentuple
Expand Down

0 comments on commit c7c7bde

Please sign in to comment.