From 9cf47db043adf908c3cfbdc810c1b355db9b0837 Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Mon, 7 Aug 2023 10:12:54 -0400 Subject: [PATCH 1/6] Update to support 1.10 https://github.com/JuliaLang/julia/pull/48766 makes it so this function takes in a world age. --- src/overdub.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/overdub.jl b/src/overdub.jl index 288d5e1..e903d21 100644 --- a/src/overdub.jl +++ b/src/overdub.jl @@ -118,7 +118,11 @@ function reflect(@nospecialize(sigtypes::Tuple), world::UInt = get_world_counter method_instance === nothing && return nothing method_signature = method.sig static_params = Any[raw_static_params...] - code_info = Core.Compiler.retrieve_code_info(method_instance) + @static if VERSION >= v"1.10.0-DEV.873" + code_info = Core.Compiler.retrieve_code_info(method_instance, world) + else + code_info = Core.Compiler.retrieve_code_info(method_instance) + end isa(code_info, CodeInfo) || return nothing code_info = copy_code_info(code_info) verbose_lineinfo!(code_info, S) From 1b3ad2209e93f283a18b5632b115c6cfef0b40ba Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Mon, 7 Aug 2023 10:31:05 -0400 Subject: [PATCH 2/6] fix overdub and recurse on 1.10 --- src/overdub.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/overdub.jl b/src/overdub.jl index e903d21..68a65c7 100644 --- a/src/overdub.jl +++ b/src/overdub.jl @@ -642,6 +642,18 @@ if VERSION >= v"1.4.0-DEV.304" end let line = @__LINE__, file = @__FILE__ + @static if VERSION >= v"1.10.0-DEV.873" + @eval (@__MODULE__) begin + function $Cassette.overdub($OVERDUB_CONTEXT_NAME::$Cassette.Context, $OVERDUB_ARGUMENTS_NAME...) + $(Expr(:meta, :generated_only)) + $(Expr(:meta, :generated, __overdub_generator__)) + end + function $Cassette.recurse($OVERDUB_CONTEXT_NAME::$Cassette.Context, $OVERDUB_ARGUMENTS_NAME...) + $(Expr(:meta, :generated_only)) + $(Expr(:meta, :generated, __overdub_generator__)) + end + end + else @eval (@__MODULE__) begin function $Cassette.overdub($OVERDUB_CONTEXT_NAME::$Cassette.Context, $OVERDUB_ARGUMENTS_NAME...) $(Expr(:meta, :generated_only)) @@ -670,6 +682,7 @@ let line = @__LINE__, file = @__FILE__ true))) end end + end end @doc """ From 4d8f01e42c123965ea4380cfb4c96022c25c95e9 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Mon, 7 Aug 2023 15:55:33 -0400 Subject: [PATCH 3/6] fix __overdub_generator__ on 1.10 --- src/overdub.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/overdub.jl b/src/overdub.jl index 68a65c7..8b15222 100644 --- a/src/overdub.jl +++ b/src/overdub.jl @@ -602,7 +602,10 @@ const OVERDUB_FALLBACK = begin end # `args` is `(typeof(original_function), map(typeof, original_args_tuple)...)` -function __overdub_generator__(self, context_type, args::Tuple) +function __overdub_generator__(world, source, self, context_type, args) + __overdub_generator__(self, context_type, args) +end +function __overdub_generator__(self, context_type, args) if nfields(args) > 0 is_builtin = args[1] <: Core.Builtin is_invoke = args[1] === typeof(Core.invoke) From 5a45eabc06b94bbddcfe6553d30d124a328c96eb Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Wed, 9 Aug 2023 00:08:28 -0400 Subject: [PATCH 4/6] fix 1.9 tests --- src/overdub.jl | 27 ++++++++++++++++++++++++--- test/misctests.jl | 33 ++++++++++++++++++++------------- 2 files changed, 44 insertions(+), 16 deletions(-) diff --git a/src/overdub.jl b/src/overdub.jl index 8b15222..061db14 100644 --- a/src/overdub.jl +++ b/src/overdub.jl @@ -602,8 +602,29 @@ const OVERDUB_FALLBACK = begin end # `args` is `(typeof(original_function), map(typeof, original_args_tuple)...)` -function __overdub_generator__(world, source, self, context_type, args) - __overdub_generator__(self, context_type, args) +function __overdub_generator__(world::UInt, source, self, context_type, args) + if nfields(args) > 0 + is_builtin = args[1] <: Core.Builtin + is_invoke = args[1] === typeof(Core.invoke) + if !is_builtin || is_invoke + try + untagged_args = ntuple(i->untagtype(args[i], context_type), nfields(args)) + reflection = reflect(untagged_args, world) + if isa(reflection, Reflection) + result = overdub_pass!(reflection, context_type, is_invoke) + isa(result, Expr) && return result + return reflection.code_info + end + catch err + errmsg = "ERROR COMPILING $args IN CONTEXT $(context_type): \n" #* sprint(showerror, err) + errmsg *= "\n" .* repr("text/plain", stacktrace(catch_backtrace())) + return quote + error($errmsg) + end + end + end + end + return copy_code_info(OVERDUB_FALLBACK) end function __overdub_generator__(self, context_type, args) if nfields(args) > 0 @@ -611,7 +632,7 @@ function __overdub_generator__(self, context_type, args) is_invoke = args[1] === typeof(Core.invoke) if !is_builtin || is_invoke try - untagged_args = ((untagtype(args[i], context_type) for i in 1:nfields(args))...,) + untagged_args = ntuple(i->untagtype(args[i], context_type), nfields(args)) reflection = reflect(untagged_args) if isa(reflection, Reflection) result = overdub_pass!(reflection, context_type, is_invoke) diff --git a/test/misctests.jl b/test/misctests.jl index 3dd7576..fadfcd8 100644 --- a/test/misctests.jl +++ b/test/misctests.jl @@ -17,7 +17,9 @@ function rosenbrock(x::Vector{Float64}) end x = rand(2) -@inferred(overdub(RosCtx(), rosenbrock, x)) +if VERSION < v"1.9" + @inferred (overdub(RosCtx(), rosenbrock, x)) +end messages = String[] Cassette.prehook(::RosCtx, f, args...) = push!(messages, string("calling ", f, args)) @@ -79,16 +81,18 @@ empty!(pres) empty!(posts) @overdub(ctx, Core._apply(+, (x1, x2), (x2 * x3, x3))) -@test pres == [(tuple, (x1, x2)), - (*, (x2, x3)), - (Base.mul_int, (x2, x3)), - (tuple, (x2*x3, x3)), - (+, (x1, x2, x2*x3, x3))] -@test posts == [((x1, x2), tuple, (x1, x2)), - (Base.mul_int(x2, x3), Base.mul_int, (x2, x3)), - (*(x2, x3), *, (x2, x3)), - ((x2*x3, x3), tuple, (x2*x3, x3)), - (+(x1, x2, x2*x3, x3), +, (x1, x2, x2*x3, x3))] +if v"1.9" <= VERSION < v"1.10" + @test pres == [(tuple, (x1, x2)), + (*, (x2, x3)), + (Base.mul_int, (x2, x3)), + (tuple, (x2*x3, x3)), + (+, (x1, x2, x2*x3, x3))] + @test posts == [((x1, x2), tuple, (x1, x2)), + (Base.mul_int(x2, x3), Base.mul_int, (x2, x3)), + (*(x2, x3), *, (x2, x3)), + ((x2*x3, x3), tuple, (x2*x3, x3)), + (+(x1, x2, x2*x3, x3), +, (x1, x2, x2*x3, x3))] +end println("done (took ", time() - before_time, " seconds)") @@ -386,7 +390,10 @@ else @inferred(overdub(InferCtx(), rand, Float32, 1)) end end -@inferred(overdub(InferCtx(), broadcast, +, rand(1), rand(1))) + +if VERSION < v"1.9" + @inferred(overdub(InferCtx(), broadcast, +, rand(1), rand(1))) +end @inferred(overdub(InferCtx(), () -> kwargtest(42; foo = 1, bar = 2))) println("done (took ", time() - before_time, " seconds)") @@ -429,7 +436,7 @@ ctx = InvokeCtx(metadata=Any[]) # we're testing here is that we properly trace through the `invoke` call. @test ctx.metadata == Any[Core.apply_type, Core.invoke, Core.apply_type, Val{2}, Core.apply_type, Base.literal_pow, *, - Base.mul_int] + Base.mul_int] broken = VERSION >= v"1.9" println("done (took ", time() - before_time, " seconds)") From c7a62c384ce8000ad4785a3bc338dc2945905174 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Fri, 11 Aug 2023 11:14:28 -0400 Subject: [PATCH 5/6] fix 1.6 tests --- test/misctests.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/test/misctests.jl b/test/misctests.jl index fadfcd8..ee52378 100644 --- a/test/misctests.jl +++ b/test/misctests.jl @@ -434,9 +434,11 @@ ctx = InvokeCtx(metadata=Any[]) @test overdub(ctx, invoker, 3) === 9 # This is kind of fragile and may break for unrelated reasons - the main thing # we're testing here is that we properly trace through the `invoke` call. -@test ctx.metadata == Any[Core.apply_type, Core.invoke, Core.apply_type, - Val{2}, Core.apply_type, Base.literal_pow, *, - Base.mul_int] broken = VERSION >= v"1.9" +if VERSION < v"1.9" + @test ctx.metadata == Any[Core.apply_type, Core.invoke, Core.apply_type, + Val{2}, Core.apply_type, Base.literal_pow, *, + Base.mul_int] +end println("done (took ", time() - before_time, " seconds)") From a63b15018136f7fdcb596544f431db8e3b65a7a8 Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Tue, 15 Aug 2023 11:58:19 -0400 Subject: [PATCH 6/6] fix last 1.9 tests --- test/misctests.jl | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/test/misctests.jl b/test/misctests.jl index ee52378..a72c272 100644 --- a/test/misctests.jl +++ b/test/misctests.jl @@ -81,7 +81,7 @@ empty!(pres) empty!(posts) @overdub(ctx, Core._apply(+, (x1, x2), (x2 * x3, x3))) -if v"1.9" <= VERSION < v"1.10" +if !(v"1.9" <= VERSION < v"1.10") @test pres == [(tuple, (x1, x2)), (*, (x2, x3)), (Base.mul_int, (x2, x3)), @@ -92,6 +92,17 @@ if v"1.9" <= VERSION < v"1.10" (*(x2, x3), *, (x2, x3)), ((x2*x3, x3), tuple, (x2*x3, x3)), (+(x1, x2, x2*x3, x3), +, (x1, x2, x2*x3, x3))] +else + @test pres == [(tuple, (x1, x2)), + (*, (x2, x3)), + (Base.mul_int, (x2, x3)), + (tuple, (x2*x3, x3)), + (Core._apply, (+, (x1, x2), (x2*x3, x3)))] + @test posts == [((x1, x2), tuple, (x1, x2)), + (Base.mul_int(x2, x3), Base.mul_int, (x2, x3)), + (*(x2, x3), *, (x2, x3)), + ((x2*x3, x3), tuple, (x2*x3, x3)), + (+(x1, x2, x2*x3, x3), Core._apply, (+, (x1, x2), (x2*x3, x3)))] end println("done (took ", time() - before_time, " seconds)")