Skip to content

Commit

Permalink
update for 1.10 (#205)
Browse files Browse the repository at this point in the history
oscardssmith authored Aug 17, 2023
1 parent fae23e7 commit 4dbe9b1
Showing 2 changed files with 79 additions and 18 deletions.
47 changes: 44 additions & 3 deletions src/overdub.jl
Original file line number Diff line number Diff line change
@@ -118,7 +118,11 @@ function reflect(@nospecialize(sigtypes::Tuple), world::UInt = get_world_counter
method_instance === nothing && return nothing
method_signature = method.sig
static_params = Any[raw_static_params...]
code_info = Core.Compiler.retrieve_code_info(method_instance)
@static if VERSION >= v"1.10.0-DEV.873"
code_info = Core.Compiler.retrieve_code_info(method_instance, world)
else
code_info = Core.Compiler.retrieve_code_info(method_instance)
end
isa(code_info, CodeInfo) || return nothing
code_info = copy_code_info(code_info)
verbose_lineinfo!(code_info, S)
@@ -598,13 +602,37 @@ const OVERDUB_FALLBACK = begin
end

# `args` is `(typeof(original_function), map(typeof, original_args_tuple)...)`
function __overdub_generator__(self, context_type, args::Tuple)
function __overdub_generator__(world::UInt, source, self, context_type, args)
if nfields(args) > 0
is_builtin = args[1] <: Core.Builtin
is_invoke = args[1] === typeof(Core.invoke)
if !is_builtin || is_invoke
try
untagged_args = ntuple(i->untagtype(args[i], context_type), nfields(args))
reflection = reflect(untagged_args, world)
if isa(reflection, Reflection)
result = overdub_pass!(reflection, context_type, is_invoke)
isa(result, Expr) && return result
return reflection.code_info
end
catch err
errmsg = "ERROR COMPILING $args IN CONTEXT $(context_type): \n" #* sprint(showerror, err)
errmsg *= "\n" .* repr("text/plain", stacktrace(catch_backtrace()))
return quote
error($errmsg)
end
end
end
end
return copy_code_info(OVERDUB_FALLBACK)
end
function __overdub_generator__(self, context_type, args)
if nfields(args) > 0
is_builtin = args[1] <: Core.Builtin
is_invoke = args[1] === typeof(Core.invoke)
if !is_builtin || is_invoke
try
untagged_args = ((untagtype(args[i], context_type) for i in 1:nfields(args))...,)
untagged_args = ntuple(i->untagtype(args[i], context_type), nfields(args))
reflection = reflect(untagged_args)
if isa(reflection, Reflection)
result = overdub_pass!(reflection, context_type, is_invoke)
@@ -638,6 +666,18 @@ if VERSION >= v"1.4.0-DEV.304"
end

let line = @__LINE__, file = @__FILE__
@static if VERSION >= v"1.10.0-DEV.873"
@eval (@__MODULE__) begin
function $Cassette.overdub($OVERDUB_CONTEXT_NAME::$Cassette.Context, $OVERDUB_ARGUMENTS_NAME...)
$(Expr(:meta, :generated_only))
$(Expr(:meta, :generated, __overdub_generator__))
end
function $Cassette.recurse($OVERDUB_CONTEXT_NAME::$Cassette.Context, $OVERDUB_ARGUMENTS_NAME...)
$(Expr(:meta, :generated_only))
$(Expr(:meta, :generated, __overdub_generator__))
end
end
else
@eval (@__MODULE__) begin
function $Cassette.overdub($OVERDUB_CONTEXT_NAME::$Cassette.Context, $OVERDUB_ARGUMENTS_NAME...)
$(Expr(:meta, :generated_only))
@@ -666,6 +706,7 @@ let line = @__LINE__, file = @__FILE__
true)))
end
end
end
end

@doc """
50 changes: 35 additions & 15 deletions test/misctests.jl
Original file line number Diff line number Diff line change
@@ -17,7 +17,9 @@ function rosenbrock(x::Vector{Float64})
end

x = rand(2)
@inferred(overdub(RosCtx(), rosenbrock, x))
if VERSION < v"1.9"
@inferred (overdub(RosCtx(), rosenbrock, x))
end

messages = String[]
Cassette.prehook(::RosCtx, f, args...) = push!(messages, string("calling ", f, args))
@@ -79,16 +81,29 @@ empty!(pres)
empty!(posts)

@overdub(ctx, Core._apply(+, (x1, x2), (x2 * x3, x3)))
@test pres == [(tuple, (x1, x2)),
(*, (x2, x3)),
(Base.mul_int, (x2, x3)),
(tuple, (x2*x3, x3)),
(+, (x1, x2, x2*x3, x3))]
@test posts == [((x1, x2), tuple, (x1, x2)),
(Base.mul_int(x2, x3), Base.mul_int, (x2, x3)),
(*(x2, x3), *, (x2, x3)),
((x2*x3, x3), tuple, (x2*x3, x3)),
(+(x1, x2, x2*x3, x3), +, (x1, x2, x2*x3, x3))]
if !(v"1.9" <= VERSION < v"1.10")
@test pres == [(tuple, (x1, x2)),
(*, (x2, x3)),
(Base.mul_int, (x2, x3)),
(tuple, (x2*x3, x3)),
(+, (x1, x2, x2*x3, x3))]
@test posts == [((x1, x2), tuple, (x1, x2)),
(Base.mul_int(x2, x3), Base.mul_int, (x2, x3)),
(*(x2, x3), *, (x2, x3)),
((x2*x3, x3), tuple, (x2*x3, x3)),
(+(x1, x2, x2*x3, x3), +, (x1, x2, x2*x3, x3))]
else
@test pres == [(tuple, (x1, x2)),
(*, (x2, x3)),
(Base.mul_int, (x2, x3)),
(tuple, (x2*x3, x3)),
(Core._apply, (+, (x1, x2), (x2*x3, x3)))]
@test posts == [((x1, x2), tuple, (x1, x2)),
(Base.mul_int(x2, x3), Base.mul_int, (x2, x3)),
(*(x2, x3), *, (x2, x3)),
((x2*x3, x3), tuple, (x2*x3, x3)),
(+(x1, x2, x2*x3, x3), Core._apply, (+, (x1, x2), (x2*x3, x3)))]
end

println("done (took ", time() - before_time, " seconds)")

@@ -386,7 +401,10 @@ else
@inferred(overdub(InferCtx(), rand, Float32, 1))
end
end
@inferred(overdub(InferCtx(), broadcast, +, rand(1), rand(1)))

if VERSION < v"1.9"
@inferred(overdub(InferCtx(), broadcast, +, rand(1), rand(1)))
end
@inferred(overdub(InferCtx(), () -> kwargtest(42; foo = 1, bar = 2)))

println("done (took ", time() - before_time, " seconds)")
@@ -427,9 +445,11 @@ ctx = InvokeCtx(metadata=Any[])
@test overdub(ctx, invoker, 3) === 9
# This is kind of fragile and may break for unrelated reasons - the main thing
# we're testing here is that we properly trace through the `invoke` call.
@test ctx.metadata == Any[Core.apply_type, Core.invoke, Core.apply_type,
Val{2}, Core.apply_type, Base.literal_pow, *,
Base.mul_int]
if VERSION < v"1.9"
@test ctx.metadata == Any[Core.apply_type, Core.invoke, Core.apply_type,
Val{2}, Core.apply_type, Base.literal_pow, *,
Base.mul_int]
end

println("done (took ", time() - before_time, " seconds)")

3 comments on commit 4dbe9b1

@avik-pal
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vchuravy or @oscardssmith can we cut a release for this?

@oscardssmith
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should

@avik-pal
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did a release (hopefully no one will scream at me)

Please sign in to comment.