Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update for 1.10 #205

Merged
merged 6 commits into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions src/overdub.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,11 @@ function reflect(@nospecialize(sigtypes::Tuple), world::UInt = get_world_counter
method_instance === nothing && return nothing
method_signature = method.sig
static_params = Any[raw_static_params...]
code_info = Core.Compiler.retrieve_code_info(method_instance)
@static if VERSION >= v"1.10.0-DEV.873"
code_info = Core.Compiler.retrieve_code_info(method_instance, world)
else
code_info = Core.Compiler.retrieve_code_info(method_instance)
end
isa(code_info, CodeInfo) || return nothing
code_info = copy_code_info(code_info)
verbose_lineinfo!(code_info, S)
Expand Down Expand Up @@ -598,13 +602,37 @@ const OVERDUB_FALLBACK = begin
end

# `args` is `(typeof(original_function), map(typeof, original_args_tuple)...)`
function __overdub_generator__(self, context_type, args::Tuple)
function __overdub_generator__(world::UInt, source, self, context_type, args)
if nfields(args) > 0
is_builtin = args[1] <: Core.Builtin
is_invoke = args[1] === typeof(Core.invoke)
if !is_builtin || is_invoke
try
untagged_args = ntuple(i->untagtype(args[i], context_type), nfields(args))
reflection = reflect(untagged_args, world)
if isa(reflection, Reflection)
result = overdub_pass!(reflection, context_type, is_invoke)
isa(result, Expr) && return result
return reflection.code_info
end
catch err
errmsg = "ERROR COMPILING $args IN CONTEXT $(context_type): \n" #* sprint(showerror, err)
errmsg *= "\n" .* repr("text/plain", stacktrace(catch_backtrace()))
return quote
error($errmsg)
end
end
end
end
return copy_code_info(OVERDUB_FALLBACK)
end
function __overdub_generator__(self, context_type, args)
if nfields(args) > 0
is_builtin = args[1] <: Core.Builtin
is_invoke = args[1] === typeof(Core.invoke)
if !is_builtin || is_invoke
try
untagged_args = ((untagtype(args[i], context_type) for i in 1:nfields(args))...,)
untagged_args = ntuple(i->untagtype(args[i], context_type), nfields(args))
reflection = reflect(untagged_args)
if isa(reflection, Reflection)
result = overdub_pass!(reflection, context_type, is_invoke)
Expand Down Expand Up @@ -638,6 +666,18 @@ if VERSION >= v"1.4.0-DEV.304"
end

let line = @__LINE__, file = @__FILE__
@static if VERSION >= v"1.10.0-DEV.873"
@eval (@__MODULE__) begin
function $Cassette.overdub($OVERDUB_CONTEXT_NAME::$Cassette.Context, $OVERDUB_ARGUMENTS_NAME...)
$(Expr(:meta, :generated_only))
$(Expr(:meta, :generated, __overdub_generator__))
end
function $Cassette.recurse($OVERDUB_CONTEXT_NAME::$Cassette.Context, $OVERDUB_ARGUMENTS_NAME...)
$(Expr(:meta, :generated_only))
$(Expr(:meta, :generated, __overdub_generator__))
end
end
else
@eval (@__MODULE__) begin
function $Cassette.overdub($OVERDUB_CONTEXT_NAME::$Cassette.Context, $OVERDUB_ARGUMENTS_NAME...)
$(Expr(:meta, :generated_only))
Expand Down Expand Up @@ -666,6 +706,7 @@ let line = @__LINE__, file = @__FILE__
true)))
end
end
end
end

@doc """
Expand Down
50 changes: 35 additions & 15 deletions test/misctests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ function rosenbrock(x::Vector{Float64})
end

x = rand(2)
@inferred(overdub(RosCtx(), rosenbrock, x))
if VERSION < v"1.9"
@inferred (overdub(RosCtx(), rosenbrock, x))
end

messages = String[]
Cassette.prehook(::RosCtx, f, args...) = push!(messages, string("calling ", f, args))
Expand Down Expand Up @@ -79,16 +81,29 @@ empty!(pres)
empty!(posts)

@overdub(ctx, Core._apply(+, (x1, x2), (x2 * x3, x3)))
@test pres == [(tuple, (x1, x2)),
(*, (x2, x3)),
(Base.mul_int, (x2, x3)),
(tuple, (x2*x3, x3)),
(+, (x1, x2, x2*x3, x3))]
@test posts == [((x1, x2), tuple, (x1, x2)),
(Base.mul_int(x2, x3), Base.mul_int, (x2, x3)),
(*(x2, x3), *, (x2, x3)),
((x2*x3, x3), tuple, (x2*x3, x3)),
(+(x1, x2, x2*x3, x3), +, (x1, x2, x2*x3, x3))]
if !(v"1.9" <= VERSION < v"1.10")
@test pres == [(tuple, (x1, x2)),
(*, (x2, x3)),
(Base.mul_int, (x2, x3)),
(tuple, (x2*x3, x3)),
(+, (x1, x2, x2*x3, x3))]
@test posts == [((x1, x2), tuple, (x1, x2)),
(Base.mul_int(x2, x3), Base.mul_int, (x2, x3)),
(*(x2, x3), *, (x2, x3)),
((x2*x3, x3), tuple, (x2*x3, x3)),
(+(x1, x2, x2*x3, x3), +, (x1, x2, x2*x3, x3))]
else
@test pres == [(tuple, (x1, x2)),
(*, (x2, x3)),
(Base.mul_int, (x2, x3)),
(tuple, (x2*x3, x3)),
(Core._apply, (+, (x1, x2), (x2*x3, x3)))]
@test posts == [((x1, x2), tuple, (x1, x2)),
(Base.mul_int(x2, x3), Base.mul_int, (x2, x3)),
(*(x2, x3), *, (x2, x3)),
((x2*x3, x3), tuple, (x2*x3, x3)),
(+(x1, x2, x2*x3, x3), Core._apply, (+, (x1, x2), (x2*x3, x3)))]
end

println("done (took ", time() - before_time, " seconds)")

Expand Down Expand Up @@ -386,7 +401,10 @@ else
@inferred(overdub(InferCtx(), rand, Float32, 1))
end
end
@inferred(overdub(InferCtx(), broadcast, +, rand(1), rand(1)))

if VERSION < v"1.9"
@inferred(overdub(InferCtx(), broadcast, +, rand(1), rand(1)))
end
@inferred(overdub(InferCtx(), () -> kwargtest(42; foo = 1, bar = 2)))

println("done (took ", time() - before_time, " seconds)")
Expand Down Expand Up @@ -427,9 +445,11 @@ ctx = InvokeCtx(metadata=Any[])
@test overdub(ctx, invoker, 3) === 9
# This is kind of fragile and may break for unrelated reasons - the main thing
# we're testing here is that we properly trace through the `invoke` call.
@test ctx.metadata == Any[Core.apply_type, Core.invoke, Core.apply_type,
Val{2}, Core.apply_type, Base.literal_pow, *,
Base.mul_int]
if VERSION < v"1.9"
@test ctx.metadata == Any[Core.apply_type, Core.invoke, Core.apply_type,
Val{2}, Core.apply_type, Base.literal_pow, *,
Base.mul_int]
end

println("done (took ", time() - before_time, " seconds)")

Expand Down