diff --git a/src/overdub.jl b/src/overdub.jl index 288d5e1..061db14 100644 --- a/src/overdub.jl +++ b/src/overdub.jl @@ -118,7 +118,11 @@ function reflect(@nospecialize(sigtypes::Tuple), world::UInt = get_world_counter method_instance === nothing && return nothing method_signature = method.sig static_params = Any[raw_static_params...] - code_info = Core.Compiler.retrieve_code_info(method_instance) + @static if VERSION >= v"1.10.0-DEV.873" + code_info = Core.Compiler.retrieve_code_info(method_instance, world) + else + code_info = Core.Compiler.retrieve_code_info(method_instance) + end isa(code_info, CodeInfo) || return nothing code_info = copy_code_info(code_info) verbose_lineinfo!(code_info, S) @@ -598,13 +602,37 @@ const OVERDUB_FALLBACK = begin end # `args` is `(typeof(original_function), map(typeof, original_args_tuple)...)` -function __overdub_generator__(self, context_type, args::Tuple) +function __overdub_generator__(world::UInt, source, self, context_type, args) + if nfields(args) > 0 + is_builtin = args[1] <: Core.Builtin + is_invoke = args[1] === typeof(Core.invoke) + if !is_builtin || is_invoke + try + untagged_args = ntuple(i->untagtype(args[i], context_type), nfields(args)) + reflection = reflect(untagged_args, world) + if isa(reflection, Reflection) + result = overdub_pass!(reflection, context_type, is_invoke) + isa(result, Expr) && return result + return reflection.code_info + end + catch err + errmsg = "ERROR COMPILING $args IN CONTEXT $(context_type): \n" #* sprint(showerror, err) + errmsg *= "\n" .* repr("text/plain", stacktrace(catch_backtrace())) + return quote + error($errmsg) + end + end + end + end + return copy_code_info(OVERDUB_FALLBACK) +end +function __overdub_generator__(self, context_type, args) if nfields(args) > 0 is_builtin = args[1] <: Core.Builtin is_invoke = args[1] === typeof(Core.invoke) if !is_builtin || is_invoke try - untagged_args = ((untagtype(args[i], context_type) for i in 1:nfields(args))...,) + untagged_args = ntuple(i->untagtype(args[i], context_type), nfields(args)) reflection = reflect(untagged_args) if isa(reflection, Reflection) result = overdub_pass!(reflection, context_type, is_invoke) @@ -638,6 +666,18 @@ if VERSION >= v"1.4.0-DEV.304" end let line = @__LINE__, file = @__FILE__ + @static if VERSION >= v"1.10.0-DEV.873" + @eval (@__MODULE__) begin + function $Cassette.overdub($OVERDUB_CONTEXT_NAME::$Cassette.Context, $OVERDUB_ARGUMENTS_NAME...) + $(Expr(:meta, :generated_only)) + $(Expr(:meta, :generated, __overdub_generator__)) + end + function $Cassette.recurse($OVERDUB_CONTEXT_NAME::$Cassette.Context, $OVERDUB_ARGUMENTS_NAME...) + $(Expr(:meta, :generated_only)) + $(Expr(:meta, :generated, __overdub_generator__)) + end + end + else @eval (@__MODULE__) begin function $Cassette.overdub($OVERDUB_CONTEXT_NAME::$Cassette.Context, $OVERDUB_ARGUMENTS_NAME...) $(Expr(:meta, :generated_only)) @@ -666,6 +706,7 @@ let line = @__LINE__, file = @__FILE__ true))) end end + end end @doc """ diff --git a/test/misctests.jl b/test/misctests.jl index 3dd7576..a72c272 100644 --- a/test/misctests.jl +++ b/test/misctests.jl @@ -17,7 +17,9 @@ function rosenbrock(x::Vector{Float64}) end x = rand(2) -@inferred(overdub(RosCtx(), rosenbrock, x)) +if VERSION < v"1.9" + @inferred (overdub(RosCtx(), rosenbrock, x)) +end messages = String[] Cassette.prehook(::RosCtx, f, args...) = push!(messages, string("calling ", f, args)) @@ -79,16 +81,29 @@ empty!(pres) empty!(posts) @overdub(ctx, Core._apply(+, (x1, x2), (x2 * x3, x3))) -@test pres == [(tuple, (x1, x2)), - (*, (x2, x3)), - (Base.mul_int, (x2, x3)), - (tuple, (x2*x3, x3)), - (+, (x1, x2, x2*x3, x3))] -@test posts == [((x1, x2), tuple, (x1, x2)), - (Base.mul_int(x2, x3), Base.mul_int, (x2, x3)), - (*(x2, x3), *, (x2, x3)), - ((x2*x3, x3), tuple, (x2*x3, x3)), - (+(x1, x2, x2*x3, x3), +, (x1, x2, x2*x3, x3))] +if !(v"1.9" <= VERSION < v"1.10") + @test pres == [(tuple, (x1, x2)), + (*, (x2, x3)), + (Base.mul_int, (x2, x3)), + (tuple, (x2*x3, x3)), + (+, (x1, x2, x2*x3, x3))] + @test posts == [((x1, x2), tuple, (x1, x2)), + (Base.mul_int(x2, x3), Base.mul_int, (x2, x3)), + (*(x2, x3), *, (x2, x3)), + ((x2*x3, x3), tuple, (x2*x3, x3)), + (+(x1, x2, x2*x3, x3), +, (x1, x2, x2*x3, x3))] +else + @test pres == [(tuple, (x1, x2)), + (*, (x2, x3)), + (Base.mul_int, (x2, x3)), + (tuple, (x2*x3, x3)), + (Core._apply, (+, (x1, x2), (x2*x3, x3)))] + @test posts == [((x1, x2), tuple, (x1, x2)), + (Base.mul_int(x2, x3), Base.mul_int, (x2, x3)), + (*(x2, x3), *, (x2, x3)), + ((x2*x3, x3), tuple, (x2*x3, x3)), + (+(x1, x2, x2*x3, x3), Core._apply, (+, (x1, x2), (x2*x3, x3)))] +end println("done (took ", time() - before_time, " seconds)") @@ -386,7 +401,10 @@ else @inferred(overdub(InferCtx(), rand, Float32, 1)) end end -@inferred(overdub(InferCtx(), broadcast, +, rand(1), rand(1))) + +if VERSION < v"1.9" + @inferred(overdub(InferCtx(), broadcast, +, rand(1), rand(1))) +end @inferred(overdub(InferCtx(), () -> kwargtest(42; foo = 1, bar = 2))) println("done (took ", time() - before_time, " seconds)") @@ -427,9 +445,11 @@ ctx = InvokeCtx(metadata=Any[]) @test overdub(ctx, invoker, 3) === 9 # This is kind of fragile and may break for unrelated reasons - the main thing # we're testing here is that we properly trace through the `invoke` call. -@test ctx.metadata == Any[Core.apply_type, Core.invoke, Core.apply_type, - Val{2}, Core.apply_type, Base.literal_pow, *, - Base.mul_int] +if VERSION < v"1.9" + @test ctx.metadata == Any[Core.apply_type, Core.invoke, Core.apply_type, + Val{2}, Core.apply_type, Base.literal_pow, *, + Base.mul_int] +end println("done (took ", time() - before_time, " seconds)")