From c4c6e48e75a586925ae9f35b095bc7e6e5ce3278 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Mon, 2 Sep 2024 20:01:57 -0500 Subject: [PATCH 01/87] Don't consider size of memset to be writing --- src/compiler.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/compiler.jl b/src/compiler.jl index 041203c37b3..65cd00a10d0 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -6369,6 +6369,13 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; if isa(user, LLVM.CallInst) called = LLVM.called_operand(user) if isa(called, LLVM.Function) + intr = LLVM.API.LLVMGetIntrinsicID(called) + if intr == LLVM.Intrinsic("llvm.memset").id + if cur != operands(user)[1] + continue + end + end + nm = LLVM.name(called) if nm == "ijl_alloc_array_1d" || nm == "jl_alloc_array_1d" || nm == "ijl_alloc_array_2d" || nm == "jl_alloc_array_2d" || From 545f3d138c81d50045d13ce862da013859bc56e3 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 2 Sep 2024 20:21:25 -0500 Subject: [PATCH 02/87] Fix select index offset (#1778) --- src/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 65cd00a10d0..020a2d0b59b 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -2517,7 +2517,7 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err else shadowres = LLVM.UndefValue(value_type(lhs)) for idx in 1:width - shadowres = insert_value!(prevbb, shadowres, select!(prevbb, new_from_original(gutils, operands(cur)[1]), extract_value!(prevbb, lhs, idx), extract_value!(prevbb, rhs, idx)), idx) + shadowres = insert_value!(prevbb, shadowres, select!(prevbb, new_from_original(gutils, operands(cur)[1]), extract_value!(prevbb, lhs, idx-1), extract_value!(prevbb, rhs, idx-1)), idx-1) if isa(shadowres, LLVM.Instruction) push!(created, shadowres) end From 9db8a4b4b7f1da5ce1cfb3a33cbe492690f5a6d2 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 3 Sep 2024 07:16:27 -0500 Subject: [PATCH 03/87] Update Project.toml --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 82196219d29..8229dfc1709 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.35" +version = "0.12.36" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -35,7 +35,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7.8" -Enzyme_jll = "0.0.146" +Enzyme_jll = "0.0.147" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27" LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" From 54fdd094e4615c1085c1e46ea40df85baae4250c Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 3 Sep 2024 10:52:34 -0500 Subject: [PATCH 04/87] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 8229dfc1709..02636694ba3 100644 --- a/Project.toml +++ b/Project.toml @@ -35,7 +35,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7.8" -Enzyme_jll = "0.0.147" +Enzyme_jll = "0.0.146" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27" LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" From f520dd4912f3becc2c5286e7ebd9cc3a9fd2f869 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 3 Sep 2024 13:07:49 -0500 Subject: [PATCH 05/87] Add forward mode svec_ref (#1782) --- src/rules/llvmrules.jl | 2 +- src/rules/typeunstablerules.jl | 68 +++++++++++++++++++++++++++------- 2 files changed, 56 insertions(+), 14 deletions(-) diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index 31b2454a1be..9c4feb126cb 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -149,7 +149,7 @@ end if in(name, ("ijl_f__apply_iterate", "jl_f__apply_iterate")) return common_apply_iterate_augfwd(2, B, orig, gutils, normalR, shadowR, tapeR) end - if in(name, ("ijl_f__svec_rev", "jl_f__svec_ref")) + if in(name, ("ijl_f__svec_ref", "jl_f__svec_ref")) return common_f_svec_ref_augfwd(2, B, orig, gutils, normalR, shadowR, tapeR) end if in(name, ("ijl_f_finalizer", "jl_f_finalizer")) diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index 42fcbe18cff..5372a677265 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -1330,27 +1330,69 @@ end common_setfield_rev(1, B, orig, gutils, tape) end - +function error_if_differentiable(::Type{T}) where T + seen = () + areg = active_reg_inner(T, seen, nothing) + if areg != AnyState + throw(AssertionError("Found unhandled differentiable variable in jl_f_svec_ref $T")) + end + nothing +end function common_f_svec_ref_fwd(offset, B, orig, gutils, normalR, shadowR) if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end - emit_error(B, orig, "Enzyme: unhandled forward for jl_f__svec_ref") - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing - if shadowR != C_NULL && normal !== nothing - unsafe_store!(shadowR, normal.ref) + + width = get_width(gutils) + + origmi, origh, origkey = operands(orig)[offset:end-1] + + shadowh = invert_pointer(gutils, origh, B) + + newvals = API.CValueType[API.VT_Primal, API.VT_Shadow, API.VT_Primal] + + if offset != 1 + pushfirst!(newvals, API.VT_Primal) end - return false -end + + mi = new_from_original(gutils, origmi) -function error_if_differentiable(::Type{T}) where T - seen = () - areg = active_reg_inner(T, seen, nothing, #=justActive=#Val(true)) - if areg != AnyState - throw(AssertionError("Found unhandled differentiable variable in jl_f_svec_ref $T")) + mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) + + shadowres = if width == 1 + newops = LLVM.Value[mi, shadowh, new_from_original(gutils, origkey)] + if offset != 1 + pushfirst!(newops, operands(orig)[1]) + end + cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) + callconv!(cal, callconv(orig)) + + if is_constant_value(gutils, origh) + emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(B, error_if_differentiable), emit_jltypeof!(B, cal)]) + end + cal + else + ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) + shadow = LLVM.UndefValue(ST) + for j in 1:width + newops = LLVM.Value[mi, extract_value!(B, shadowh, j-1), new_from_original(gutils, origkey)] + if offset != 1 + pushfirst!(newops, operands(orig)[1]) + end + cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) + callconv!(cal, callconv(orig)) + if is_constant_value(gutils, origh) + emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(B, error_if_differentiable), emit_jltypeof!(B, cal)]) + end + shadow = insert_value!(B, shadow, cal, j-1) + end + shadow end - nothing + + unsafe_store!(shadowR, shadowres.ref) + + return false end function common_f_svec_ref_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) From d883c6ac18204bacf3edd275be51f160b5af18f0 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 3 Sep 2024 17:41:21 -0500 Subject: [PATCH 06/87] Do write barrier after mixed duplicated allocation upgrade (#1785) * Do write barrier after mixed duplicated allocation upgrade * Add dump post wrap option --- src/compiler.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 020a2d0b59b..6d3161b25f1 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -4002,6 +4002,7 @@ include("rules/activityrules.jl") @inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where A <: BatchDuplicatedNoNeed = API.DFT_DUP_NONEED const DumpPreEnzyme = Ref(false) +const DumpPostWrap = Ref(false) function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wrap, modifiedBetween, returnPrimal, expectedTapeType, loweredArgs, boxedArgs) if DumpPreEnzyme[] @@ -4189,6 +4190,9 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr else @assert "Unhandled derivative mode", mode end + if DumpPostWrap[] + API.EnzymeDumpModuleRef(mod.ref) + end API.EnzymeLogicErasePreprocessedFunctions(logic) adjointfname = adjointf == nothing ? nothing : LLVM.name(adjointf) augmented_primalfname = augmented_primalf == nothing ? nothing : LLVM.name(augmented_primalf) @@ -4495,9 +4499,10 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, convty = convert(LLVMType, T′; allow_boxed=true) if (T <: MixedDuplicated || T <: BatchMixedDuplicated) && !isboxed # && (isa(llty, LLVM.ArrayType) || isa(llty, LLVM.StructType)) - al = emit_allocobj!(builder, Base.RefValue{T′}) + al0 = al = emit_allocobj!(builder, Base.RefValue{T′}) al = bitcast!(builder, al, LLVM.PointerType(llty, addrspace(value_type(al)))) store!(builder, params[i], al) + emit_writebarrier!(builder, get_julia_inner_types(builder, al0, params[i])) al = addrspacecast!(builder, al, LLVM.PointerType(llty, Derived)) push!(realparms, al) else From d7cb43bb53011c768f1b7f68859dce7399cdaef2 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 3 Sep 2024 17:41:41 -0500 Subject: [PATCH 07/87] Handle batch closures (#1784) --- src/compiler.jl | 4 +++- test/runtests.jl | 17 +++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 6d3161b25f1..a9bcaddf53e 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -6872,8 +6872,10 @@ end argexpr = :(fn.dval) if isboxed push!(types, Any) - else + elseif width == 1 push!(types, F) + else + push!(types, NTuple{width, F}) end push!(ccexprs, argexpr) end diff --git a/test/runtests.jl b/test/runtests.jl index a869443a092..250fa20c975 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -313,6 +313,23 @@ end BatchDuplicated(ones(3), (ones(3), ones(3)))) end +struct MyClosure{A} + a::A +end + +function (mc::MyClosure)(x) + # computes x^2 using internal storage + mc.a[1] = x + return mc.a[1]^2 +end + +@testset "Batch Closure" begin + g = MyClosure([0.0]) + g_and_dgs = BatchDuplicated(g, (make_zero(g), make_zero(g))) + x_and_dxs = BatchDuplicated(3.0, (5.0, 7.0)) + autodiff(Forward, g_and_dgs, BatchDuplicated, x_and_dxs) # error +end + # @testset "Split Tape" begin # f(x) = x[1] * x[1] From 75a2f4c9fd072e4e6a11c52058c757aee9620a95 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 3 Sep 2024 22:37:23 -0500 Subject: [PATCH 08/87] Improve zero-set location error (#1788) --- src/rules/llvmrules.jl | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index 9c4feb126cb..fb93016063b 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -327,7 +327,12 @@ end elSize = LLVM.zext!(B, elSize, LLVM.IntType(8*sizeof(Csize_t))) len = get_array_len(B, shadowin) length = LLVM.mul!(B, len, elSize) - GPUCompiler.@safe_warn "TODO forward zero-set of arraycopy used memset rather than runtime type" + bt = GPUCompiler.backtrace(orig) + btstr = sprint() do io + print(io,"\nCaused by:") + Base.show_backtrace(io, bt) + end + GPUCompiler.@safe_warn "TODO forward zero-set of arraycopy used memset rather than runtime type $btstr" LLVM.memset!(B, get_array_data(B, shadowres), LLVM.ConstantInt(i8, 0, false), length, algn) end if API.runtimeActivity() @@ -345,7 +350,12 @@ end elSize = LLVM.zext!(B, elSize, LLVM.IntType(8*sizeof(Csize_t))) len = get_array_len(B, ev) length = LLVM.mul!(B, len, elSize) - GPUCompiler.@safe_warn "TODO forward zero-set of arraycopy used memset rather than runtime type" + bt = GPUCompiler.backtrace(orig) + btstr = sprint() do io + print(io,"\nCaused by:") + Base.show_backtrace(io, bt) + end + GPUCompiler.@safe_warn "TODO forward zero-set of arraycopy used memset rather than runtime type $btstr" LLVM.memset!(B, get_array_data(B, callv), LLVM.ConstantInt(i8, 0, false), length, algn) end if API.runtimeActivity() From cdc790d0d25d2938fd520be2645dca6bbf33d711 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 3 Sep 2024 22:54:24 -0500 Subject: [PATCH 09/87] CompatHelper: bump compat for Enzyme_jll to 0.0.148, (keep existing compat) (#1787) Co-authored-by: CompatHelper Julia --- Project.toml | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/Project.toml b/Project.toml index 02636694ba3..b84ba12996d 100644 --- a/Project.toml +++ b/Project.toml @@ -16,26 +16,12 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -[weakdeps] -BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - -[extensions] -EnzymeBFloat16sExt = "BFloat16s" -EnzymeChainRulesCoreExt = "ChainRulesCore" -EnzymeLogExpFunctionsExt = "LogExpFunctions" -EnzymeSpecialFunctionsExt = "SpecialFunctions" -EnzymeStaticArraysExt = "StaticArrays" - [compat] BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7.8" -Enzyme_jll = "0.0.146" +Enzyme_jll = "0.0.146, 0.0.148" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27" LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" @@ -45,9 +31,23 @@ SpecialFunctions = "1, 2" StaticArrays = "1" julia = "1.6" +[extensions] +EnzymeBFloat16sExt = "BFloat16s" +EnzymeChainRulesCoreExt = "ChainRulesCore" +EnzymeLogExpFunctionsExt = "LogExpFunctions" +EnzymeSpecialFunctionsExt = "SpecialFunctions" +EnzymeStaticArraysExt = "StaticArrays" + [extras] BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[weakdeps] +BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" From 8dde0343d7b4fdc92ecf4fe4fafd7e7df7cf1427 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 3 Sep 2024 23:38:08 -0500 Subject: [PATCH 10/87] Add names to object emission (#1789) --- src/compiler.jl | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index a9bcaddf53e..ad2d943c213 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -748,7 +748,7 @@ declare_allocobj!(mod) = get_function!(mod, "julia.gc_alloc_obj") do LLVM.FunctionType(T_prjlvalue, [T_ppjlvalue, T_size_t, T_prjlvalue]) end end -function emit_allocobj!(B, tag::LLVM.Value, Size::LLVM.Value, needs_workaround::Bool) +function emit_allocobj!(B, tag::LLVM.Value, Size::LLVM.Value, needs_workaround::Bool, name::String="") curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -792,12 +792,12 @@ function emit_allocobj!(B, tag::LLVM.Value, Size::LLVM.Value, needs_workaround:: alloc_obj, alty = declare_allocobj!(mod) @static if VERSION < v"1.8.0" - return call!(B, alty, alloc_obj, [ptls, Size, tag]) + return call!(B, alty, alloc_obj, [ptls, Size, tag], name) else - return call!(B, alty, alloc_obj, [ct, Size, tag]) + return call!(B, alty, alloc_obj, [ct, Size, tag], name) end end -function emit_allocobj!(B, T::DataType) +function emit_allocobj!(B, T::DataType, name::String="") curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -811,7 +811,7 @@ function emit_allocobj!(B, T::DataType) T_size_t = convert(LLVM.LLVMType, UInt) Size = LLVM.ConstantInt(T_size_t, sizeof(T)) - emit_allocobj!(B, tag, Size, #=needs_workaround=#false) + emit_allocobj!(B, tag, Size, #=needs_workaround=#false, name) end declare_pointerfromobjref!(mod) = get_function!(mod, "julia.pointer_from_objref") do T_jlvalue = LLVM.StructType(LLVMType[]) @@ -4499,7 +4499,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, convty = convert(LLVMType, T′; allow_boxed=true) if (T <: MixedDuplicated || T <: BatchMixedDuplicated) && !isboxed # && (isa(llty, LLVM.ArrayType) || isa(llty, LLVM.StructType)) - al0 = al = emit_allocobj!(builder, Base.RefValue{T′}) + al0 = al = emit_allocobj!(builder, Base.RefValue{T′}, "mixedparameter") al = bitcast!(builder, al, LLVM.PointerType(llty, addrspace(value_type(al)))) store!(builder, params[i], al) emit_writebarrier!(builder, get_julia_inner_types(builder, al0, params[i])) @@ -4649,7 +4649,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, ival = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, T_prjlvalue))) for idx in 1:width pv = (width == 1) ? eval : extract_value!(builder, eval, idx-1) - al0 = al = emit_allocobj!(builder, Base.RefValue{eltype(rettype)}) + al0 = al = emit_allocobj!(builder, Base.RefValue{eltype(rettype)}, "batchmixedret") llty = value_type(pv) al = bitcast!(builder, al, LLVM.PointerType(llty, addrspace(value_type(al)))) store!(builder, pv, al) @@ -5236,9 +5236,10 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function if arg.arg_i in loweredArgs push!(nops, load!(builder, convert(LLVMType, arg.typ), parm)) elseif arg.arg_i in raisedArgs - obj = emit_allocobj!(builder, arg.typ) + obj = emit_allocobj!(builder, arg.typ, "raisedArg") bc = bitcast!(builder, obj, LLVM.PointerType(value_type(parm), addrspace(value_type(obj)))) store!(builder, parm, bc) + emit_writebarrier!(builder, get_julia_inner_types(builder, obj, parm)) addr = addrspacecast!(builder, bc, LLVM.PointerType(value_type(parm), Derived)) push!(nops, addr) else @@ -5374,7 +5375,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function ret!(builder, fill_val) else nobj = if sretPtr !== nothing - obj = emit_allocobj!(builder, jlrettype) + obj = emit_allocobj!(builder, jlrettype, "boxunion") llty = convert(LLVMType, jlrettype) ld = load!(builder, llty, bitcast!(builder, sretPtr, LLVM.PointerType(llty, addrspace(value_type(sretPtr))))) store!(builder, ld, bitcast!(builder, obj, LLVM.PointerType(llty, addrspace(value_type(obj))))) From a5ec75f2a9000d90a107a4b37f11d50f8f13671e Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 5 Sep 2024 23:01:19 -0500 Subject: [PATCH 11/87] Consider constant fp in runtime activity (#1797) * Consider constant fp in runtime activity * fix --- src/compiler.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/compiler.jl b/src/compiler.jl index ad2d943c213..94a012bd2c8 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -2459,6 +2459,9 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err return make_batched(ncur, prevbb) end end + if isa(cur, LLVM.ConstantFP) + return make_batched(ConstantFP(value_type(cur), 0), prevbb) + end if isa(cur, LLVM.ConstantDataSequential) cvals = LLVM.Value[] changed = false From 0307b78de83cff587be7f098afc016db1f5a6451 Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Fri, 6 Sep 2024 00:02:46 -0400 Subject: [PATCH 12/87] Suggest workaround in error for overwritten active by ref (#1791) --- src/rules/customrules.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 749ec36cfde..0f77d37a9d3 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -136,7 +136,12 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, ptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 0)]) if value_type(val) != eltype(value_type(ptr)) if overwritten[end] - emit_error(B, orig, "Enzyme: active by ref type $Ty is overwritten in application of custom rule for $mi val=$(string(val)) ptr=$(string(ptr))") + emit_error( + B, + orig, + "Enzyme: active by ref type $Ty is overwritten in application of custom rule for $mi val=$(string(val)) ptr=$(string(ptr)). " + * "As a workaround until support for this is added, try passing values as separate arguments rather than as an aggregate of type $Ty.", + ) end if arty == eltype(value_type(val)) val = load!(B, arty, val) From b91fb0798532912bd7d666ebf5666a4769267e08 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 5 Sep 2024 23:08:46 -0500 Subject: [PATCH 13/87] Fix custom active reverse mode check (#1798) --- src/rules/customrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 0f77d37a9d3..96286239876 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -1052,7 +1052,7 @@ end idx+=1 end else - Tys = (A <: Active ? eltype(A) : Nothing for A in activity[2+isKWCall:end]) + Tys = (A <: Active ? (width == 1 ? eltype(A) : NTuple{Int(width), eltype(A)}) : Nothing for A in activity[2+isKWCall:end]) ST = Tuple{Tys...} if rev_RT != ST emit_error(B, orig, "Enzyme: Reverse pass custom rule " * string(rev_TT) * " return type mismatch, expected "*string(ST)*" found "* string(rev_RT)) From 754937bacb860d6235c9d3ea86104649b838c5ff Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 6 Sep 2024 10:50:25 -0500 Subject: [PATCH 14/87] Look for more writebarrier opportunities (#1800) * Look for more writebarrier opportunities * Update compiler.jl --- src/compiler.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 94a012bd2c8..c5684c2e7e7 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -4505,7 +4505,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, al0 = al = emit_allocobj!(builder, Base.RefValue{T′}, "mixedparameter") al = bitcast!(builder, al, LLVM.PointerType(llty, addrspace(value_type(al)))) store!(builder, params[i], al) - emit_writebarrier!(builder, get_julia_inner_types(builder, al0, params[i])) + emit_writebarrier!(builder, get_julia_inner_types(builder, al0, params[i])) al = addrspacecast!(builder, al, LLVM.PointerType(llty, Derived)) push!(realparms, al) else @@ -5382,6 +5382,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function llty = convert(LLVMType, jlrettype) ld = load!(builder, llty, bitcast!(builder, sretPtr, LLVM.PointerType(llty, addrspace(value_type(sretPtr))))) store!(builder, ld, bitcast!(builder, obj, LLVM.PointerType(llty, addrspace(value_type(obj))))) + emit_writebarrier!(builder, get_julia_inner_types(builder, obj, ld)) # memcpy!(builder, bitcast!(builder, obj, LLVM.PointerType(T_int8, addrspace(value_type(obj)))), 0, bitcast!(builder, sretPtr, LLVM.PointerType(T_int8)), 0, LLVM.ConstantInt(T_int64, sizeof(jlrettype))) obj else From 14851efd29d85a5c0775ff14a409aadb3f4cf4f2 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 12 Sep 2024 09:04:02 -0400 Subject: [PATCH 15/87] Restrict version to 1.10+ (#1809) * Restrict version to 1.10+ * fix * fixup * Update CI.yml * Update Project.toml * Update Project.toml --- .github/workflows/CI.yml | 45 +--- Project.toml | 6 +- lib/EnzymeTestUtils/Project.toml | 2 +- src/Enzyme.jl | 56 ++--- src/compiler.jl | 348 +++++++------------------------ src/compiler/interpreter.jl | 100 +-------- src/compiler/orcv1.jl | 181 ---------------- src/compiler/orcv2.jl | 9 +- src/compiler/reflection.jl | 29 +-- src/compiler/utils.jl | 67 ------ src/compiler/validation.jl | 67 ++---- src/internal_rules.jl | 16 +- src/rules/jitrules.jl | 14 -- src/rules/parallelrules.jl | 28 +-- src/typetree.jl | 12 +- src/utils.jl | 106 +--------- test/DiffTests.jl | 18 -- test/applyiter.jl | 2 - test/internal_rules.jl | 6 +- test/mixed.jl | 4 - test/rrules.jl | 3 - test/runtests.jl | 169 ++++++--------- 22 files changed, 184 insertions(+), 1104 deletions(-) delete mode 100644 src/compiler/orcv1.jl diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 5093cf3a5a5..60d713c5297 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -21,10 +21,6 @@ jobs: fail-fast: false matrix: version: - - '1.6' - - '1.7' - - '1.8' - - '1.9' - '1.10' - ~1.11.0-0 - 'nightly' @@ -42,46 +38,11 @@ jobs: arch: x64 libEnzyme: local include: - - os: ubuntu-20.04 - arch: x86 - libEnzyme: packaged - version: '1.6' - assertions: false - - os: ubuntu-20.04 - arch: x86 - libEnzyme: packaged - version: '1.7' - assertions: false - - os: ubuntu-20.04 - arch: x86 - libEnzyme: packaged - version: '1.8' - assertions: false - - os: ubuntu-20.04 - arch: x86 - libEnzyme: packaged - version: '1.9' - assertions: false - os: ubuntu-20.04 arch: x86 libEnzyme: packaged version: '1.10' assertions: false - - os: ubuntu-20.04 - arch: x64 - libEnzyme: packaged - version: '1.7' - assertions: true - - os: ubuntu-20.04 - arch: x64 - libEnzyme: packaged - version: '1.8' - assertions: true - - os: ubuntu-20.04 - arch: x64 - libEnzyme: packaged - version: '1.9' - assertions: true - os: ubuntu-20.04 arch: x64 libEnzyme: packaged @@ -125,7 +86,8 @@ jobs: shell: julia --color=yes --project=. {0} run: | using Pkg - Pkg.develop(path="lib/EnzymeCore") + Pkg.develop([PackageSpec(; path) for path in ("lib/EnzymeCore", "lib/EnzymeTestUtils")]) + Pkg.instantiate() env: JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager - name: Build libEnzyme @@ -172,9 +134,6 @@ jobs: fail-fast: false matrix: version: - - '1.7' - - '1.8' - - '1.9' - '1.10' - ~1.11.0-0 - 'nightly' diff --git a/Project.toml b/Project.toml index b84ba12996d..15890547e1b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.36" +version = "0.13.0" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -23,13 +23,13 @@ ChainRulesCore = "1" EnzymeCore = "0.7.8" Enzyme_jll = "0.0.146, 0.0.148" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27" -LLVM = "6.1, 7, 8, 9" +LLVM = "6.1, 7, 8, =9.0" LogExpFunctions = "0.3" ObjectFile = "0.4" Preferences = "1.4" SpecialFunctions = "1, 2" StaticArrays = "1" -julia = "1.6" +julia = "1.10" [extensions] EnzymeBFloat16sExt = "BFloat16s" diff --git a/lib/EnzymeTestUtils/Project.toml b/lib/EnzymeTestUtils/Project.toml index 38b783facc5..80dd2ede758 100644 --- a/lib/EnzymeTestUtils/Project.toml +++ b/lib/EnzymeTestUtils/Project.toml @@ -13,7 +13,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] ConstructionBase = "1.4.1" -Enzyme = "0.11, 0.12" +Enzyme = "0.11, 0.12, 0.13" EnzymeCore = "0.5, 0.6, 0.7" FiniteDifferences = "0.12.12" MetaTesting = "0.1" diff --git a/src/Enzyme.jl b/src/Enzyme.jl index b7d86b47057..450d96ffb01 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -249,11 +249,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) end rt = if A isa UnionAll - @static if VERSION >= v"1.8.0" - Compiler.primal_return_type(rmode, Val(codegen_world_age(FTy, tt)), FTy, tt) - else - Core.Compiler.return_type(f.val, tt) - end + Compiler.primal_return_type(rmode, Val(codegen_world_age(FTy, tt)), FTy, tt) else eltype(A) end @@ -339,7 +335,7 @@ Like [`autodiff`](@ref) but will try to guess the activity of the return value. """ @inline function autodiff(mode::CMode, f::FA, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, CMode<:Mode, Nargs} tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} - rt = if mode isa ReverseMode && VERSION >= v"1.8.0" + rt = if mode isa ReverseMode Compiler.primal_return_type(mode, Val(codegen_world_age(eltype(FA), tt)), eltype(FA), tt) else Core.Compiler.return_type(f.val, tt) @@ -556,7 +552,7 @@ Like [`autodiff_deferred`](@ref) but will try to guess the activity of the retur @inline function autodiff_deferred(mode::M, f::FA, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, M<:Mode, Nargs} tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} - rt = if mode isa ReverseMode && VERSION >= v"1.8.0" + rt = if mode isa ReverseMode Compiler.primal_return_type(mode, Val(codegen_world_age(eltype(FA), tt)), eltype(FA), tt) else Core.Compiler.return_type(f.val, tt) @@ -903,11 +899,7 @@ result, ∂v, ∂A end rt = if RT isa UnionAll - @static if VERSION < v"1.8-" - throw(MethodError(autodiff_deferred_thunk, (mode, tt, fa, a2, args...))) - else - RT{Core.Compiler.return_type(Tuple{eltype(FA), map(eltype, args)...})} - end + RT{Core.Compiler.return_type(Tuple{eltype(FA), map(eltype, args)...})} else @assert RT isa DataType RT @@ -1243,13 +1235,9 @@ of shape `size(input)` of values of the output type. inshape = size(x) outshape = size(cols[1]) # st : outshape x total inputs - st = @static if VERSION >= v"1.9" - Base.stack(cols) - else - reshape(cat(cols..., dims=length(outshape)), (outshape..., inshape...)) - end + st = Base.stack(cols) - st3 = if length(inshape) <= 1 || VERSION < v"1.9" + st3 = if length(inshape) <= 1 st else reshape(st, (outshape..., inshape...)) @@ -1279,13 +1267,9 @@ end inshape = size(x) outshape = size(cols[1]) # st : outshape x total inputs - st = @static if VERSION >= v"1.9" - Base.stack(cols) - else - reshape(cat(cols..., dims=length(outshape)), (outshape..., inshape...)) - end + st = Base.stack(cols) - st3 = if length(inshape) <= 1 || VERSION < v"1.9" + st3 = if length(inshape) <= 1 st else reshape(st, (outshape..., inshape...)) @@ -1311,13 +1295,9 @@ end inshape = size(x) outshape = size(cols[1]) # st : outshape x total inputs - st = @static if VERSION >= v"1.9" - Base.stack(cols) - else - reshape(cat(cols..., dims=length(outshape)), (outshape..., inshape...)) - end + st = Base.stack(cols) - st3 = if length(inshape) <= 1 || VERSION < v"1.9" + st3 = if length(inshape) <= 1 st else reshape(st, (outshape..., inshape...)) @@ -1416,13 +1396,9 @@ of shape `size(output)` of values of the input type. if x isa AbstractArray inshape = size(x) - st = @static if VERSION >= v"1.9" - Base.stack(rows) - else - reshape(cat(rows..., dims=length(inshape)), (inshape..., outshape...)) - end + st = Base.stack(rows) - st2 = if length(outshape) == 1 || VERSION < v"1.9" + st2 = if length(outshape) == 1 st else reshape(st, (inshape..., outshape...)) @@ -1469,13 +1445,9 @@ end outshape = tmp[1][2] if x isa AbstractArray inshape = size(x) - st = @static if VERSION >= v"1.9" - Base.stack(rows) - else - reshape(cat(rows..., dims=length(inshape)), (inshape..., outshape...)) - end + st = Base.stack(rows) - st2 = if length(outshape) == 1 || VERSION < v"1.9" + st2 = if length(outshape) == 1 st else reshape(st, (inshape..., outshape...)) diff --git a/src/compiler.jl b/src/compiler.jl index c5684c2e7e7..f82ae6c1358 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -31,27 +31,14 @@ function cpu_name() end function cpu_features() - if VERSION >= v"1.10.0-beta1" - return ccall(:jl_get_cpu_features, String, ()) - end - - @static if Sys.ARCH == :x86_64 || - Sys.ARCH == :x86 - return "+mmx,+sse,+sse2,+fxsr,+cx8" # mandated by Julia - else - return "" - end + return ccall(:jl_get_cpu_features, String, ()) end import GPUCompiler: @safe_debug, @safe_info, @safe_warn, @safe_error include("compiler/utils.jl") -if v"8" <= LLVM.version() < v"12" - include("compiler/orcv1.jl") -else - include("compiler/orcv2.jl") -end +include("compiler/orcv2.jl") include("gradientutils.jl") @@ -97,11 +84,9 @@ Dict{DataType, Tuple{Symbol, Int, Union{Nothing, Tuple{Symbol, DataType}}}}( typeof(Base.FastMath.cosh_fast) => (:cosh, 1, nothing), typeof(Base.tanh) => (:tanh, 1, nothing), typeof(Base.ldexp) => (:ldexp, 2, nothing), - typeof(Base.FastMath.tanh_fast) => (:tanh, 1, nothing) + typeof(Base.FastMath.tanh_fast) => (:tanh, 1, nothing), + typeof(Base.fma_emulated) => (:fma, 3, nothing) ) -@static if VERSION >= v"1.8.0" - known_ops[typeof(Base.fma_emulated)] = (:fma, 3, nothing) -end @inline function find_math_method(@nospecialize(func), sparam_vals) if func ∈ keys(known_ops) name, arity, toinject = known_ops[func] @@ -425,39 +410,7 @@ end @inline is_arrayorvararg_ty(::Type{IdDict{K, V} where K}) where {V} = true @inline function datatype_fieldcount(t::Type{T}) where T - @static if VERSION < v"1.10.0" - NT = @static if VERSION < v"1.9.0" - Base.NamedTuple_typename - else - Base._NAMEDTUPLE_NAME - end - if t.name === NT - names, types = t.parameters[1], t.parameters[2] - if names isa Tuple - return length(names) - end - if types isa DataType && types <: Tuple - return datatype_fieldcount(types) - end - return nothing - else - @static if VERSION < v"1.7.0" - if t.abstract || (t.name === Tuple.name && Base.isvatuple(t)) - return nothing - end - else - if isabstracttype(t) || (t.name === Tuple.name && Base.isvatuple(t)) - return nothing - end - end - end - if isdefined(t, :types) - return length(t.types) - end - return length(t.name.names) - else - return Base.datatype_fieldcount(t) - end + return Base.datatype_fieldcount(t) end @inline function staticInTup(::Val{T}, tup::NTuple{N, Val}) where {T, N} @@ -608,24 +561,20 @@ end throw(AssertionError("Type $T is not concrete type or concrete tuple")) end - @static if VERSION < v"1.7.0" - nT = T + nT = if T <: Tuple && T != Tuple && !(T isa UnionAll) + Tuple{(ntuple(length(T.parameters)) do i + Base.@_inline_meta + sT = T.parameters[i] + if sT isa TypeVar + Any + elseif sT isa Core.TypeofVararg + Any + else + sT + end + end)...} else - nT = if T <: Tuple && T != Tuple && !(T isa UnionAll) - Tuple{(ntuple(length(T.parameters)) do i - Base.@_inline_meta - sT = T.parameters[i] - if sT isa TypeVar - Any - elseif sT isa Core.TypeofVararg - Any - else - sT - end - end)...} - else - T - end + T end if staticInTup(Val(nT), seen) @@ -740,13 +689,8 @@ declare_allocobj!(mod) = get_function!(mod, "julia.gc_alloc_obj") do T_ppjlvalue = LLVM.PointerType(LLVM.PointerType(T_jlvalue)) T_size_t = convert(LLVM.LLVMType, Int) - @static if VERSION < v"1.8.0" - T_int8 = LLVM.Int8Type() - T_pint8 = LLVM.PointerType(T_int8) - LLVM.FunctionType(T_prjlvalue, [T_pint8, T_size_t, T_prjlvalue]) - else - LLVM.FunctionType(T_prjlvalue, [T_ppjlvalue, T_size_t, T_prjlvalue]) - end + + LLVM.FunctionType(T_prjlvalue, [T_ppjlvalue, T_size_t, T_prjlvalue]) end function emit_allocobj!(B, tag::LLVM.Value, Size::LLVM.Value, needs_workaround::Bool, name::String="") curent_bb = position(B) @@ -760,21 +704,16 @@ function emit_allocobj!(B, tag::LLVM.Value, Size::LLVM.Value, needs_workaround:: T_int8 = LLVM.Int8Type() T_pint8 = LLVM.PointerType(T_int8) - @static if VERSION < v"1.7.0" - ptls = reinsert_gcmarker!(fn, B) - ptls = bitcast!(B, ptls, T_pint8) - else - pgcstack = reinsert_gcmarker!(fn, B) - ct = inbounds_gep!(B, - T_pjlvalue, - bitcast!(B, pgcstack, T_ppjlvalue), - [LLVM.ConstantInt(current_task_offset())]) - ptls_field = inbounds_gep!(B, - T_pjlvalue, - ct, [LLVM.ConstantInt(current_ptls_offset())]) - T_ppint8 = LLVM.PointerType(T_pint8) - ptls = load!(B, T_pint8, bitcast!(B, ptls_field, T_ppint8)) - end + pgcstack = reinsert_gcmarker!(fn, B) + ct = inbounds_gep!(B, + T_pjlvalue, + bitcast!(B, pgcstack, T_ppjlvalue), + [LLVM.ConstantInt(current_task_offset())]) + ptls_field = inbounds_gep!(B, + T_pjlvalue, + ct, [LLVM.ConstantInt(current_ptls_offset())]) + T_ppint8 = LLVM.PointerType(T_pint8) + ptls = load!(B, T_pint8, bitcast!(B, ptls_field, T_ppint8)) if needs_workaround T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) @@ -791,11 +730,7 @@ function emit_allocobj!(B, tag::LLVM.Value, Size::LLVM.Value, needs_workaround:: alloc_obj, alty = declare_allocobj!(mod) - @static if VERSION < v"1.8.0" - return call!(B, alty, alloc_obj, [ptls, Size, tag], name) - else - return call!(B, alty, alloc_obj, [ct, Size, tag], name) - end + return call!(B, alty, alloc_obj, [ct, Size, tag], name) end function emit_allocobj!(B, T::DataType, name::String="") curent_bb = position(B) @@ -832,19 +767,11 @@ declare_writebarrier!(mod) = get_function!(mod, "julia.write_barrier") do T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) LLVM.FunctionType(LLVM.VoidType(), [T_prjlvalue]; vararg=true) end -@static if VERSION < v"1.8.0" -declare_apply_generic!(mod) = get_function!(mod, "jl_apply_generic") do - T_jlvalue = LLVM.StructType(LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, LLVM.PointerType(T_prjlvalue), LLVM.Int32Type()]) -end -else declare_apply_generic!(mod) = get_function!(mod, "ijl_apply_generic") do T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, LLVM.PointerType(T_prjlvalue), LLVM.Int32Type()]) end -end declare_juliacall!(mod) = get_function!(mod, "julia.call") do T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) @@ -877,17 +804,10 @@ function emit_getfield!(B::LLVM.IRBuilder, val::LLVM.Value, fld::LLVM.Value)::LL args = [val, fld] - @static if VERSION < v"1.9.0-" - FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_prjlvalue]; vararg=true) - inv = bitcast!(B, inv, LLVM.PointerType(FT)) - res = call!(B, FT, inv, args) - LLVM.callconv!(res, 37) - else - julia_call, FT = get_function!(mod, "julia.call", - LLVM.FunctionType(T_prjlvalue, - [LLVM.PointerType(gen_FT), T_prjlvalue]; vararg=true)) - res = call!(B, FT, julia_call, LLVM.Value[inv, args...]) - end + julia_call, FT = get_function!(mod, "julia.call", + LLVM.FunctionType(T_prjlvalue, + [LLVM.PointerType(gen_FT), T_prjlvalue]; vararg=true)) + res = call!(B, FT, julia_call, LLVM.Value[inv, args...]) return res end @@ -930,11 +850,7 @@ function emit_box_int32!(B::LLVM.IRBuilder, val::LLVM.Value)::LLVM.Value T_int32 = LLVM.Int32Type() FT = LLVM.FunctionType(T_prjlvalue, [T_int32]) - @static if VERSION < v"1.8-" - box_int32, _ = get_function!(mod, "jl_box_int32", FT) - else - box_int32, _ = get_function!(mod, "ijl_box_int32", FT) - end + box_int32, _ = get_function!(mod, "ijl_box_int32", FT) call!(B, FT, box_int32, [val]) end @@ -948,11 +864,7 @@ function emit_box_int64!(B::LLVM.IRBuilder, val::LLVM.Value)::LLVM.Value T_int64 = LLVM.Int64Type() FT = LLVM.FunctionType(T_prjlvalue, [T_int64]) - @static if VERSION < v"1.8-" - box_int64, _ = get_function!(mod, "jl_box_int64", FT) - else - box_int64, _ = get_function!(mod, "ijl_box_int64", FT) - end + box_int64, _ = get_function!(mod, "ijl_box_int64", FT) call!(B, FT, box_int64, [val]) end @@ -967,25 +879,13 @@ function emit_apply_generic!(B::LLVM.IRBuilder, args)::LLVM.Value T_int32 = LLVM.Int32Type() gen_FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_pprjlvalue, T_int32]) - @static if VERSION < v"1.8-" - inv, _ = get_function!(mod, "jl_apply_generic", gen_FT) - else - inv, _ = get_function!(mod, "ijl_apply_generic", gen_FT) - end + inv, _ = get_function!(mod, "ijl_apply_generic", gen_FT) - @static if VERSION < v"1.9.0-" - FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_prjlvalue]; vararg=true) - inv = bitcast!(B, inv, LLVM.PointerType(FT)) - # call cc37 nonnull {}* bitcast ({}* ({}*, {}**, i32)* @jl_f_apply_type to {}* ({}*, {}*, {}*, {}*)*)({}* null, {}* inttoptr (i64 140150176657296 to {}*), {}* %4, {}* inttoptr (i64 140149987564368 to {}*)) - res = call!(B, FT, inv, args) - LLVM.callconv!(res, 37) - else - # %5 = call nonnull {}* ({}* ({}*, {}**, i32)*, {}*, ...) @julia.call({}* ({}*, {}**, i32)* @jl_f_apply_type, {}* null, {}* inttoptr (i64 139640605802128 to {}*), {}* %4, {}* inttoptr (i64 139640590432896 to {}*)) - julia_call, FT = get_function!(mod, "julia.call", - LLVM.FunctionType(T_prjlvalue, - [LLVM.PointerType(gen_FT), T_prjlvalue]; vararg=true)) - res = call!(B, FT, julia_call, LLVM.Value[inv, args...]) - end + # %5 = call nonnull {}* ({}* ({}*, {}**, i32)*, {}*, ...) @julia.call({}* ({}*, {}**, i32)* @jl_f_apply_type, {}* null, {}* inttoptr (i64 139640605802128 to {}*), {}* %4, {}* inttoptr (i64 139640590432896 to {}*)) + julia_call, FT = get_function!(mod, "julia.call", + LLVM.FunctionType(T_prjlvalue, + [LLVM.PointerType(gen_FT), T_prjlvalue]; vararg=true)) + res = call!(B, FT, julia_call, LLVM.Value[inv, args...]) return res end @@ -1001,25 +901,13 @@ function emit_invoke!(B::LLVM.IRBuilder, args)::LLVM.Value # {} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32, {} addrspace(10)*)* @ijl_invoke gen_FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_pprjlvalue, T_int32, T_prjlvalue]) - @static if VERSION < v"1.8-" - inv = get_function!(mod, "jl_invoke", gen_FT) - else - inv = get_function!(mod, "ijl_invoke", gen_FT) - end + inv = get_function!(mod, "ijl_invoke", gen_FT) - @static if VERSION < v"1.9.0-" - FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_prjlvalue]; vararg=true) - inv = bitcast!(B, inv, LLVM.PointerType(FT)) - # call cc37 nonnull {}* bitcast ({}* ({}*, {}**, i32)* @jl_f_apply_type to {}* ({}*, {}*, {}*, {}*)*)({}* null, {}* inttoptr (i64 140150176657296 to {}*), {}* %4, {}* inttoptr (i64 140149987564368 to {}*)) - res = call!(B, FT, inv, args) - LLVM.callconv!(res, 38) - else - # %5 = call nonnull {}* ({}* ({}*, {}**, i32)*, {}*, ...) @julia.call({}* ({}*, {}**, i32)* @jl_f_apply_type, {}* null, {}* inttoptr (i64 139640605802128 to {}*), {}* %4, {}* inttoptr (i64 139640590432896 to {}*)) - julia_call, FT = get_function!(mod, "julia.call2", - LLVM.FunctionType(T_prjlvalue, - [LLVM.PointerType(generic_FT), T_prjlvalue]; vararg=true)) - res = call!(B, FT, julia_call, [inv, args...]) - end + # %5 = call nonnull {}* ({}* ({}*, {}**, i32)*, {}*, ...) @julia.call({}* ({}*, {}**, i32)* @jl_f_apply_type, {}* null, {}* inttoptr (i64 139640605802128 to {}*), {}* %4, {}* inttoptr (i64 139640590432896 to {}*)) + julia_call, FT = get_function!(mod, "julia.call2", + LLVM.FunctionType(T_prjlvalue, + [LLVM.PointerType(generic_FT), T_prjlvalue]; vararg=true)) + res = call!(B, FT, julia_call, [inv, args...]) return res end @@ -1104,7 +992,6 @@ function Base.showerror(io::IO, ece::EnzymeNoDerivativeError) print(io, msg, '\n') end -@static if VERSION >= v"1.8.0" const JuliaEnzymeNameMap = Dict{String, Any}( "enz_val_true" => Val(true), "enz_val_false" => Val(false), @@ -1119,9 +1006,6 @@ const JuliaEnzymeNameMap = Dict{String, Any}( "enz_no_shadow_exc" => EnzymeNoShadowError, "enz_no_derivative_exc" => EnzymeNoDerivativeError, ) -else -const JuliaEnzymeNameMap = Dict{String, Any}() -end const JuliaGlobalNameMap = Dict{String, Any}( "jl_type_type" => Type, @@ -1204,17 +1088,11 @@ const JuliaGlobalNameMap = Dict{String, Any}( "jl_nothing" => nothing, "jl_anytuple_type" => Tuple, + "jl_vararg_type" => Core.TypeofVararg, + "jl_opaque_closure_type" => Core.OpaqueClosure, + "jl_array_uint64_type" => Array{UInt64, 1}, + "jl_binding_type" => Core.Binding ) -@static if VERSION >= v"1.7.0" - JuliaGlobalNameMap["jl_vararg_type"] = Core.TypeofVararg - JuliaGlobalNameMap["jl_opaque_closure_type"] = Core.OpaqueClosure -end -@static if VERSION >= v"1.8.0" - JuliaGlobalNameMap["jl_array_uint64_type"] = Array{UInt64, 1} -end -@static if VERSION >= v"1.10.0" - JuliaGlobalNameMap["jl_binding_type"] = Core.Binding -end include("absint.jl") @@ -1248,19 +1126,11 @@ function emit_apply_type!(B::LLVM.IRBuilder, Ty, args)::LLVM.Value f_apply_type, _ = get_function!(mod, "jl_f_apply_type", generic_FT) Ty = unsafe_to_llvm(B, Ty) - @static if VERSION < v"1.9.0-" - FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_prjlvalue]; vararg=true) - f_apply_type = bitcast!(B, f_apply_type, LLVM.PointerType(FT)) - # call cc37 nonnull {}* bitcast ({}* ({}*, {}**, i32)* @jl_f_apply_type to {}* ({}*, {}*, {}*, {}*)*)({}* null, {}* inttoptr (i64 140150176657296 to {}*), {}* %4, {}* inttoptr (i64 140149987564368 to {}*)) - tag = call!(B, FT, f_apply_type, LLVM.Value[LLVM.PointerNull(T_prjlvalue), Ty, args...]) - LLVM.callconv!(tag, 37) - else - # %5 = call nonnull {}* ({}* ({}*, {}**, i32)*, {}*, ...) @julia.call({}* ({}*, {}**, i32)* @jl_f_apply_type, {}* null, {}* inttoptr (i64 139640605802128 to {}*), {}* %4, {}* inttoptr (i64 139640590432896 to {}*)) - julia_call, FT = get_function!(mod, "julia.call", - LLVM.FunctionType(T_prjlvalue, - [LLVM.PointerType(generic_FT), T_prjlvalue]; vararg=true)) - tag = call!(B, FT, julia_call, LLVM.Value[f_apply_type, LLVM.PointerNull(T_prjlvalue), Ty, args...]) - end + # %5 = call nonnull {}* ({}* ({}*, {}**, i32)*, {}*, ...) @julia.call({}* ({}*, {}**, i32)* @jl_f_apply_type, {}* null, {}* inttoptr (i64 139640605802128 to {}*), {}* %4, {}* inttoptr (i64 139640590432896 to {}*)) + julia_call, FT = get_function!(mod, "julia.call", + LLVM.FunctionType(T_prjlvalue, + [LLVM.PointerType(generic_FT), T_prjlvalue]; vararg=true)) + tag = call!(B, FT, julia_call, LLVM.Value[f_apply_type, LLVM.PointerNull(T_prjlvalue), Ty, args...]) return tag end @@ -1293,19 +1163,11 @@ function emit_tuple!(B, args)::LLVM.Value generic_FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_pprjlvalue, T_int32]) f_apply_type, _ = get_function!(mod, "jl_f_tuple", generic_FT) - @static if VERSION < v"1.9.0-" - FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_prjlvalue]; vararg=true) - f_apply_type = bitcast!(B, f_apply_type, LLVM.PointerType(FT)) - # call cc37 nonnull {}* bitcast ({}* ({}*, {}**, i32)* @jl_f_apply_type to {}* ({}*, {}*, {}*, {}*)*)({}* null, {}* inttoptr (i64 140150176657296 to {}*), {}* %4, {}* inttoptr (i64 140149987564368 to {}*)) - tag = call!(B, FT, f_apply_type, LLVM.Value[LLVM.PointerNull(T_prjlvalue), args...]) - LLVM.callconv!(tag, 37) - else - # %5 = call nonnull {}* ({}* ({}*, {}**, i32)*, {}*, ...) @julia.call({}* ({}*, {}**, i32)* @jl_f_apply_type, {}* null, {}* inttoptr (i64 139640605802128 to {}*), {}* %4, {}* inttoptr (i64 139640590432896 to {}*)) - julia_call, FT = get_function!(mod, "julia.call", - LLVM.FunctionType(T_prjlvalue, - [LLVM.PointerType(generic_FT), T_prjlvalue]; vararg=true)) - tag = call!(B, FT, julia_call, LLVM.Value[f_apply_type, LLVM.PointerNull(T_prjlvalue), args...]) - end + # %5 = call nonnull {}* ({}* ({}*, {}**, i32)*, {}*, ...) @julia.call({}* ({}*, {}**, i32)* @jl_f_apply_type, {}* null, {}* inttoptr (i64 139640605802128 to {}*), {}* %4, {}* inttoptr (i64 139640590432896 to {}*)) + julia_call, FT = get_function!(mod, "julia.call", + LLVM.FunctionType(T_prjlvalue, + [LLVM.PointerType(generic_FT), T_prjlvalue]; vararg=true)) + tag = call!(B, FT, julia_call, LLVM.Value[f_apply_type, LLVM.PointerNull(T_prjlvalue), args...]) return tag end @@ -1361,24 +1223,15 @@ function emit_methodinstance!(B::LLVM.IRBuilder, func, args)::LLVM.Value T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - @static if VERSION < v"1.8.0-" - worlds, FT = get_function!(mod, "jl_gf_invoke_lookup_worlds", - LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, sizeT, psizeT, psizeT])) - else worlds, FT = get_function!(mod, "jl_gf_invoke_lookup_worlds", LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_prjlvalue, sizeT, psizeT, psizeT])) - end EB = LLVM.IRBuilder() position!(EB, first(LLVM.instructions(LLVM.entry(fn)))) minworld = alloca!(EB, sizeT) maxworld = alloca!(EB, sizeT) store!(B, LLVM.ConstantInt(sizeT, 0), minworld) store!(B, LLVM.ConstantInt(sizeT, -1), maxworld) - @static if VERSION < v"1.8.0-" - methodmatch = call!(B, FT, worlds, LLVM.Value[tag, LLVM.ConstantInt(sizeT, world), minworld, maxworld]) - else methodmatch = call!(B, FT, worlds, LLVM.Value[tag, unsafe_to_llvm(B, nothing), LLVM.ConstantInt(sizeT, world), minworld, maxworld]) - end # emit_jl!(B, methodmatch) # emit_jl!(B, emit_jltypeof!(B, methodmatch)) offset = 1 @@ -2849,38 +2702,10 @@ function from_tape_type(::Type{B}) where {B<:Tuple} end # See get_current_task_from_pgcstack (used from 1.7+) -if VERSION >= v"1.9.1" - current_task_offset() = -(unsafe_load(cglobal(:jl_task_gcstack_offset, Cint)) ÷ sizeof(Ptr{Cvoid})) -elseif VERSION >= v"1.9.0" - if Sys.WORD_SIZE == 64 - current_task_offset() = -13 - else - current_task_offset() = -18 - end -else - if Sys.WORD_SIZE == 64 - current_task_offset() = -12 #1.8/1.7 - else - current_task_offset() = -17 #1.8/1.7 - end -end +current_task_offset() = -(unsafe_load(cglobal(:jl_task_gcstack_offset, Cint)) ÷ sizeof(Ptr{Cvoid})) # See get_current_ptls_from_task (used from 1.7+) -if VERSION >= v"1.9.1" - current_ptls_offset() = unsafe_load(cglobal(:jl_task_ptls_offset, Cint)) ÷ sizeof(Ptr{Cvoid}) -elseif VERSION >= v"1.9.0" - if Sys.WORD_SIZE == 64 - current_ptls_offset() = 15 - else - current_ptls_offset() = 20 - end -else - if Sys.WORD_SIZE == 64 - current_ptls_offset() = 14 # 1.8/1.7 - else - current_ptls_offset() = 19 - end -end +current_ptls_offset() = unsafe_load(cglobal(:jl_task_ptls_offset, Cint)) ÷ sizeof(Ptr{Cvoid}) function store_nonjl_types!(B, startval, p) T_jlvalue = LLVM.StructType(LLVMType[]) @@ -3309,7 +3134,7 @@ function julia_allocator(B, LLVMType, Count, AlignedSize, IsDefault, ZI) # Check if Julia version has https://github.com/JuliaLang/julia/pull/46914 # and also https://github.com/JuliaLang/julia/pull/47076 # and also https://github.com/JuliaLang/julia/pull/48620 - @static if VERSION >= v"1.10.0-DEV.569" + @static if VERSION >= v"1.10.5" needs_dynamic_size_workaround = false else needs_dynamic_size_workaround = !isa(Size, LLVM.ConstantInt) || convert(Int, Size) != 1 @@ -3555,9 +3380,7 @@ function enzyme_ci_cache(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) end end -@static if VERSION < v"1.8" GPUCompiler.ci_cache(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = enzyme_ci_cache(job) -end GPUCompiler.get_interpreter(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = Interpreter.EnzymeInterpreter(enzyme_ci_cache(job), GPUCompiler.method_table(job), job.world, job.config.params.mode) @@ -5601,14 +5424,12 @@ end using Random # returns arg, return function no_type_setting(@nospecialize(specTypes); world=nothing) - @static if VERSION >= v"1.7.0-" - # Even though the julia type here is ptr{int8}, the actual data can be something else - if specTypes.parameters[1] == typeof(Random.XoshiroSimd.xoshiro_bulk_simd) - return (true, false) - end - if specTypes.parameters[1] == typeof(Random.XoshiroSimd.xoshiro_bulk_nosimd) - return (true, false) - end + # Even though the julia type here is ptr{int8}, the actual data can be something else + if specTypes.parameters[1] == typeof(Random.XoshiroSimd.xoshiro_bulk_simd) + return (true, false) + end + if specTypes.parameters[1] == typeof(Random.XoshiroSimd.xoshiro_bulk_nosimd) + return (true, false) end return (false, false) end @@ -7214,11 +7035,6 @@ function _link(job, (mod, adjoint_name, primal_name, TapeType)) # Now invoke the JIT jitted_mod = JIT.add!(mod) - #if VERSION >= v"1.9.0-DEV.115" - # LLVM.dispose(ctx) - #else - # # we cannot dispose of the global unique context - #end adjoint_addr = JIT.lookup(jitted_mod, adjoint_name) adjoint_ptr = pointer(adjoint_addr) @@ -7382,38 +7198,26 @@ end @inline function thunk(mi::Core.MethodInstance, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}, ::Val{ErrIfFuncWritten}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, ABI, ErrIfFuncWritten} ts_ctx = JuliaContext() - ctx = @static if VERSION >= v"1.9.0-DEV.115" - context(ts_ctx) - else - ts_ctx - end + ctx = context(ts_ctx) activate(ctx) try return thunkbase(ctx, mi, Val(#=World=#nothing), FA, A, TT, Val(Mode), Val(width), Val(ModifiedBetween), Val(ReturnPrimal), Val(ShadowInit), ABI, Val(ErrIfFuncWritten)) finally deactivate(ctx) - @static if VERSION >= v"1.9.0-DEV.115" - dispose(ts_ctx) - end + dispose(ts_ctx) end end @inline @generated function thunk(::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}, ::Val{ErrIfFuncWritten}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI, ErrIfFuncWritten} mi = fspec(eltype(FA), TT, World) ts_ctx = JuliaContext() - ctx = @static if VERSION >= v"1.9.0-DEV.115" - context(ts_ctx) - else - ts_ctx - end + ctx = context(ts_ctx) activate(ctx) res = try thunkbase(ctx, mi, Val(World), FA, A, TT, Val(Mode), Val(width), Val(ModifiedBetween), Val(ReturnPrimal), Val(ShadowInit), ABI, Val(ErrIfFuncWritten)) finally deactivate(ctx) - @static if VERSION >= v"1.9.0-DEV.115" - dispose(ts_ctx) - end + dispose(ts_ctx) end return quote Base.@_inline_meta diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index e1652c58956..08b42d587b3 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -52,8 +52,7 @@ function EnzymeInterpreter(cache_or_token, mt::Union{Nothing,Core.MethodTable}, # parameters for inference and optimization InferenceParams(unoptimize_throw_blocks=false), - VERSION >= v"1.8.0-DEV.486" ? OptimizationParams() : - OptimizationParams(unoptimize_throw_blocks=false), + OptimizationParams(), mode ) end @@ -82,9 +81,7 @@ Core.Compiler.may_compress(interp::EnzymeInterpreter) = true # but as far as I understand Enzyme wants "always inlining, except special cased functions", # so I guess we really don't want to discard sources? Core.Compiler.may_discard_trees(interp::EnzymeInterpreter) = false -if VERSION >= v"1.7.0-DEV.577" Core.Compiler.verbose_stmt_info(interp::EnzymeInterpreter) = false -end if isdefined(Base.Experimental, Symbol("@overlay")) Core.Compiler.method_table(interp::EnzymeInterpreter, sv::InferenceState) = @@ -123,21 +120,7 @@ function is_primitive_func(@nospecialize(TT)) end function isKWCallSignature(@nospecialize(TT)) - if VERSION >= v"1.9.0-DEV.1598" - return TT <: Tuple{typeof(Core.kwcall), Any, Any, Vararg} - else - if hasproperty(TT, :parameters) && length(TT.parameters) >= 3 - kwftype = TT.parameters[1] - ft = TT.parameters[3] - if ccall(:jl_argument_method_table, Any, (Any,), ft) === nothing - return false - end - if Core.kwftype(ft) == kwftype - return true - end - end - return false - end + return TT <: Tuple{typeof(Core.kwcall), Any, Any, Vararg} end function simplify_kw(specTypes) @@ -149,8 +132,6 @@ function simplify_kw(specTypes) end # https://github.com/JuliaLang/julia/pull/46965 -@static if VERSION ≥ v"1.9.0-DEV.1535" - import Core.Compiler: CallInfo function Core.Compiler.inlining_policy(interp::EnzymeInterpreter, @nospecialize(src), @nospecialize(info::CallInfo), stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any}) @@ -190,81 +171,4 @@ function Core.Compiler.inlining_policy(interp::EnzymeInterpreter, src::Any, info::CallInfo, stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any}) end -# https://github.com/JuliaLang/julia/pull/41328 -elseif isdefined(Core.Compiler, :is_stmt_inline) - -function Core.Compiler.inlining_policy(interp::EnzymeInterpreter, - @nospecialize(src), stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any}) - - method_table = Core.Compiler.method_table(interp) - specTypes = simplify_kw(mi.specTypes) - - if is_primitive_func(specTypes) - return nothing - end - - if is_alwaysinline_func(specTypes) - @assert src !== nothing - return src - end - - if EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table) - return nothing - end - if interp.mode == API.DEM_ForwardMode - if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table) - return nothing - end - else - if EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table) - return nothing - end - end - - return Base.@invoke Core.Compiler.inlining_policy(interp::AbstractInterpreter, - src::Any, stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any}) -end - -elseif isdefined(Core.Compiler, :inlining_policy) - -import Core.Compiler: InliningTodo, InliningState -struct EnzymeInliningPolicy - interp::EnzymeInterpreter -end -(::EnzymeInliningPolicy)(@nospecialize(src)) = Core.Compiler.default_inlining_policy(src) -Core.Compiler.inlining_policy(interp::EnzymeInterpreter) = EnzymeInliningPolicy(interp) - -function Core.Compiler.resolve_todo(todo::InliningTodo, state::InliningState{S, T, <:EnzymeInliningPolicy}) where {S<:Union{Nothing, Core.Compiler.EdgeTracker}, T} - mi = todo.mi - specTypes = simplify_kw(mi.specTypes) - - if is_primitive_func(specTypes) - return Core.Compiler.compileable_specialization(state.et, todo.spec.match) - end - - if is_alwaysinline_func(specTypes) - @assert false "Need to mark resolve_todo function as alwaysinline, but don't know how" - end - - interp = state.policy.interp - method_table = Core.Compiler.method_table(interp) - if EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table) - return Core.Compiler.compileable_specialization(state.et, todo.spec.match) - end - if interp.mode == API.DEM_ForwardMode - if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table) - return Core.Compiler.compileable_specialization(state.et, todo.spec.match) - end - else - if EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table) - return Core.Compiler.compileable_specialization(state.et, todo.spec.match) - end - end - - return Base.@invoke Core.Compiler.resolve_todo( - todo::InliningTodo, state::InliningState) -end - -end # @static if isdefined(Core.Compiler, :is_stmt_inline) - end diff --git a/src/compiler/orcv1.jl b/src/compiler/orcv1.jl deleted file mode 100644 index bcac867e73f..00000000000 --- a/src/compiler/orcv1.jl +++ /dev/null @@ -1,181 +0,0 @@ -module JIT - -using LLVM -using Libdl -import LLVM: TargetMachine - -import GPUCompiler: CompilerJob, JuliaContext -import ..Compiler -import ..Compiler: API, cpu_name, cpu_features - -export get_trampoline - -# We have one global JIT and TM -const jit = Ref{OrcJIT}() -const tm = Ref{TargetMachine}() - -get_tm() = tm[] - -function __init__() - opt_level = Base.JLOptions().opt_level - if opt_level < 2 - optlevel = LLVM.API.LLVMCodeGenLevelNone - elseif opt_level == 2 - optlevel = LLVM.API.LLVMCodeGenLevelDefault - else - optlevel = LLVM.API.LLVMCodeGenLevelAggressive - end - - tm[] = LLVM.JITTargetMachine(LLVM.triple(), cpu_name(), cpu_features(); optlevel) - LLVM.asm_verbosity!(tm[], true) - - jit[] = OrcJIT(tm[]) # takes ownership of tm - - if haskey(ENV, "ENABLE_GDBLISTENER") - LLVM.register!(jit[], LLVM.GDBRegistrationListener()) - end - atexit() do - dispose(jit[]) - end -end - -mutable struct CallbackContext - job::CompilerJob - stub::Symbol - l_job::ReentrantLock - addr::Ptr{Cvoid} - CallbackContext(job, stub, l_job) = new(job, stub, l_job, C_NULL) -end - -const l_outstanding = Base.ReentrantLock() -const outstanding = Base.IdSet{CallbackContext}() - -# Setup the lazy callback for creating a module -function callback(orc_ref::LLVM.API.LLVMOrcJITStackRef, callback_ctx::Ptr{Cvoid}) - JuliaContext() do ctx - orc = OrcJIT(orc_ref) - cc = Base.unsafe_pointer_to_objref(callback_ctx)::CallbackContext - - # 1. Lock job - lock(cc.l_job) - - # 2. lookup if we are the first - lock(l_outstanding) - if in(cc, outstanding) - delete!(outstanding, cc) - else - unlock(l_outstanding) - unlock(cc.l_job) - - # 3. We are the second callback to run, but we raced the other one - # thus we return the addr from them. - @assert cc.addr != C_NULL - return UInt64(reinterpret(UInt, cc.addr)) - end - unlock(l_outstanding) - - try - thunk = Compiler._link(cc.job, Compiler._thunk(cc.job)) - mode = cc.job.config.params.mode - use_primal = mode == API.DEM_ReverseModePrimal - cc.addr = use_primal ? thunk.primal : thunk.adjoint - - # 4. Update the stub pointer to point to the recently compiled module - set_stub!(orc, string(cc.stub), cc.addr) - finally - unlock(cc.l_job) - end - - # 5. Return the address of the implementation, since we are going to call it now - @assert cc.addr != C_NULL - return UInt64(reinterpret(UInt, cc.addr)) - end -end - -function get_trampoline(job) - l_job = Base.ReentrantLock() - - cc = CallbackContext(job, gensym(:func), l_job) - lock(l_outstanding) - push!(outstanding, cc) - unlock(l_outstanding) - - c_callback = @cfunction(callback, UInt64, (LLVM.API.LLVMOrcJITStackRef, Ptr{Cvoid})) - - orc = jit[] - addr_adjoint = callback!(orc, c_callback, pointer_from_objref(cc)) - create_stub!(orc, string(cc.stub), addr_adjoint) - - return address(orc, string(cc.stub)) -end - - -function resolver(name, ctx) - name = unsafe_string(name) - ptr = try - ## Step 0: Should have already resolved it iff it was in the - ## same module - ## Step 1: See if it's something known to the execution enging - # TODO: Do we need to do this? - # address(jit[], name) - - ## Step 2: Search the program symbols - # - # SearchForAddressOfSymbol expects an unmangled 'C' symbol name. - # Iff we are on Darwin, strip the leading '_' off. - @static if Sys.isapple() - if name[1] == '_' - name = name[2:end] - end - end - - found = false - val = nothing - - hnd = unsafe_load(cglobal(:jl_libjulia_handle, Ptr{Cvoid})) - for (k, v) in Compiler.JuliaGlobalNameMap - if "ejl_"*k == name - val = unsafe_load(Base.reinterpret(Ptr{Ptr{Cvoid}}, Libdl.dlsym(hnd, k))) - found = true - break - end - end - - if !found - for (k, v) in Compiler.JuliaEnzymeNameMap - if "ejl_"*k == name - val = Compiler.unsafe_to_ptr(v) - found = true - break - end - end - end - - if found - val - else - LLVM.API.LLVMSearchForAddressOfSymbol(name) - end - ## Step 4: Lookup in libatomic - # TODO: Do we need to do this? - catch ex - @error "Enzyme: Lookup failed" name exception=(ex, Base.catch_backtrace()) - C_NULL - end - if ptr === C_NULL - @show name - error("Enzyme: Symbol lookup failed. Aborting!") - end - - return UInt64(reinterpret(UInt, ptr)) -end - -function add!(mod) - return compile!(jit[], mod, @cfunction(resolver, UInt64, (Cstring, Ptr{Cvoid}))) -end - -function lookup(jitted_mod, name) - return LLVM.addressin(jit[], jitted_mod, name) -end - -end diff --git a/src/compiler/orcv2.jl b/src/compiler/orcv2.jl index d36b1ca1c10..40d13eea805 100644 --- a/src/compiler/orcv2.jl +++ b/src/compiler/orcv2.jl @@ -1,3 +1,4 @@ + module JIT using LLVM @@ -9,7 +10,7 @@ import ..Compiler import ..Compiler: API, cpu_name, cpu_features @inline function use_ojit() - return (VERSION >= v"1.10.0-DEV.1395") && !Sys.iswindows() + return !Sys.iswindows() end export get_trampoline @@ -132,11 +133,7 @@ function __init__() jit[] = CompilerInstance(lljit, nothing, nothing) end - hnd = @static if VERSION >= v"1.10" - unsafe_load(cglobal(:jl_libjulia_handle, Ptr{Cvoid})) - else - Libdl.dlopen("libjulia") - end + hnd = unsafe_load(cglobal(:jl_libjulia_handle, Ptr{Cvoid})) for (k, v) in Compiler.JuliaGlobalNameMap ptr = unsafe_load(Base.reinterpret(Ptr{Ptr{Cvoid}}, Libdl.dlsym(hnd, k))) LLVM.define(jd_main, absolute_symbol_materialization(mangle(lljit, "ejl_"*k), ptr)) diff --git a/src/compiler/reflection.jl b/src/compiler/reflection.jl index 5b7100f8876..944b0b24989 100644 --- a/src/compiler/reflection.jl +++ b/src/compiler/reflection.jl @@ -35,7 +35,6 @@ function reflect(@nospecialize(func), @nospecialize(A), @nospecialize(types); return llvmf, mod end -# For VERSION >= v"1.9.0-DEV.516" struct jl_llvmf_dump TSM::LLVM.API.LLVMOrcThreadSafeModuleRef F::LLVM.API.LLVMValueRef @@ -46,30 +45,12 @@ function enzyme_code_llvm(io::IO, @nospecialize(func), @nospecialize(A), @nospec raw::Bool=false, debuginfo::Symbol=:default, dump_module::Bool=false, mode=API.DEM_ReverseModeCombined) JuliaContext() do ctx entry_fn, ir = reflect(func, A, types; optimize, run_enzyme, second_stage, mode) - @static if VERSION >= v"1.9.0-DEV.516" - ts_mod = ThreadSafeModule(ir) - if VERSION >= v"1.9.0-DEV.672" - GC.@preserve ts_mod entry_fn begin - value = Ref(jl_llvmf_dump(ts_mod.ref, entry_fn.ref)) - str = ccall(:jl_dump_function_ir, Ref{String}, - (Ptr{jl_llvmf_dump}, Bool, Bool, Ptr{UInt8}), - value, !raw, dump_module, debuginfo) - end - else - GC.@preserve ts_mod entry_fn begin - # N.B. jl_dump_function_ir will `Libc.free` the passed-in pointer - value_ptr = reinterpret(Ptr{jl_llvmf_dump}, - Libc.malloc(sizeof(jl_llvmf_dump))) - unsafe_store!(value_ptr, jl_llvmf_dump(ts_mod.ref, entry_fn.ref)) - str = ccall(:jl_dump_function_ir, Ref{String}, - (Ptr{jl_llvmf_dump}, Bool, Bool, Ptr{UInt8}), - value_ptr, !raw, dump_module, debuginfo) - end - end - else + ts_mod = ThreadSafeModule(ir) + GC.@preserve ts_mod entry_fn begin + value = Ref(jl_llvmf_dump(ts_mod.ref, entry_fn.ref)) str = ccall(:jl_dump_function_ir, Ref{String}, - (LLVM.API.LLVMValueRef, Bool, Bool, Ptr{UInt8}), - entry_fn, !raw, dump_module, debuginfo) + (Ptr{jl_llvmf_dump}, Bool, Bool, Ptr{UInt8}), + value, !raw, dump_module, debuginfo) end print(io, str) end diff --git a/src/compiler/utils.jl b/src/compiler/utils.jl index e4825e52265..6615b6bd405 100644 --- a/src/compiler/utils.jl +++ b/src/compiler/utils.jl @@ -261,72 +261,6 @@ T_ppjlvalue() = LLVM.PointerType(LLVM.PointerType(LLVM.StructType(LLVMType[]))) return v end -if VERSION < v"1.7.0-DEV.1205" - -declare_ptls!(mod) = get_function!(mod, "julia.ptls_states", LLVM.FunctionType(LLVM.PointerType(T_ppjlvalue()))) - -function emit_ptls!(B) - curent_bb = position(B) - fn = LLVM.parent(curent_bb) - mod = LLVM.parent(fn) - func, fty = declare_ptls!(mod) - return call!(B, fty, func) -end - -function get_ptls(func) - entry_bb = first(blocks(func)) - ptls_func = declare_ptls!(LLVM.parent(func)) - - for I in instructions(entry_bb) - if I isa LLVM.CallInst && called_operand(I) == ptls_func - return I - end - end - return nothing -end - -function reinsert_gcmarker!(func, PB=nothing) - ptls = get_ptls(func) - if isnothing(ptls) - B = IRBuilder() - entry_bb = first(blocks(func)) - if !isempty(instructions(entry_bb)) - position!(B, first(instructions(entry_bb))) - else - position!(B, entry_bb) - end - emit_ptls!(B) - else - entry_bb = first(blocks(func)) - fst = first(instructions(entry_bb)) - if fst != ptls - API.moveBefore(ptls, fst, PB === nothing ? C_NULL : PB.ref) - end - ptls - end -end - -function unique_gcmarker!(func) - entry_bb = first(blocks(func)) - ptls_func = declare_ptls!(LLVM.parent(func)) - - found = LLVM.CallInst[] - for I in instructions(entry_bb) - if I isa LLVM.CallInst && called_operand(I) == ptls_func - push!(found, I) - end - end - if length(found) > 1 - for i in 2:length(found) - LLVM.replace_uses!(found[i], found[1]) - Base.unsafe_delete!(entry_bb, found[i]) - end - end - return nothing -end - -else - function declare_pgcstack!(mod) get_function!(mod, "julia.get_pgcstack", LLVM.FunctionType(LLVM.PointerType(T_ppjlvalue()))) end @@ -398,7 +332,6 @@ function unique_gcmarker!(func) end return nothing end -end @inline AnonymousStruct(::Type{U}) where U<:Tuple = NamedTuple{ntuple(i->Symbol(i), Val(length(U.parameters))), U} diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index b95d343bfb3..51aeacf675a 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -9,56 +9,19 @@ module FFI using LinearAlgebra using ObjectFile using Libdl - @static if VERSION >= v"1.7" - function __init__() - @static if VERSION > v"1.8" - global blas_handle = Libdl.dlopen(BLAS.libblastrampoline) - else - global blas_handle = Libdl.dlopen(BLAS.libblas) - end - end - function get_blas_symbols() - symbols = BLAS.get_config().exported_symbols - if BLAS.USE_BLAS64 - return map(n->n*"64_", symbols) - end - return symbols - end - - function lookup_blas_symbol(name) - Libdl.dlsym(blas_handle::Ptr{Cvoid}, name; throw_error=false) - end - else - function __init__() - global blas_handle = Libdl.dlopen(BLAS.libblas) - end - function get_blas_symbols() - symbols = Set{String}() - path = Libdl.dlpath(BLAS.libblas) - ignoreSymbols = Set(String["", "edata", "_edata", "end", "_end", "_bss_start", "__bss_start", ".text", ".data"]) - for meta in readmeta(open(path, "r")) - for s in Symbols(meta) - name = symbol_name(s) - if !Sys.iswindows() && BLAS.vendor() == :openblas64 - endswith(name, "64_") || continue - else - endswith(name, "_") || continue - end - if !in(name, ignoreSymbols) - push!(symbols, name) - end - end - end - symbols = collect(symbols) - if Sys.iswindows() && BLAS.vendor() == :openblas64 - return map(n->n*"64_", symbols) - end - return symbols + function __init__() + global blas_handle = Libdl.dlopen(BLAS.libblastrampoline) + end + function get_blas_symbols() + symbols = BLAS.get_config().exported_symbols + if BLAS.USE_BLAS64 + return map(n->n*"64_", symbols) end + return symbols + end - function lookup_blas_symbol(name) - Libdl.dlsym(blas_handle::Ptr{Cvoid}, name; throw_error=false) - end + function lookup_blas_symbol(name) + Libdl.dlsym(blas_handle::Ptr{Cvoid}, name; throw_error=false) end end @@ -361,11 +324,9 @@ end return has_method(sig, mt.world, nothing) end -@static if VERSION >= v"1.7" @inline function has_method(sig, world::UInt, mt::Core.Compiler.OverlayMethodTable) return has_method(sig, mt.mt, mt.world) || has_method(sig, nothing, mt.world) end -end @inline function is_inactive(tys, world::UInt, mt) specTypes = Interpreter.simplify_kw(Tuple{tys...}) @@ -739,11 +700,7 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) frames = ccall(:jl_lookup_code_address, Any, (Ptr{Cvoid}, Cint,), ptr, 0) if length(frames) >= 1 - @static if VERSION >= v"1.4.0-DEV.123" - fn, file, line, linfo, fromC, inlined = last(frames) - else - fn, file, line, linfo, fromC, inlined, ip = last(frames) - end + fn, file, line, linfo, fromC, inlined = last(frames) # Remember pointer in our global map fn = FFI.memoize!(ptr, string(fn)) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index bcb3b1c4132..8b772e8e24d 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -107,9 +107,7 @@ function EnzymeRules.inactive(::typeof(Base.startswith), ::AbstractString, args. return nothing end -if VERSION >= v"1.9" - Enzyme.EnzymeRules.inactive_noinl(::typeof(Core._compute_sparams), args...) = nothing -end +Enzyme.EnzymeRules.inactive_noinl(::typeof(Core._compute_sparams), args...) = nothing @inline EnzymeRules.inactive_type(v::Type{Nothing}) = true @inline EnzymeRules.inactive_type(v::Type{Union{}}) = true @@ -379,15 +377,6 @@ function EnzymeRules.augmented_primal(config, func::Const{typeof(\)}, ::Type{RT} nothing end -@static if VERSION < v"1.8.0" - UT = Union{ - LinearAlgebra.Diagonal{eltype(AT), onedimensionalize(BT)}, - LinearAlgebra.LowerTriangular{eltype(AT), AT}, - LinearAlgebra.UpperTriangular{eltype(AT), AT}, - LinearAlgebra.LU{eltype(AT), AT}, - LinearAlgebra.QRCompactWY{eltype(AT), AT} - } -else UT = Union{ LinearAlgebra.Diagonal{eltype(AT), onedimensionalize(BT)}, LinearAlgebra.LowerTriangular{eltype(AT), AT}, @@ -395,7 +384,6 @@ else LinearAlgebra.LU{eltype(AT), AT, Vector{Int}}, LinearAlgebra.QRPivoted{eltype(AT), AT, onedimensionalize(BT), Vector{Int}} } -end cache = NamedTuple{(Symbol("1"),Symbol("2"), Symbol("3"), Symbol("4")), Tuple{ eltype(RT), @@ -532,7 +520,6 @@ _zero_unused_elements!(X, ::LowerTriangular) = tril!(X) _zero_unused_elements!(X, ::UnitUpperTriangular) = triu!(X, 1) _zero_unused_elements!(X, ::UnitLowerTriangular) = tril!(X, -1) -@static if VERSION >= v"1.7-" # Force a rule around hvcat_fill as it is type unstable if the tuple is not of the same type (e.g., int, float, int, float) function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple} primal = if EnzymeRules.needs_primal(config) @@ -581,7 +568,6 @@ function EnzymeRules.reverse(config, func::Const{typeof(Base.hvcat_fill!)}, ::Ty end return (nothing, nothing) end -end function EnzymeRules.forward( ::Const{typeof(sort!)}, diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index bdcdd79b258..49622ada901 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -1302,23 +1302,9 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, end pushfirst!(vals, etup) - @static if VERSION < v"1.7.0-" || true - else - mi = emit_methodinstance!(B, func, vals) - end - pushfirst!(vals, unsafe_to_llvm(B, func)) - @static if VERSION < v"1.7.0-" || true - else - pushfirst!(vals, mi) - end - - @static if VERSION < v"1.7.0-" || true cal = emit_apply_generic!(B, vals) - else - cal = emit_invoke!(B, vals) - end debug_from_orig!(gutils, cal, orig) diff --git a/src/rules/parallelrules.jl b/src/rules/parallelrules.jl index 1db4cd8d0bc..54208fe21cc 100644 --- a/src/rules/parallelrules.jl +++ b/src/rules/parallelrules.jl @@ -167,13 +167,8 @@ end # TODO actually do modifiedBetween -@static if VERSION < v"1.8-" - e_tt = Tuple{} - modifiedBetween = (mode != API.DEM_ForwardMode, ) -else e_tt = Tuple{Const{Int}} modifiedBetween = (mode != API.DEM_ForwardMode, false) -end world = enzyme_extract_world(LLVM.parent(position(B))) @@ -374,10 +369,7 @@ end push!(vals, tape) end - @static if VERSION < v"1.8-" - else - push!(vals, new_from_original(gutils, operands(orig)[end-1])) - end + push!(vals, new_from_original(gutils, operands(orig)[end-1])) return refed, LLVM.name(subfunc), dfuncT, vals, thunkTy, TapeType, copies end @@ -392,11 +384,7 @@ end _, sname, dfuncT, vals, thunkTy, _, _ = threadsfor_common(orig, gutils, B, API.DEM_ForwardMode) -@static if VERSION < v"1.8-" - tt = Tuple{thunkTy, dfuncT} -else tt = Tuple{thunkTy, dfuncT, Bool} -end mode = get_mode(gutils) world = enzyme_extract_world(LLVM.parent(position(B))) entry = nested_codegen!(mode, mod, runtime_pfor_fwd, tt, world) @@ -431,17 +419,7 @@ end byRef, sname, dfuncT, vals, thunkTy, _, copies = threadsfor_common(orig, gutils, B, API.DEM_ReverseModePrimal) -@static if VERSION < v"1.8-" - if byRef - emit_error(B, orig, "Enzyme: active variable in Threads.@threads closure "*(string(eltype(eltype(dfuncT))))*" not supported") - end -end - -@static if VERSION < v"1.8-" - tt = Tuple{thunkTy, dfuncT, Val{any_jltypes(EnzymeRules.tape_type(thunkTy))}, Val{byRef}} -else tt = Tuple{thunkTy, dfuncT, Val{any_jltypes(EnzymeRules.tape_type(thunkTy))}, Val{byRef}, Bool} -end mode = get_mode(gutils) world = enzyme_extract_world(LLVM.parent(position(B))) entry = nested_codegen!(mode, mod, runtime_pfor_augfwd, tt, world) @@ -489,11 +467,7 @@ end Vector{TapeType} end -@static if VERSION < v"1.8-" - tt = Tuple{thunkTy, dfuncT, Val{any_jltypes(EnzymeRules.tape_type(thunkTy))}, Val{byRef}, STT } -else tt = Tuple{thunkTy, dfuncT, Val{any_jltypes(EnzymeRules.tape_type(thunkTy))}, Val{byRef}, STT, Bool} -end mode = get_mode(gutils) entry = nested_codegen!(mode, mod, runtime_pfor_rev, tt, world) push!(function_attributes(entry), EnumAttribute("alwaysinline")) diff --git a/src/typetree.jl b/src/typetree.jl index 73b296b95bc..40b01edcce3 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -199,11 +199,7 @@ else end end -if VERSION >= v"1.7.0-DEV.204" - import Base: ismutabletype -else - ismutabletype(T) = isa(T, DataType) && T.mutable -end +import Base: ismutabletype function typetree_inner(@nospecialize(T::Type), ctx, dl, seen::TypeTreeTable) if T isa UnionAll || T isa Union || T == Union{} || Base.isabstracttype(T) @@ -214,10 +210,8 @@ function typetree_inner(@nospecialize(T::Type), ctx, dl, seen::TypeTreeTable) return TypeTree() end - @static if VERSION >= v"1.7.0" - if is_concrete_tuple(T) && any(T2 isa Core.TypeofVararg for T2 in T.parameters) - return TypeTree() - end + if is_concrete_tuple(T) && any(T2 isa Core.TypeofVararg for T2 in T.parameters) + return TypeTree() end if T <: AbstractFloat diff --git a/src/utils.jl b/src/utils.jl index cc2af40c746..ac312e82951 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -124,11 +124,7 @@ function hasfieldcount(@nospecialize(dt)) return true end -if VERSION <= v"1.6" - allocatedinline(@nospecialize(T)) = T.isinlinealloc -else - import Base: allocatedinline -end +import Base: allocatedinline #Excerpt from https://github.com/JuliaGPU/GPUCompiler.jl/blob/v0.19.4/src/jlgen.jl # !!! warning "codegen_world_age below is fundamentally unsound." @@ -154,8 +150,6 @@ using Base: _methods_by_ftype # directly, instead use `cached_compilation` which handles invalidation for you. -if VERSION >= v"1.10.0-DEV.873" - # on 1.10 (JuliaLang/julia#48611) the generated function knows which world it was invoked in function _generated_ex(world, source, ex) @@ -178,16 +172,9 @@ function codegen_world_age_generator(world::UInt, source, self, ft::Type, tt::Ty min_world = Ref{UInt}(typemin(UInt)) max_world = Ref{UInt}(typemax(UInt)) has_ambig = Ptr{Int32}(C_NULL) # don't care about ambiguous results - mthds = if VERSION >= v"1.7.0-DEV.1297" - Base._methods_by_ftype(sig, #=mt=# nothing, #=lim=# -1, - world, #=ambig=# false, - min_world, max_world, has_ambig) - # XXX: use the correct method table to support overlaying kernels - else - Base._methods_by_ftype(sig, #=lim=# -1, + mthds = Base._methods_by_ftype(sig, #=mt=# nothing, #=lim=# -1, world, #=ambig=# false, min_world, max_world, has_ambig) - end mthds === nothing && return _generated_ex(world, source, method_error) length(mthds) == 1 || return _generated_ex(world, source, method_error) @@ -234,95 +221,6 @@ end $(Expr(:meta, :generated, codegen_world_age_generator)) end -else - -# on older versions of Julia we fall back to looking up the current world. this may be wrong -# when the generator is invoked in a different world (TODO: when does this happen?) - -function codegen_world_age_generator(self, ft::Type, tt::Type) - @nospecialize - @assert Core.Compiler.isType(ft) && Core.Compiler.isType(tt) - ft = ft.parameters[1] - tt = tt.parameters[1] - - # validation - ft <: Core.Builtin && error("$(GPUCompiler.unsafe_function_from_type(ft)) is not a generic function") - - # look up the method - method_error = :(throw(MethodError(ft, tt))) - sig = Tuple{ft, tt.parameters...} - min_world = Ref{UInt}(typemin(UInt)) - max_world = Ref{UInt}(typemax(UInt)) - has_ambig = Ptr{Int32}(C_NULL) # don't care about ambiguous results - mthds = if VERSION >= v"1.7.0-DEV.1297" - Base._methods_by_ftype(sig, #=mt=# nothing, #=lim=# -1, - #=world=# typemax(UInt), #=ambig=# false, - min_world, max_world, has_ambig) - # XXX: use the correct method table to support overlaying kernels - else - Base._methods_by_ftype(sig, #=lim=# -1, - #=world=# typemax(UInt), #=ambig=# false, - min_world, max_world, has_ambig) - end - # XXX: using world=-1 is wrong, but the current world isn't exposed to this generator - mthds === nothing && return method_error - length(mthds) == 1 || return method_error - - # look up the method and code instance - mtypes, msp, m = mthds[1] - mi = ccall(:jl_specializations_get_linfo, Ref{MethodInstance}, (Any, Any, Any), m, mtypes, msp) - ci = retrieve_code_info(mi)::CodeInfo - - # prepare a new code info - new_ci = copy(ci) - empty!(new_ci.code) - empty!(new_ci.codelocs) - resize!(new_ci.linetable, 1) # see note below - empty!(new_ci.ssaflags) - new_ci.ssavaluetypes = 0 - new_ci.min_world = min_world[] - new_ci.max_world = max_world[] - new_ci.edges = MethodInstance[mi] - # XXX: setting this edge does not give us proper method invalidation, see - # JuliaLang/julia#34962 which demonstrates we also need to "call" the kernel. - # invoking `code_llvm` also does the necessary codegen, as does calling the - # underlying C methods -- which GPUCompiler does, so everything Just Works. - - # prepare the slots - new_ci.slotnames = Symbol[Symbol("#self#"), :ft, :tt] - new_ci.slotflags = UInt8[0x00 for i = 1:3] - - # return the current world age (which is not technically the codegen world age, - # but works well enough for invalidation purposes) - push!(new_ci.code, ReturnNode(Base.get_world_counter())) - push!(new_ci.ssaflags, 0x00) # Julia's native compilation pipeline (and its verifier) expects `ssaflags` to be the same length as `code` - push!(new_ci.codelocs, 1) # see note below - new_ci.ssavaluetypes += 1 - - # NOTE: we keep the first entry of the original linetable, and use it for location info - # on the call to check_cache. we can't not have a codeloc (using 0 causes - # corruption of the back trace), and reusing the target function's info - # has as advantage that we see the name of the kernel in the backtraces. - - return new_ci -end - -@eval function codegen_world_age(ft, tt) - $(Expr(:meta, :generated_only)) - $(Expr(:meta, - :generated, - Expr(:new, - Core.GeneratedFunctionStub, - :codegen_world_age_generator, - Any[:methodinstance, :ft, :tt], - Any[], - @__LINE__, - QuoteNode(Symbol(@__FILE__)), - true))) -end - -end - export codegen_world_age diff --git a/test/DiffTests.jl b/test/DiffTests.jl index 98851f15599..9dd5abdb4d3 100644 --- a/test/DiffTests.jl +++ b/test/DiffTests.jl @@ -29,13 +29,8 @@ num2num_3(x) = 10.31^(x + x) - x num2num_4(x) = 1.0 num2num_5(x) = 1. / (1. + exp(-x)) -@static if sizeof(Int) == Int64 || VERSION ≥ v"1.7-" const NUMBER_TO_NUMBER_FUNCS = (num2num_1, num2num_2, num2num_3, num2num_4, num2num_5, identity) -else -const NUMBER_TO_NUMBER_FUNCS = (num2num_1, num2num_2, num2num_3, - num2num_4, identity) -end ####################### # f(x::Number)::Array # @@ -120,25 +115,12 @@ end self_weighted_logit(x) = inv(1.0 + exp(-dot(x, x))) -@static if VERSION ≥ v"1.10-" # vec2num_6 fails due to #708 # rosenbrock_4 fails on nightly for unknown reasons const VECTOR_TO_NUMBER_FUNCS = (vec2num_1, vec2num_2, vec2num_3, vec2num_4, vec2num_5, #=vec2num_6,=# vec2num_7, rosenbrock_1, rosenbrock_2, rosenbrock_3, #=rosenbrock_4,=# ackley, self_weighted_logit, first) -elseif sizeof(Int) == Int64 || VERSION ≥ v"1.7-" -# vec2num_6 fails due to #708 -const VECTOR_TO_NUMBER_FUNCS = (vec2num_1, vec2num_2, vec2num_3, vec2num_4, vec2num_5, - #=vec2num_6,=# vec2num_7, rosenbrock_1, rosenbrock_2, - rosenbrock_3, rosenbrock_4, ackley, self_weighted_logit, - first) -else -const VECTOR_TO_NUMBER_FUNCS = (#=vec2num_1,=# vec2num_2, vec2num_3, vec2num_4, vec2num_5, - #=vec2num_6,=# vec2num_7, rosenbrock_1, rosenbrock_2, - rosenbrock_3, rosenbrock_4, #=ackley,=# self_weighted_logit, - first) -end ######################## # f(x::Matrix)::Number # ######################## diff --git a/test/applyiter.jl b/test/applyiter.jl index 11e9ebf37c5..5b55617e553 100644 --- a/test/applyiter.jl +++ b/test/applyiter.jl @@ -305,7 +305,6 @@ end Enzyme.autodiff(Forward, metaconcat, Const(a)) -@static if VERSION ≥ v"1.7-" dres, = Enzyme.autodiff(Forward, midconcat, Duplicated(1.0, 7.0), Duplicated(a, da)) @test length(dres) == 5 @test dres[1] ≈ 7.0 @@ -351,7 +350,6 @@ end @test dres[3] == "b" @test dres[4] == "c" @test dres[5] == "d" -end y = [(-92.0, -93.0), (-97.9, -911.2)] dy = [(-913.7, -915.2), (-9100.02, -9304.1)] diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 2b1e9bc6213..b9a705941c9 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -62,9 +62,7 @@ end end @test autodiff(Forward, f4, Duplicated(1.5, 1.0))[1] == 1.5 - @static if VERSION < v"1.7-" || VERSION >= v"1.8-" - @test autodiff(Forward, f4, BatchDuplicated(1.5, (1.0, 2.0)))[1] == (var"1"=1.5, var"2"=3.0) - end + @test autodiff(Forward, f4, BatchDuplicated(1.5, (1.0, 2.0)))[1] == (var"1"=1.5, var"2"=3.0) @test autodiff(Reverse, f4, Active(1.5))[1][1] == 1.5 @test autodiff(Reverse, f4, Active(4.0))[1][1] == 0.5 @test autodiff(Reverse, f4, Active(6.0))[1][1] == 0.0 @@ -285,7 +283,6 @@ end end end -@static if VERSION > v"1.8" @testset "Cholesky" begin function symmetric_definite(n :: Int=10) α = one(Float64) @@ -619,7 +616,6 @@ end end end end -end @testset "rand and randn rules" begin # Distributed as x + unit normal + uniform diff --git a/test/mixed.jl b/test/mixed.jl index dae06230738..4de521414b4 100644 --- a/test/mixed.jl +++ b/test/mixed.jl @@ -26,7 +26,6 @@ end @test 6.2 ≈ Enzyme.autodiff(Reverse, outmixedmul2, Const, Duplicated(res, dres), Active(3.1))[1][2] end -@static if VERSION >= v"1.8-" @testset "Batched Byref Mixed Activity" begin res = Ref(4.7) dres = Ref(1.0) @@ -35,7 +34,6 @@ end @test 6.2 ≈ sig[1][2][1] @test 3*6.2 ≈ sig[1][2][2] end -end function tupmixedmul(x::Float64) vec = [x] @@ -59,7 +57,6 @@ end @test 6.2 ≈ Enzyme.autodiff(Reverse, outtupmixedmul, Const, Duplicated(res, dres), Active(3.1))[1][2] end -@static if VERSION >= v"1.8-" @testset "Batched Byref Tuple Mixed Activity" begin res = Ref(4.7) dres = Ref(1.0) @@ -68,4 +65,3 @@ end @test 6.2 ≈ sig[1][2][1] @test 3*6.2 ≈ sig[1][2][2] end -end diff --git a/test/rrules.jl b/test/rrules.jl index be4c4f14242..6c2a965b0e0 100644 --- a/test/rrules.jl +++ b/test/rrules.jl @@ -295,7 +295,6 @@ function plaquette_sum(U) end -@static if VERSION >= v"1.9" @testset "No caching byref julia" begin U = Complex{Float64}[3.0 + 4.0im] dU = Complex{Float64}[0.0] @@ -304,8 +303,6 @@ end @test dU[1] ≈ 7 * ( 3.0 + 4.0im ) end -end - struct Closure v::Vector{Float64} diff --git a/test/runtests.jl b/test/runtests.jl index 250fa20c975..dc826cd5b53 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,13 +1,3 @@ -# HACK: work around Pkg.jl#2500 -if VERSION < v"1.8-" -test_project = Base.active_project() -preferences_file = joinpath(dirname(@__DIR__), "LocalPreferences.toml") -test_preferences_file = joinpath(dirname(test_project), "LocalPreferences.toml") -if isfile(preferences_file) && !isfile(test_preferences_file) - cp(preferences_file, test_preferences_file) -end -end - # # work around https://github.com/JuliaLang/Pkg.jl/issues/1585 # using Pkg # Pkg.develop(PackageSpec(; path=joinpath(dirname(@__DIR__), "lib", "EnzymeTestUtils"))) @@ -80,7 +70,8 @@ function test_matrix_to_number(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1) @test isapproxfn((Enzyme.Forward, f), dx_fwd, dx_fd; rtol=rtol, atol=atol, kwargs...) end -Aqua.test_all(Enzyme, unbound_args=false, piracies=false, deps_compat=false) +# Aqua.test_all(Enzyme, unbound_args=false, piracies=false, deps_compat=false, stale_deps=(;:ignore=>[:EnzymeTestUtils])) +# Aqua.test_all(Enzyme, unbound_args=false, piracies=false, deps_compat=false, stale_deps=(;:ignore=>[:EnzymeTestUtils])) include("abi.jl") include("typetree.jl") @@ -91,12 +82,9 @@ include("typetree.jl") include("kwrules.jl") include("kwrrules.jl") include("internal_rules.jl") - @static if VERSION ≥ v"1.9-" - # XXX invalidation does not work on Julia 1.8 - include("ruleinvalidation.jl") - end + include("ruleinvalidation.jl") end -@static if VERSION ≥ v"1.7-" || !Sys.iswindows() +@static if !Sys.iswindows() include("blas.jl") end @@ -394,17 +382,11 @@ make3() = (1.0, 2.0, 3.0) test_scalar(cbrt, 1.0f0; rtol = 1.0e-5, atol = 1.0e-5) test_scalar(Base.sinh, 1.0) test_scalar(Base.cosh, 1.0) - if sizeof(Int) == Int64 || VERSION ≥ v"1.7-" test_scalar(Base.sinc, 2.2) - end test_scalar(Base.FastMath.sinh_fast, 1.0) test_scalar(Base.FastMath.cosh_fast, 1.0) - if sizeof(Int) == Int64 || VERSION ≥ v"1.7-" test_scalar(Base.FastMath.exp_fast, 1.0) - end - if sizeof(Int) == Int64 || VERSION ≥ v"1.7-" test_scalar(Base.exp10, 1.0) - end test_scalar(Base.exp2, 1.0) test_scalar(Base.expm1, 1.0) test_scalar(x->rem(x, 1), 0.7) @@ -454,11 +436,7 @@ end Const{typeof(dot)}, Active, Duplicated{typeof(thunk_A)} ) @test Tuple{Float64,Float64} === TapeType - Ret = if VERSION < v"1.8-" - Active{Float64} - else - Active - end + Ret = Active fwd, rev = Enzyme.autodiff_deferred_thunk( ReverseSplitWithPrimal, TapeType, @@ -474,31 +452,28 @@ end @test all(dA .== def_dA) @test all(dA .== thunk_dA) - @static if VERSION < v"1.8-" - else - function kernel(len, A) - for i in 1:len - A[i] *= A[i] - end + function kernel(len, A) + for i in 1:len + A[i] *= A[i] end + end - A = Array{Float64}(undef, 64) - dA = Array{Float64}(undef, 64) + A = Array{Float64}(undef, 64) + dA = Array{Float64}(undef, 64) - A .= (1:1:64) - dA .= 1 + A .= (1:1:64) + dA .= 1 - function aug_fwd(ctx, f::FT, ::Val{ModifiedBetween}, args...) where {ModifiedBetween, FT} - TapeType = Enzyme.tape_type(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...) - forward, reverse = Enzyme.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), TapeType, Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...) - forward(Const(f), Const(ctx), args...)[1] - return nothing - end + function aug_fwd(ctx, f::FT, ::Val{ModifiedBetween}, args...) where {ModifiedBetween, FT} + TapeType = Enzyme.tape_type(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...) + forward, reverse = Enzyme.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), TapeType, Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...) + forward(Const(f), Const(ctx), args...)[1] + return nothing + end - ModifiedBetween = Val((false, false, true)) + ModifiedBetween = Val((false, false, true)) - aug_fwd(64, kernel, ModifiedBetween, Duplicated(A, dA)) - end + aug_fwd(64, kernel, ModifiedBetween, Duplicated(A, dA)) end @@ -880,34 +855,31 @@ end @test autodiff(Forward, arsum, Duplicated(inp, dinp))[1] ≈ 2.0 - # On Julia 1.6 the gradients are wrong (1.0 too large) and on 1.7 it errors - @static if VERSION ≥ v"1.8-" - function f1(m) - s = 0.0 - for (i, col) in enumerate(eachcol(m)) - s += i * sum(col) - end - return s + function f1(m) + s = 0.0 + for (i, col) in enumerate(eachcol(m)) + s += i * sum(col) end + return s + end - m = Float64[1 2 3; 4 5 6; 7 8 9] - dm = zero(m) - autodiff(Reverse, f1, Active, Duplicated(m, dm)) - @test dm == Float64[1 2 3; 1 2 3; 1 2 3] + m = Float64[1 2 3; 4 5 6; 7 8 9] + dm = zero(m) + autodiff(Reverse, f1, Active, Duplicated(m, dm)) + @test dm == Float64[1 2 3; 1 2 3; 1 2 3] - function f2(m) - s = 0.0 - for (i, col) in enumerate(eachrow(m)) - s += i * sum(col) - end - return s + function f2(m) + s = 0.0 + for (i, col) in enumerate(eachrow(m)) + s += i * sum(col) end - - dm = zero(m) - autodiff(Reverse, f2, Active, Duplicated(m, dm)) - @test dm == Float64[1 1 1; 2 2 2; 3 3 3] + return s end + dm = zero(m) + autodiff(Reverse, f2, Active, Duplicated(m, dm)) + @test dm == Float64[1 1 1; 2 2 2; 3 3 3] + function my_conv_3(x, w) y = zeros(Float64, 2, 3, 4, 5) for hi in axes(y, 3) @@ -2300,7 +2272,6 @@ function bc2_loss_function(x, scale, bias) return sum(abs2, bc2_affine_normalize(identity, x_, xmean, xvar, scale_, bias_, 1e-5)) end -@static if VERSION ≥ v"1.8-" @testset "Broadcast noalias" begin x = ones(30) @@ -2315,7 +2286,6 @@ end Enzyme.autodiff(Reverse, bc2_loss_function, Active, Duplicated(x, Enzyme.make_zero(x)), Duplicated(sc, Enzyme.make_zero(sc)), Duplicated(bi, Enzyme.make_zero(bi))) end -end function solve_cubic_eq(poly::AbstractVector{Complex{T}}) where T a1 = 1 / @inbounds poly[1] @@ -2870,14 +2840,12 @@ end @test y[1] == [0.0, 1.0, 0.0] @test y[2] == [0.0, 0.0, 1.0] -@static if VERSION ≥ v"1.9-" x = @SArray [5.0 0.0 6.0] dx = Enzyme.gradient(Forward, prod, x) @test dx[1] ≈ 0 @test dx[2] ≈ 30 @test dx[3] ≈ 0 end -end function sparse_eval(x::Vector{Float64}) @@ -2887,7 +2855,6 @@ function sparse_eval(x::Vector{Float64}) return A[1] end -@static if VERSION ≥ v"1.7-" @testset "Type Unstable SparseArrays" begin x = [3.1, 2.7, 8.2] dx = [0.0, 0.0, 0.0] @@ -2897,7 +2864,6 @@ end @test x ≈ [3.1, 2.7, 8.2] @test dx ≈ [-1.0, 43.74, 0] end -end @testset "Simple Jacobian" begin @test Enzyme.jacobian(Enzyme.Forward, x->2*x, 3.0) ≈ 2.0 @@ -3357,11 +3323,7 @@ end @test res[1] ≈ 0.2 # broken as the return of an apply generic is {primal, primal} # but since the return is abstractfloat doing the - @static if VERSION ≥ v"1.9-" && !(VERSION ≥ v"1.10-" ) - @test_broken res[2] ≈ 1.0 - else - @test res[2] ≈ 1.0 - end + @test res[2] ≈ 1.0 end @inline function uns_mymean(f, A, ::Type{T}, c) where T @@ -3412,7 +3374,6 @@ end @test dx ≈ Float64[1.0] end -@static if VERSION < v"1.8-" || VERSION >= v"1.9-" @inline extract_bc(bc, ::Val{:north}) = (bc.north) @inline extract_bc(bc, ::Val{:top}) = (bc.top) @@ -3437,7 +3398,6 @@ end Enzyme.API.looseTypeAnalysis!(false) end -end @testset "Static activity" begin @@ -3539,11 +3499,9 @@ end @test res.x == 5.0 - if VERSION > v"1.10-" - res = autodiff(Reverse, g, Active, Active(Moo(3.0, "a")))[1][1] + res = autodiff(Reverse, g, Active, Active(Moo(3.0, "a")))[1][1] - @test res.x == 5.0 - end + @test res.x == 5.0 end @testset "Type preservation" begin @@ -3800,15 +3758,11 @@ end @test autodiff(Reverse, f8, Active, Active(1.5))[1][1] == 0 @test autodiff(Forward, f8, Duplicated(1.5, 1.0))[1] == 0 - # On Julia 1.6 the gradients are wrong (0.7 not 1.2) and on 1.7 it errors - @static if VERSION ≥ v"1.8-" - f9(x) = sum(quantile([1.0, x], [0.5, 0.7])) - @test autodiff(Reverse, f9, Active, Active(2.0))[1][1] == 1.2 - @test autodiff(Forward, f9, Duplicated(2.0, 1.0))[1] == 1.2 - end + f9(x) = sum(quantile([1.0, x], [0.5, 0.7])) + @test autodiff(Reverse, f9, Active, Active(2.0))[1][1] == 1.2 + @test autodiff(Forward, f9, Duplicated(2.0, 1.0))[1] == 1.2 end -@static if VERSION >= v"1.7-" @testset "hvcat_fill" begin ar = Matrix{Float64}(undef, 2, 3) dar = [1.0 2.0 3.0; 4.0 5.0 6.0] @@ -3824,26 +3778,19 @@ end end # TEST EXTENSIONS -@static if VERSION ≥ v"1.9-" - using SpecialFunctions - @testset "SpecialFunctions ext" begin - lgabsg(x) = SpecialFunctions.logabsgamma(x)[1] - test_scalar(lgabsg, 1.0; rtol = 1.0e-5, atol = 1.0e-5) - test_scalar(lgabsg, 1.0f0; rtol = 1.0e-5, atol = 1.0e-5) - end - - using ChainRulesCore - @testset "ChainRulesCore ext" begin - include("ext/chainrulescore.jl") - end - include("ext/logexpfunctions.jl") - - @testset "BFloat16s ext" begin - include("ext/bfloat16s.jl") - end +using SpecialFunctions +@testset "SpecialFunctions ext" begin + lgabsg(x) = SpecialFunctions.logabsgamma(x)[1] + test_scalar(lgabsg, 1.0; rtol = 1.0e-5, atol = 1.0e-5) + test_scalar(lgabsg, 1.0f0; rtol = 1.0e-5, atol = 1.0e-5) end - - +using ChainRulesCore +@testset "ChainRulesCore ext" begin + include("ext/chainrulescore.jl") end +include("ext/logexpfunctions.jl") +@testset "BFloat16s ext" begin + include("ext/bfloat16s.jl") +end From dffb431ddae5bae755fd646485b83b96c02fd07b Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 12 Sep 2024 09:15:26 -0500 Subject: [PATCH 16/87] Update Project.toml --- lib/EnzymeTestUtils/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/EnzymeTestUtils/Project.toml b/lib/EnzymeTestUtils/Project.toml index 80dd2ede758..05e5e6b94e3 100644 --- a/lib/EnzymeTestUtils/Project.toml +++ b/lib/EnzymeTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeTestUtils" uuid = "12d8515a-0907-448a-8884-5fe00fdf1c5a" authors = ["Seth Axen ", "William Moses ", "Valentin Churavy "] -version = "0.1.7" +version = "0.1.8" [deps] ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" From b140e0e1f3cc70d8135b3f78f371285a675fd188 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 12 Sep 2024 17:04:56 -0500 Subject: [PATCH 17/87] Fix MixedDuplicated ABI error on primalerror (#1815) --- src/compiler.jl | 10 +++++----- test/mixed.jl | 17 +++++++++++++++++ 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index f82ae6c1358..ba538364094 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -4928,7 +4928,7 @@ function get_return_info(jlrettype)::Tuple{Union{Nothing, Type}, Union{Nothing, end # Modified from GPUCompiler/src/irgen.jl:365 lower_byval -function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function, actualRetType::Type, RetActivity, TT) +function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function, actualRetType::Type, RetActivity, TT, run_enzyme) entry_ft = LLVM.function_type(entry_f) RT = LLVM.return_type(entry_ft) @@ -4985,7 +4985,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function push!(wrapper_types, typ) push!(wrapper_attrs, LLVM.Attribute[]) elseif arg.cc != GPUCompiler.BITS_REF - if TT != nothing && (TT.parameters[arg.arg_i] <: MixedDuplicated || TT.parameters[arg.arg_i] <: BatchMixedDuplicated) + if TT != nothing && (TT.parameters[arg.arg_i] <: MixedDuplicated || TT.parameters[arg.arg_i] <: BatchMixedDuplicated) && run_enzyme push!(boxedArgs, arg.arg_i) push!(raisedArgs, arg.arg_i) push!(wrapper_types, LLVM.PointerType(typ, Derived)) @@ -4996,7 +4996,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function end else # bits ref, and not boxed - if TT != nothing && (TT.parameters[arg.arg_i] <: MixedDuplicated || TT.parameters[arg.arg_i] <: BatchMixedDuplicated) + if TT != nothing && (TT.parameters[arg.arg_i] <: MixedDuplicated || TT.parameters[arg.arg_i] <: BatchMixedDuplicated) && run_enzyme push!(boxedArgs, arg.arg_i) push!(wrapper_types, typ) push!(wrapper_attrs, LLVM.Attribute[EnumAttribute("noalias")]) @@ -5931,7 +5931,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; sret = get_return_info(k.ci.rettype)[2] !== nothing if sret cur = llvmfn == primalf - llvmfn, _, boxedArgs, loweredArgs = lower_convention(mi.specTypes, mod, llvmfn, k.ci.rettype, Duplicated, nothing) + llvmfn, _, boxedArgs, loweredArgs = lower_convention(mi.specTypes, mod, llvmfn, k.ci.rettype, Duplicated, nothing, params.run_enzyme) if cur primalf = llvmfn lowerConvention = false @@ -6002,7 +6002,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; primalf, returnRoots = primalf, false if lowerConvention - primalf, returnRoots, boxedArgs, loweredArgs = lower_convention(source_sig, mod, primalf, actualRetType, job.config.params.rt, TT) + primalf, returnRoots, boxedArgs, loweredArgs = lower_convention(source_sig, mod, primalf, actualRetType, job.config.params.rt, TT, params.run_enzyme) end if primal_job.config.target isa GPUCompiler.NativeCompilerTarget diff --git a/test/mixed.jl b/test/mixed.jl index 4de521414b4..dc4c510b234 100644 --- a/test/mixed.jl +++ b/test/mixed.jl @@ -65,3 +65,20 @@ end @test 6.2 ≈ sig[1][2][1] @test 3*6.2 ≈ sig[1][2][2] end + +struct Foobar + x::Int + y::Int + z::Int + q::Int + r::Float64 +end + +function bad_abi(fb) + v = fb.x + throw(AssertionError("saw bad val $v")) +end + +@testset "Mixed PrimalError" begin + @test_throws AssertionError autodiff(Reverse, bad_abi, MixedDuplicated(Foobar(2, 3, 4, 5, 6.0), Ref(Foobar(2, 3, 4, 5, 6.0)))) +end \ No newline at end of file From e63c1b75f1e5d1158722e02c6d048dfe9fbe30ae Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki <40514306+aviatesk@users.noreply.github.com> Date: Sat, 14 Sep 2024 02:43:38 +0900 Subject: [PATCH 18/87] adjustments to the latest inlining interface changes (#1350) * adjustments to the latest inlining interface changes * Update src/compiler/interpreter.jl * rebase * Update interpreter.jl * fix `inlining_policy` overload --------- Co-authored-by: William S. Moses --- .gitignore | 1 + src/compiler/interpreter.jl | 113 +++++++++++++++++++++++------------- 2 files changed, 75 insertions(+), 39 deletions(-) diff --git a/.gitignore b/.gitignore index 594a8584c48..e7ee8ed2f51 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ *.jl.cov *.jl.mem /Manifest.toml +/Manifest-v*.toml /test/Manifest.toml /docs/Manifest.toml /docs/build/ diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 08b42d587b3..46ca95ab326 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -68,20 +68,17 @@ else end # No need to do any locking since we're not putting our results into the runtime cache -Core.Compiler.lock_mi_inference(interp::EnzymeInterpreter, mi::MethodInstance) = nothing -Core.Compiler.unlock_mi_inference(interp::EnzymeInterpreter, mi::MethodInstance) = nothing +Core.Compiler.lock_mi_inference(::EnzymeInterpreter, ::MethodInstance) = nothing +Core.Compiler.unlock_mi_inference(::EnzymeInterpreter, ::MethodInstance) = nothing -function Core.Compiler.add_remark!(interp::EnzymeInterpreter, sv::InferenceState, msg) -end - -Core.Compiler.may_optimize(interp::EnzymeInterpreter) = true -Core.Compiler.may_compress(interp::EnzymeInterpreter) = true +Core.Compiler.may_optimize(::EnzymeInterpreter) = true +Core.Compiler.may_compress(::EnzymeInterpreter) = true # From @aviatesk: # `may_discard_trees = true`` means a complicated (in terms of inlineability) source will be discarded, # but as far as I understand Enzyme wants "always inlining, except special cased functions", # so I guess we really don't want to discard sources? -Core.Compiler.may_discard_trees(interp::EnzymeInterpreter) = false -Core.Compiler.verbose_stmt_info(interp::EnzymeInterpreter) = false +Core.Compiler.may_discard_trees(::EnzymeInterpreter) = false +Core.Compiler.verbose_stmt_info(::EnzymeInterpreter) = false if isdefined(Base.Experimental, Symbol("@overlay")) Core.Compiler.method_table(interp::EnzymeInterpreter, sv::InferenceState) = @@ -123,7 +120,7 @@ function isKWCallSignature(@nospecialize(TT)) return TT <: Tuple{typeof(Core.kwcall), Any, Any, Vararg} end -function simplify_kw(specTypes) +function simplify_kw(@nospecialize specTypes) if isKWCallSignature(specTypes) return Base.tuple_type_tail(Base.tuple_type_tail(specTypes)) else @@ -131,44 +128,82 @@ function simplify_kw(specTypes) end end -# https://github.com/JuliaLang/julia/pull/46965 import Core.Compiler: CallInfo -function Core.Compiler.inlining_policy(interp::EnzymeInterpreter, - @nospecialize(src), @nospecialize(info::CallInfo), stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any}) - +struct NoInlineCallInfo <: CallInfo + info::CallInfo # wrapped call + tt # ::Type + kind::Symbol + NoInlineCallInfo(@nospecialize(info::CallInfo), @nospecialize(tt), kind::Symbol) = new(info, tt, kind) +end +Core.Compiler.nsplit_impl(info::NoInlineCallInfo) = Core.Compiler.nsplit(info.info) +Core.Compiler.getsplit_impl(info::NoInlineCallInfo, idx::Int) = Core.Compiler.getsplit(info.info, idx) +Core.Compiler.getresult_impl(info::NoInlineCallInfo, idx::Int) = Core.Compiler.getresult(info.info, idx) +struct AlwaysInlineCallInfo <: CallInfo + info::CallInfo # wrapped call + tt # ::Type + AlwaysInlineCallInfo(@nospecialize(info::CallInfo), @nospecialize(tt)) = new(info, tt) +end +Core.Compiler.nsplit_impl(info::AlwaysInlineCallInfo) = Core.Compiler.nsplit(info.info) +Core.Compiler.getsplit_impl(info::AlwaysInlineCallInfo, idx::Int) = Core.Compiler.getsplit(info.info, idx) +Core.Compiler.getresult_impl(info::AlwaysInlineCallInfo, idx::Int) = Core.Compiler.getresult(info.info, idx) + +using Core.Compiler: ArgInfo, StmtInfo, AbsIntState +function Core.Compiler.abstract_call_gf_by_type(interp::EnzymeInterpreter, @nospecialize(f), + arginfo::ArgInfo, si::StmtInfo, @nospecialize(atype), sv::AbsIntState, max_methods::Int) + ret = @invoke Core.Compiler.abstract_call_gf_by_type(interp::AbstractInterpreter, f::Any, + arginfo::ArgInfo, si::StmtInfo, atype::Any, sv::AbsIntState, max_methods::Int) + callinfo = ret.info method_table = Core.Compiler.method_table(interp) - specTypes = simplify_kw(mi.specTypes) - + specTypes = simplify_kw(atype) if is_primitive_func(specTypes) - @safe_debug "Blocking inlining for primitive func" mi.specTypes - return nothing - end - - if is_alwaysinline_func(specTypes) - @safe_debug "Forcing inlining for primitive func" mi.specTypes - @assert src !== nothing - return src + callinfo = NoInlineCallInfo(callinfo, atype, :primitive) + elseif is_alwaysinline_func(specTypes) + callinfo = AlwaysInlineCallInfo(callinfo, atype) + elseif EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table) + callinfo = NoInlineCallInfo(callinfo, atype, :inactive) + elseif interp.mode == API.DEM_ForwardMode + if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table) + callinfo = NoInlineCallInfo(callinfo, atype, :frule) + end + elseif EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table) + callinfo = NoInlineCallInfo(callinfo, atype, :rrule) end - - if EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table) - @safe_debug "Blocking inlining due to inactive rule" mi.specTypes - return nothing + @static if VERSION ≥ v"1.11-" + return Core.Compiler.CallMeta(ret.rt, ret.exct, ret.effects, callinfo) + else + return Core.Compiler.CallMeta(ret.rt, ret.effects, callinfo) end +end - if interp.mode == API.DEM_ForwardMode - if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table) - @safe_debug "Blocking inlining due to frule" mi.specTypes - return nothing - end +let # overload `inlining_policy` + @static if VERSION ≥ v"1.11.0-DEV.879" + sigs_ex = :(interp::EnzymeInterpreter, @nospecialize(src), @nospecialize(info::Core.Compiler.CallInfo), stmt_flag::UInt32) + args_ex = :(interp::AbstractInterpreter, src::Any, info::Core.Compiler.CallInfo, stmt_flag::UInt32) else - if EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table) - @safe_debug "Blocking inling due to rrule" mi.specTypes + sigs_ex = :(interp::EnzymeInterpreter, + @nospecialize(src), @nospecialize(info::Core.Compiler.CallInfo), stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any}) + args_ex = :(interp::AbstractInterpreter, + src::Any, info::Core.Compiler.CallInfo, stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any}) + end + @eval function Core.Compiler.inlining_policy($(sigs_ex.args...)) + if info isa NoInlineCallInfo + if info.kind === :primitive + @safe_debug "Blocking inlining for primitive func" info.tt + elseif info.kind === :inactive + @safe_debug "Blocking inlining due to inactive rule" info.tt + elseif info.kind === :frule + @safe_debug "Blocking inlining due to frule" info.tt + else + @assert info.kind === :rrule + @safe_debug "Blocking inlining due to rrule" info.tt + end return nothing + elseif info isa AlwaysInlineCallInfo + @safe_debug "Forcing inlining for primitive func" info.tt + return src end + return @invoke Core.Compiler.inlining_policy($(args_ex.args...)) end - - return Base.@invoke Core.Compiler.inlining_policy(interp::AbstractInterpreter, - src::Any, info::CallInfo, stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any}) end -end +end # module Interpreter From 3528723712fd6af6f25822ad9f3f4a214be0b4f4 Mon Sep 17 00:00:00 2001 From: ExpandingMan Date: Fri, 13 Sep 2024 18:10:49 -0400 Subject: [PATCH 19/87] more comprehensive unit tests for gradient and jacobian (#1773) * more comprehensive unit tests for gradient and jacobian * More extensive sugar tests * more comprehensive unit tests for gradient and jacobian * More extensive sugar tests * hopefully working now? * Update Project.toml * try fixing tests on 1.6 * Update Project.toml --------- Co-authored-by: Billy Moses Co-authored-by: William Moses --- src/Enzyme.jl | 10 +- test/runtests.jl | 275 ++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 267 insertions(+), 18 deletions(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 450d96ffb01..a5b949c60a1 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1341,7 +1341,7 @@ For functions who return other types, this function will retun an array or tuple of shape `size(output)` of values of the input type. ``` """ -@inline function jacobian(::ReverseMode{#=ReturnPrimal=#false,RABI, ErrIfFuncWritten}, f::F, x::X, n_outs::Val{n_out_val}, ::Val{chunk}) where {F, X, chunk, n_out_val, RABI<:ABI, ErrIfFuncWritten} +@inline function jacobian(::ReverseMode{#=ReturnPrimal=#false,RABI, #=Holomorphic=#false, ErrIfFuncWritten}, f::F, x::X, n_outs::Val{n_out_val}, ::Val{chunk}) where {F, X, chunk, n_out_val, RABI<:ABI, ErrIfFuncWritten} num = ((n_out_val + chunk - 1) ÷ chunk) if chunk == 0 @@ -1417,7 +1417,7 @@ of shape `size(output)` of values of the input type. end end -@inline function jacobian(::ReverseMode{#=ReturnPrimal=#false,RABI, ErrIfFuncWritten}, f::F, x::X, n_outs::Val{n_out_val}, ::Val{1} = Val(1)) where {F, X, n_out_val,RABI<:ABI, ErrIfFuncWritten} +@inline function jacobian(::ReverseMode{#=ReturnPrimal=#false,RABI, #=Holomorphic=#false, ErrIfFuncWritten}, f::F, x::X, n_outs::Val{n_out_val}, ::Val{1} = Val(1)) where {F, X, n_out_val,RABI<:ABI, ErrIfFuncWritten} XT = Core.Typeof(x) MD = Compiler.active_reg_inner(XT, #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState tt′ = MD ? Tuple{MixedDuplicated{XT}} : Tuple{Duplicated{XT}} @@ -1466,12 +1466,12 @@ end end end -@inline function jacobian(::ReverseMode{ReturnPrimal,RABI, ErrIfFuncWritten}, f::F, x::X) where {ReturnPrimal, F, X, RABI<:ABI, ErrIfFuncWritten} +@inline function jacobian(::ReverseMode{ReturnPrimal,RABI,Holomorphic,ErrIfFuncWritten}, f::F, x::X) where {ReturnPrimal, F, X, RABI<:ABI, Holomorphic, ErrIfFuncWritten} res = f(x) jac = if res isa AbstractArray - jacobian(ReverseMode{false,RABI, ErrIfFuncWritten}(), f, x, Val(length(jac))) + jacobian(ReverseMode{false,RABI, Holomorphic, ErrIfFuncWritten}(), f, x, Val(length(res))) elseif res isa AbstractFloat - gradient(ReverseMode{false,RABI, ErrIfFuncWritten}(), f, x) + gradient(ReverseMode{false,RABI, Holomorphic, ErrIfFuncWritten}(), f, x) else throw(AssertionError("Unsupported return type of function for reverse-mode jacobian, $(Core.Typeof(res))")) end diff --git a/test/runtests.jl b/test/runtests.jl index dc826cd5b53..6ffd3dd09c5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,6 +16,15 @@ using InlineStrings using Enzyme_jll @info "Testing against" Enzyme_jll.libEnzyme +# symbol is \simeq +# this is basically a more flexible version of ≈ +(≃)(a, b) = (≈)(a, b) +(≃)(a::Tuple, b::Tuple) = all(xy -> xy[1] ≃ xy[2], zip(a,b)) +function (≃)(a::AbstractArray{<:Tuple}, b::AbstractArray{<:Tuple}) + size(a) == size(b) || return false + all(xy -> xy[1] ≃ xy[2], zip(a,b)) +end + function isapproxfn(fn, args...; kwargs...) isapprox(args...; kwargs...) end @@ -2865,6 +2874,259 @@ end @test dx ≈ [-1.0, 43.74, 0] end + +# these are used in gradient and jacobian tests +struct InpStruct + i1::Float64 + i2::Float64 + i3::Float64 +end +struct OutStruct + i1::Float64 + i2::Float64 + i3::Float64 +end + +for A ∈ (:InpStruct, :OutStruct) + @eval (≃)(a::$A, b::$A) = (a.i1 ≃ b.i1) && (a.i2 ≃ b.i2) && (a.i3 ≃ b.i3) + @eval function (≃)(a::AbstractArray{<:$A}, b::AbstractArray{<:$A}) + size(a) == size(b) || return false + all(xy -> xy[1] ≃ xy[2], zip(a, b)) + end +end + + +#NOTE: this is needed because of problems with hvcat on 1.10 and something inexplicable on 1.6 +# suffice it to say it's not good that this is required, please remove when possible +mkarray(sz, args...) = reshape(vcat(args...), sz) + +@testset "Gradient and Jacobian Outputs" begin + + scalar = 3.0 + + # ∂ scalar / ∂ scalar + @test Enzyme.gradient(Enzyme.Forward, x -> x^2, scalar) ≈ 6.0 + @test Enzyme.gradient(Enzyme.Reverse, x -> x^2, scalar) ≈ 6.0 + @test Enzyme.jacobian(Enzyme.Forward, x -> x^2, scalar) ≈ 6.0 + @test Enzyme.jacobian(Enzyme.Reverse, x -> x^2, scalar) ≈ 6.0 + @test Enzyme.gradient(Enzyme.Forward, x -> 2*x, scalar) ≈ 2.0 + @test Enzyme.gradient(Enzyme.Reverse, x -> 2*x, scalar) ≈ 2.0 + @test Enzyme.jacobian(Enzyme.Forward, x -> 2*x, scalar) ≈ 2.0 + @test Enzyme.jacobian(Enzyme.Reverse, x -> 2*x, scalar) ≈ 2.0 + + # ∂ vector / ∂ scalar + @test Enzyme.gradient(Enzyme.Forward, x -> [2*x, x^2], scalar) ≈ [2.0, 6.0] + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [2*x, x^2], scalar) ≈ [2.0, 6.0] + + @test Enzyme.jacobian(Enzyme.Forward, x -> [2*x, x^2], scalar) ≈ [2.0, 6.0] + @test Enzyme.jacobian(Enzyme.Reverse, x -> [2*x, x^2], scalar) ≈ [2.0, 6.0] + + + # ∂ tuple / ∂ scalar + @test Enzyme.gradient(Enzyme.Forward, x -> (2*x, x^2), scalar) ≃ (2.0, 6.0) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (2*x, x^2), scalar) ≈ [2.0, 6.0] + + @test Enzyme.jacobian(Enzyme.Forward, x -> (2*x, x^2), scalar) ≃ (2.0, 6.0) + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (2*x, x^2), scalar) ≃ (2.0, 6.0) + + mkarray1 = x -> mkarray((2,2),2*x,sin(x),x^2,exp(x)) + + # ∂ matrix / ∂ scalar + @test Enzyme.gradient(Enzyme.Forward, mkarray1, scalar) ≈ [2.0 6.0; cos(scalar) exp(scalar)] + @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray1, scalar) ≈ [2.0 6.0; cos(scalar) exp(scalar)] + + @test Enzyme.jacobian(Enzyme.Forward, mkarray1, scalar) ≈ [2.0 6.0; cos(scalar) exp(scalar)] + @test Enzyme.jacobian(Enzyme.Reverse, mkarray1, scalar) ≈ [2.0 6.0; cos(scalar) exp(scalar)] + + # ∂ struct / ∂ scalar + @test Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x, x^2, x^3), scalar) == OutStruct(1.0,2*scalar,3*scalar^2) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> InpStruct(x, x^2, x^3), scalar) == (OutStruct(1.0,2.0,3.0),) + @test Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x, x^2, x^3), scalar) == OutStruct(1.0,2*scalar,3*scalar^2) + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> InpStruct(x, x^2, x^3), scalar) == (OutStruct(1.0,2.0,3.0),) + + + + vector = [2.7, 3.1] + + # ∂ scalar / ∂ vector + @test Enzyme.gradient(Enzyme.Forward, x -> x[1] * x[2], vector) ≃ (vector[2],vector[1]) + @test Enzyme.gradient(Enzyme.Reverse, x -> x[1] * x[2], vector) ≈ [vector[2], vector[1]] + @test Enzyme.jacobian(Enzyme.Forward, x -> x[1] * x[2], vector) ≈ [vector[2], vector[1]] + @test Enzyme.jacobian(Enzyme.Reverse, x -> x[1] * x[2], vector) ≈ [vector[2], vector[1]] + + + # ∂ vector / ∂ vector + @test Enzyme.gradient(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector) ≃ + ([vector[2], -sin(vector[1])], [vector[1], 1.0]) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector) ≈ ([vector[2], -sin(vector[1])], [vector[1], 1.0]) + @test Enzyme.jacobian(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector) ≈ + [vector[2] vector[1]; -sin(vector[1]) 1.0] + @test Enzyme.jacobian(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector) ≈ + [vector[2] vector[1]; -sin(vector[1]) 1.0] + + # ∂ tuple / ∂ vector + @test Enzyme.gradient(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector) ≃ + ((vector[2], -sin(vector[1])), (vector[1], 1.0)) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector) ≈ + ([vector[2], -sin(vector[1])], [vector[1], 1.0]) + @test Enzyme.jacobian(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector) ≃ + [(vector[2], -sin(vector[1])), (vector[1], 1.0)] + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector) + + mkarray2 = x -> mkarray((2,2), x[1]*x[2], exp(x[2]), cos(x[1])+x[2], x[1]) + + # ∂ matrix / ∂ vector + @test Enzyme.gradient(Enzyme.Forward, mkarray2, vector) ≃ + ([vector[2] -sin(vector[1]); 0.0 1.0], [vector[1] 1.0; exp(vector[2]) 0.0]) + @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray2, vector) + @test Enzyme.jacobian(Enzyme.Forward, mkarray2, vector) ≈ + mkarray((2,2,2), vector[2], 0.0, -sin(vector[1]), 1.0, vector[1], exp(vector[2]), 1.0, 0.0) + @test Enzyme.jacobian(Enzyme.Reverse, mkarray2, vector) ≈ + mkarray((2,2,2), vector[2], 0.0, -sin(vector[1]), 1.0, vector[1], exp(vector[2]), 1.0, 0.0) + + # ∂ struct / ∂ vector + @test Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), vector) ≃ + (OutStruct(vector[2], -sin(vector[1]), 0.0), OutStruct(vector[1], 1.0, exp(vector[2]))) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector) ≈ ([vector[2], -sin(vector[1])], [vector[1], 1.0]) + + @test Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), vector) ≃ + [OutStruct(vector[2], -sin(vector[1]), 0.0), OutStruct(vector[1], 1.0, exp(vector[2]))] + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector) ≈ ([vector[2], -sin(vector[1])], [vector[1], 1.0]) + + + + tuplev = (2.7, 3.1) + + # ∂ scalar / ∂ tuple + @test Enzyme.gradient(Enzyme.Forward, x -> x[1] * x[2], tuplev) ≃ (tuplev[2],tuplev[1]) + @test Enzyme.gradient(Enzyme.Reverse, x -> x[1] * x[2], tuplev) ≃ (tuplev[2],tuplev[1]) + @test Enzyme.jacobian(Enzyme.Forward, x -> x[1] * x[2], tuplev) ≃ (tuplev[2],tuplev[1]) + @test Enzyme.jacobian(Enzyme.Reverse, x -> x[1] * x[2], tuplev) ≃ (tuplev[2],tuplev[1]) + + # ∂ vector / ∂ tuple + @test Enzyme.gradient(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev) ≃ + ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev) ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev) ≈ + [tuplev[2] tuplev[1]; -sin(tuplev[1]) 1.0] + @test Enzyme.jacobian(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev) ≃ + [(tuplev[2], tuplev[1]), (-sin(tuplev[1]), 1.0)] + + # ∂ tuple / ∂ tuple + @test Enzyme.gradient(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev) ≃ + ((vector[2], -sin(vector[1])), (vector[1], 1.0)) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev) ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) + @test Enzyme.jacobian(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev) ≃ + ((tuplev[2], -sin(tuplev[1])), (tuplev[1], 1.0)) + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev) ≈ + [tuplev[2] tuplev[1]; -sin(tuplev[1]) 1.0] + + # ∂ matrix / ∂ tuple + @test Enzyme.gradient(Enzyme.Forward, mkarray2, tuplev) ≃ + ([tuplev[2] -sin(tuplev[1]); 0.0 1.0], [tuplev[1] 1.0; exp(tuplev[2]) 0.0]) + @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray2, tuplev) + @test_broken Enzyme.jacobian(Enzyme.Forward, mkarray2, tuplev) ≈ + [tuplev[2] -sin(tuplev[1]); 0.0 1.0;;; tuplev[1] 1.0; exp(tuplev[2]) 0.0] + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> mkarray2, tuplev) ≈ + [tuplev[2] -sin(tuplev[1]); 0.0 1.0;;; tuplev[1] 1.0; exp(tuplev[2]) 0.0] + + # ∂ struct / ∂ tuple + @test Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), tuplev) ≃ + (OutStruct(tuplev[2], -sin(tuplev[1]), 0.0), OutStruct(tuplev[1], 1.0, exp(tuplev[2]))) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev) ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) + + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), tuplev) ≃ + [OutStruct(tuplev[2], -sin(tuplev[1]), 0.0), OutStruct(tuplev[1], 1.0, exp(tuplev[2]))] + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev) ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) + + + + matrix = [2.7 3.1; 4.7 5.6] + + # ∂ scalar / ∂ matrix + @test Enzyme.gradient(Enzyme.Forward, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix) ≃ + (matrix[1,2], matrix[2,2], matrix[1,1], matrix[2,1]) + @test Enzyme.gradient(Enzyme.Reverse, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix) ≈ [matrix[1,2] matrix[1,1]; matrix[2,2] matrix[2,1]] + @test Enzyme.jacobian(Enzyme.Forward, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix) ≈ [matrix[1,2] matrix[1,1]; matrix[2,2] matrix[2,1]] + @test Enzyme.jacobian(Enzyme.Reverse, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix) ≈ [matrix[1,2] matrix[1,1]; matrix[2,2] matrix[2,1]] + + # ∂ vector / ∂ matrix + @test Enzyme.gradient(Enzyme.Forward, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix) ≃ + ([matrix[1,2], 0.0], [0.0, matrix[2,2]], [matrix[1,1], 0.0], [0.0, matrix[2,1]]) + @test_broken Enzyme.gradient(Enzyme.Reverse, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix) + # again we can't use array construction syntax because of 1.6 + @test Enzyme.jacobian(Enzyme.Forward, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix) ≈ + mkarray((2,2,2), matrix[1,2], 0.0, 0.0, matrix[2,2], matrix[1,1], 0.0, 0.0, matrix[2,1]) + @test Enzyme.jacobian(Enzyme.Reverse, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix) ≈ + mkarray((2,2,2), matrix[1,2], 0.0, 0.0, matrix[2,2], matrix[1,1], 0.0, 0.0, matrix[2,1]) + + # ∂ tuple / ∂ matrix + @test Enzyme.gradient(Enzyme.Forward, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix) ≃ ((matrix[1,2], 0.0), (0.0, matrix[2,2]), (matrix[1,1], 0.0), (0.0, matrix[2,1])) + @test_broken Enzyme.gradient(Enzyme.Reverse, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix) + @test Enzyme.jacobian(Enzyme.Forward, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix) ≃ + [(matrix[1,2],0.0) (matrix[1,1],0.0); (0.0,matrix[2,2]) (0.0,matrix[2,1])] + @test_broken Enzyme.jacobian(Enzyme.Reverse, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix) + + mkarray3 = x -> mkarray((2,2), x[1,1]*x[1,2], exp(x[1,1])+x[2,2], x[2,1]*x[2,2], sin(x[1,2])+x[2,1]) + + # ∂ matrix / ∂ matrix + @test Enzyme.gradient(Enzyme.Forward, mkarray3, matrix) ≃ + ([matrix[1,2] 0.0; exp(matrix[1,1]) 0.0], [0.0 matrix[2,2]; 0.0 1.0], [matrix[1,1] 0.0; 0.0 cos(matrix[1,2])], [0.0 matrix[2,1]; 1.0 0.0]) + @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray3, matrix) + # array construction syntax broken on 1.6 + @test Enzyme.jacobian(Enzyme.Forward, mkarray3, matrix) ≈ + mkarray((2,2,2,2), matrix[1,2],exp(matrix[1,1]),0.0,0.0,0.0,0.0,matrix[2,2],1.0, + matrix[1,1],0.0,0.0,cos(matrix[1,2]),0.0,1.0,matrix[2,1],0.0) + @test Enzyme.jacobian(Enzyme.Reverse, mkarray3, matrix) ≈ + mkarray((2,2,2,2), matrix[1,2],exp(matrix[1,1]),0.0,0.0,0.0,0.0,matrix[2,2],1.0, + matrix[1,1],0.0,0.0,cos(matrix[1,2]),0.0,1.0,matrix[2,1],0.0) + + # ∂ tuple / ∂ matrix + @test Enzyme.gradient(Enzyme.Forward, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix) ≃ + (OutStruct(matrix[1,2], 0.0, exp(matrix[1,1])), OutStruct(0.0, matrix[2,2], 0.0), OutStruct(matrix[1,1], 0.0, 0.0), OutStruct(0.0, matrix[2,1], 1.0)) + @test_broken Enzyme.gradient(Enzyme.Reverse, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix) + @test Enzyme.jacobian(Enzyme.Forward, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix) ≃ + [OutStruct(matrix[1,2],0.0, exp(matrix[1,1])) OutStruct(matrix[1,1],0.0,0.0); OutStruct(0.0,matrix[2,2],0.0) OutStruct(0.0,matrix[2,1], 1.0)] + @test_broken Enzyme.jacobian(Enzyme.Reverse, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix) + + + istruct = InpStruct(2.7, 3.1, 4.7) + + # ∂ scalar / ∂ struct + @test_broken Enzyme.gradient(Enzyme.Forward, x -> x.i1 * x.i2 + x.i3, istruct) + @test Enzyme.gradient(Enzyme.Reverse, x -> x.i1 * x.i2 + x.i3, istruct) ≃ InpStruct(istruct.i2, istruct.i1, 1.0) + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> x.i1 * x.i2 + x.i3, istruct) + @test Enzyme.jacobian(Enzyme.Reverse, x -> x.i1 * x.i2 + x.i3, istruct) ≃ InpStruct(istruct.i2, istruct.i1, 1.0) + + # ∂ vector / ∂ struct + @test_broken Enzyme.gradient(Enzyme.Forward, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct) + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct) + @test Enzyme.jacobian(Enzyme.Reverse, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct) ≃ [InpStruct(istruct.i2, istruct.i1, 0.0), InpStruct(1.0, 0.0, -sin(istruct.i3))] + + # ∂ tuple / ∂ struct + @test_broken Enzyme.gradient(Enzyme.Forward, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct) + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct) + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct) + + mkarray4 = x -> mkarray((2,2), x.i1*x.i2, exp(x.i2), cos(x.i3)+x.i1, x.i1) + + # ∂ matrix / ∂ struct + @test_broken Enzyme.gradient(Enzyme.Forward, x -> [x.i1 * x.i2 cos(x.i3) + x.i1; exp(x.i2) x.i1], istruct) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x.i1 * x.i2 cos(x.i3) + x.i1; exp(x.i2) x.i1], istruct) + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> [x.i1 * x.i2 cos(x.i3) + x.i1; exp(x.i2) x.i1], istruct) + @test Enzyme.jacobian(Enzyme.Reverse, mkarray4, istruct) ≃ + [InpStruct(istruct.i2, istruct.i1, 0.0) InpStruct(1.0, 0.0, -sin(istruct.i3)); + InpStruct(0.0, exp(istruct.i2), 0.0) InpStruct(1.0, 0.0, 0.0)] + + # ∂ struct / ∂ struct + @test_broken Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct) + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct) + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct) +end + @testset "Simple Jacobian" begin @test Enzyme.jacobian(Enzyme.Forward, x->2*x, 3.0) ≈ 2.0 @test Enzyme.jacobian(Enzyme.Forward, x->[x, 2*x], 3.0) ≈ [1.0, 2.0] @@ -2922,12 +3184,6 @@ end @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] - struct InpStruct - i1::Float64 - i2::Float64 - i3::Float64 - end - fillinpabs2(x) = [(x.i1*x.i1+x.i2*x.i2+x.i3*x.i3), 10*(x.i1*x.i1+x.i2*x.i2+x.i3*x.i3), 100*(x.i1*x.i1+x.i2*x.i2+x.i3*x.i3), 1000*(x.i1*x.i1+x.i2*x.i2+x.i3*x.i3)] x2 = InpStruct(1.0, 2.0, 3.0) @@ -2946,12 +3202,6 @@ end @test jac[3] == InpStruct(200.0, 400.0, 600.0) @test jac[4] == InpStruct(2000.0, 4000.0, 6000.0) - struct OutStruct - i1::Float64 - i2::Float64 - i3::Float64 - end - filloutabs2(x) = OutStruct(sum(abs2, x), 10*sum(abs2, x), 100*sum(abs2, x)) jac = Enzyme.jacobian(Enzyme.Forward, filloutabs2, x) @@ -2986,7 +3236,6 @@ end @test jac[1, 3] == OutStruct(10.0, 100.0, 1000.0) @test jac[2, 3] == OutStruct(12.0, 120.0, 1200.0) - end From 24c58efe5bac678aad9ba66fe97a18f9044b3e3d Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 14 Sep 2024 23:24:42 -0500 Subject: [PATCH 20/87] Ensure typeof doesn't get cached (#1826) --- src/compiler.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/compiler.jl b/src/compiler.jl index ba538364094..0d9e1471560 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3541,6 +3541,12 @@ function annotate!(mod, mode) push!(function_attributes(fn), LLVM.StringAttribute("enzyme_shouldrecompute")) end end + for fname in ("julia.typeof",) + if haskey(fns, fname) + fn = fns[fname] + push!(function_attributes(fn), LLVM.StringAttribute("enzyme_nocache")) + end + end for fname in ("jl_excstack_state","ijl_excstack_state", "ijl_field_index", "jl_field_index") if haskey(fns, fname) From afedaac9dac1fc039aa585307398247cbbb54c68 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 14 Sep 2024 23:24:54 -0500 Subject: [PATCH 21/87] Improve deferred error message (#1827) * Improve deferred error message * fix --- src/rules/llvmrules.jl | 46 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index fb93016063b..962a4f46afd 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -1329,6 +1329,46 @@ end end +@register_fwd function deferred_fwd(B, orig, gutils, normalR, shadowR) + if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) + return true + end + err = emit_error(B, orig, "There is a known issue in GPUCompiler.jl which is preventing higher-order AD of this code.\nPlease see https://github.com/JuliaGPU/GPUCompiler.jl/issues/629 for more information and to alert the GPUCompiler authors of your use case and need.") + newo = new_from_original(gutils, orig) + API.moveBefore(newo, err, B) + normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + if shadowR != C_NULL && normal !== nothing + unsafe_store!(shadowR, normal.ref) + end + return false +end + +@register_aug function deferred_augfwd(B, orig, gutils, normalR, shadowR, tapeR) + if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) + return true + end + err = emit_error(B, orig, "There is a known issue in GPUCompiler.jl which is preventing higher-order AD of this code.\nPlease see https://github.com/JuliaGPU/GPUCompiler.jl/issues/629 for more information and to alert the GPUCompiler authors of your use case and need.") + newo = new_from_original(gutils, orig) + API.moveBefore(newo, err, B) + normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + if shadowR != C_NULL && normal !== nothing + unsafe_store!(shadowR, normal.ref) + end + # Delete the primal code + if normal !== nothing + unsafe_store!(normalR, C_NULL) + else + ni = new_from_original(gutils, orig) + API.EnzymeGradientUtilsErase(gutils, ni) + end + return false +end + +@register_rev function deferred_rev(B, orig, gutils, tape) + return nothing +end + + function register_handler!(variants, augfwd_handler, rev_handler, fwd_handler=nothing) for variant in variants if augfwd_handler !== nothing && rev_handler !== nothing @@ -1522,6 +1562,12 @@ end @revfunc(finalizer_rev), @fwdfunc(finalizer_fwd), ) + register_handler!( + ("deferred_codegen",), + @augfunc(deferred_augfwd), + @revfunc(deferred_rev), + @fwdfunc(deferred_fwd), + ) register_handler!( ("jl_array_grow_end","ijl_array_grow_end"), @augfunc(jl_array_grow_end_augfwd), From 9ff45682cfed861054950460a8f3c4f5712e9300 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 15 Sep 2024 01:37:50 -0500 Subject: [PATCH 22/87] Fix diffuse rooting (#1829) * add comment * fix --- src/rules/customrules.jl | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 96286239876..005557c65ab 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -1141,6 +1141,22 @@ end if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) && !has_aug_fwd_rule(orig, gutils) return (false, true) end + non_rooting_use = false + fop = called_operand(orig)::LLVM.Function + for (i, v) in enumerate(operands(orig)[1:end-1]) + if v == val + if !any(a->kind(a) == kind(StringAttribute("enzymejl_returnRoots")), collect(parameter_attributes(fop, i))) + non_rooting_use = true + break + end + end + end + + # If the operand is just rooting, we don't need it and should override defaults + if !non_rooting_use + return (false, false) + end + # don't use default and always require the arg return (true, false) end From 0b6effaa00ef511c9e8b6b3474d4b70e07f69ed2 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 15 Sep 2024 02:12:00 -0500 Subject: [PATCH 23/87] Fix nightly precompile (#1830) --- src/compiler/interpreter.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 46ca95ab326..61a433af4c2 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -40,6 +40,12 @@ end function EnzymeInterpreter(cache_or_token, mt::Union{Nothing,Core.MethodTable}, world::UInt, mode::API.CDerivativeMode) @assert world <= Base.get_world_counter() + parms = @static if VERSION < v"1.12" + InferenceParams(unoptimize_throw_blocks=false), + else + InferenceParams() + end + return EnzymeInterpreter( cache_or_token, mt, @@ -51,7 +57,7 @@ function EnzymeInterpreter(cache_or_token, mt::Union{Nothing,Core.MethodTable}, world, # parameters for inference and optimization - InferenceParams(unoptimize_throw_blocks=false), + parms, OptimizationParams(), mode ) From ffcb7dd977fada76efa88c10d80f3b47d7bdc9a2 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 15 Sep 2024 09:15:54 -0500 Subject: [PATCH 24/87] Runtime activity in mode (#1816) * Runtime activity in mode * fixup * cr * rules * fix fwd * fr * fr * fr * fr * fr * Update test_forward.jl * Update test_forward.jl * fix * fix * inv * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * improve gradient --- Project.toml | 31 ++--- docs/src/faq.md | 18 ++- examples/custom_rule.jl | 7 +- ext/EnzymeChainRulesCoreExt.jl | 2 +- lib/EnzymeCore/Project.toml | 4 +- lib/EnzymeCore/src/EnzymeCore.jl | 62 ++++++---- lib/EnzymeCore/src/rules.jl | 83 +++++++++---- lib/EnzymeCore/test/runtests.jl | 4 +- lib/EnzymeTestUtils/Project.toml | 8 +- lib/EnzymeTestUtils/src/test_forward.jl | 3 +- lib/EnzymeTestUtils/src/test_reverse.jl | 4 +- lib/EnzymeTestUtils/test/test_forward.jl | 4 +- src/Enzyme.jl | 141 +++++++++++++---------- src/api.jl | 64 ++-------- src/compiler.jl | 38 +++--- src/compiler/reflection.jl | 4 +- src/gradientutils.jl | 1 + src/internal_rules.jl | 99 ++++++++-------- src/rules/customrules.jl | 11 +- src/rules/jitrules.jl | 55 ++++----- src/rules/llvmrules.jl | 6 +- src/rules/parallelrules.jl | 15 +-- src/rules/typeunstablerules.jl | 8 +- test/Project.toml | 2 +- test/ext/.chainrulescore.jl.swp | Bin 0 -> 12288 bytes test/kwrrules.jl | 20 ++-- test/kwrules.jl | 8 +- test/mixedrrule.jl | 8 +- test/rrules.jl | 36 +++--- test/ruleinvalidation.jl | 10 +- test/rules.jl | 24 ++-- test/runtests.jl | 43 +++---- test/sc.jl | 64 ++++++++++ 33 files changed, 495 insertions(+), 392 deletions(-) create mode 100644 test/ext/.chainrulescore.jl.swp create mode 100644 test/sc.jl diff --git a/Project.toml b/Project.toml index 15890547e1b..1ea7b5c05b1 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.13.0" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" Enzyme_jll = "7cc45869-7501-5eee-bdea-0790c847d4ef" GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55" LLVM = "929cbde3-209d-540e-8aea-75f648917ca0" @@ -16,11 +17,25 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +[weakdeps] +BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[extensions] +EnzymeBFloat16sExt = "BFloat16s" +EnzymeChainRulesCoreExt = "ChainRulesCore" +EnzymeLogExpFunctionsExt = "LogExpFunctions" +EnzymeSpecialFunctionsExt = "SpecialFunctions" +EnzymeStaticArraysExt = "StaticArrays" + [compat] BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" -EnzymeCore = "0.7.8" +EnzymeCore = "0.8" Enzyme_jll = "0.0.146, 0.0.148" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27" LLVM = "6.1, 7, 8, =9.0" @@ -31,23 +46,9 @@ SpecialFunctions = "1, 2" StaticArrays = "1" julia = "1.10" -[extensions] -EnzymeBFloat16sExt = "BFloat16s" -EnzymeChainRulesCoreExt = "ChainRulesCore" -EnzymeLogExpFunctionsExt = "LogExpFunctions" -EnzymeSpecialFunctionsExt = "SpecialFunctions" -EnzymeStaticArraysExt = "StaticArrays" - [extras] BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - -[weakdeps] -BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" diff --git a/docs/src/faq.md b/docs/src/faq.md index 5e57a8ada8a..88c0cce3b98 100644 --- a/docs/src/faq.md +++ b/docs/src/faq.md @@ -268,7 +268,7 @@ Enzyme.autodiff(Reverse, f, Active(1.2), Const(Vector{Float64}(undef, 1)), Const ((0.0, nothing, nothing, nothing),) ``` -Passing in a dupliacted (e.g. differentiable) variable for `tmp` now leads to the correct answer. +Passing in a duplicated (e.g. differentiable) variable for `tmp` now leads to the correct answer. ```jldoctest storage Enzyme.autodiff(Reverse, f, Active(1.2), Duplicated(Vector{Float64}(undef, 1), zeros(1)), Const(1), Const(5)) # Correct (returns 10.367999999999999 == 1.2^4 * 5) @@ -278,9 +278,11 @@ Enzyme.autodiff(Reverse, f, Active(1.2), Duplicated(Vector{Float64}(undef, 1), z ((10.367999999999999, nothing, nothing, nothing),) ``` -However, even if we ignore the semantic guarantee provided by marking `tmp` as constant, another issue arises. When computing the original function, intermediate computations (like in `f` above) can use `tmp` for temporary storage. When computing the derivative, Enzyme also needs additional temporary storage space for the corresponding derivative variables as well. If `tmp` is marked as Const, Enzyme does not have any temporary storage space for the derivatives! +## Runtime Activity -Recent versions of Enzyme will attempt to error when they detect these latter types of situations, which we will refer to as `activity unstable`. This term is chosen to mirror the Julia notion of type-unstable code (e.g. where a type is not known at compile time). If an expression is activity unstable, it could either be constant, or active, depending on data not known at compile time. For example, consider the following: +When computing the derivative of mutable variables, Enzyme also needs additional temporary storage space for the corresponding derivative variables. If an argument `tmp` is marked as Const, Enzyme does not have any temporary storage space for the derivatives! + +Enzyme will error when they detect these latter types of situations, which we will refer to as `activity unstable`. This term is chosen to mirror the Julia notion of type-unstable code (e.g. where a type is not known at compile time). If an expression is activity unstable, it could either be constant, or active, depending on data not known at compile time. For example, consider the following: ```julia function g(cond, active_var, constant_var) @@ -293,7 +295,7 @@ end Enzyme.autodiff(Forward, g, Const(condition), Duplicated(x, dx), Const(y)) ``` -The returned value here could either by constant or duplicated, depending on the runtime-defined value of `cond`. If `cond` is true, Enzyme simply returns the shadow of `active_var` as the derivative. However, if `cond` is false, there is no derivative shadow for `constant_var` and Enzyme will throw a "Mismatched activity" error. For some simple types, e.g. a float Enzyme can circumvent this issue, for example by returning the float 0. Similarly, for some types like the Symbol type, which are never differentiable, such a shadow value will never be used, and Enzyme can return the original "primal" value as its derivative. However, for arbitrary data structures, Enzyme presently has no generic mechanism to resolve this. +The returned value here could either by constant or duplicated, depending on the runtime-defined value of `cond`. If `cond` is true, Enzyme simply returns the shadow of `active_var` as the derivative. However, if `cond` is false, there is no derivative shadow for `constant_var` and Enzyme will throw a `EnzymeRuntimeActivityError` error. For some simple types, e.g. a float Enzyme can circumvent this issue, for example by returning the float 0. Similarly, for some types like the Symbol type, which are never differentiable, such a shadow value will never be used, and Enzyme can return the original "primal" value as its derivative. However, for arbitrary data structures, Enzyme presently has no generic mechanism to resolve this. For example consider a third function: ```julia @@ -308,13 +310,17 @@ Enzyme provides a nice utility `Enzyme.make_zero` which takes a data structure a If one created a new zero'd copy of each return from `g`, this would mean that the derivative `dresult` would have one copy made for the first element, and a second copy made for the second element. This could lead to incorrect results, and is unfortunately not a general resolution. However, for non-mutable variables (e.g. like floats) or non-differrentiable types (e.g. like Symbols) this problem can never arise. -Instead, Enzyme has a special mode known as "Runtime Activity" which can handle these types of situations. It can come with a minor performance reduction, and is therefore off by default. It can be enabled with `Enzyme.API.runtimeActivity!(true)` right after importing Enzyme for the first time. +Instead, Enzyme has a special mode known as "Runtime Activity" which can handle these types of situations. It can come with a minor performance reduction, and is therefore off by default. It can be enabled with by setting runtime activity to true in a desired differentiation mode. The way Enzyme's runtime activity resolves this issue is to return the original primal variable as the derivative whenever it needs to denote the fact that a variable is a constant. As this issue can only arise with mutable variables, they must be represented in memory via a pointer. All addtional loads and stores will now be modified to first check if the primal pointer is the same as the shadow pointer, and if so, treat it as a constant. Note that this check is not saying that the same arrays contain the same values, but rather the same backing memory represents both the primal and the shadow (e.g. `a === b` or equivalently `pointer(a) == pointer(b)`). Enabling runtime activity does therefore, come with a sharp edge, which is that if the computed derivative of a function is mutable, one must also check to see if the primal and shadow represent the same pointer, and if so the true derivative of the function is actually zero. -Generally, the preferred solution to these type of activity unstable codes should be to make your variables all activity-stable (e.g. always containing differentiable memory or always containing non-differentiable memory). However, with care, Enzyme does support "Runtime Activity" as a way to differentiate these programs without having to modify your code. +Generally, the preferred solution to these type of activity unstable codes should be to make your variables all activity-stable (e.g. always containing differentiable memory or always containing non-differentiable memory). However, with care, Enzyme does support "Runtime Activity" as a way to differentiate these programs without having to modify your code. One can enable runtime activity for your code by changing the mode, such as + +```julia +Enzyme.autodiff(set_runtime_activity(Forward), h, Const(condition), Duplicated(x, dx), Const(y)) +``` ## Mixed activity diff --git a/examples/custom_rule.jl b/examples/custom_rule.jl index c2098006c20..86ffcf234a4 100644 --- a/examples/custom_rule.jl +++ b/examples/custom_rule.jl @@ -57,7 +57,7 @@ using .EnzymeRules # In this section, we write a simple forward rule to start out: -function forward(func::Const{typeof(f)}, ::Type{<:Duplicated}, y::Duplicated, x::Duplicated) +function forward(config::FwdConfig, func::Const{typeof(f)}, ::Type{<:Duplicated}, y::Duplicated, x::Duplicated) println("Using custom rule!") ret = func.val(y.val, x.val) y.dval .= 2 .* x.val .* x.dval @@ -65,6 +65,7 @@ function forward(func::Const{typeof(f)}, ::Type{<:Duplicated}, y::Duplicated, x: end # In the signature of our rule, we have made use of `Enzyme`'s activity annotations. Let's break down each one: +# - the [`FwdConfig`](@ref) configuration passes certain compile-time information about differentiation procedure (the width, and if we're using runtime activity), # - the [`Const`](@ref) annotation on `f` indicates that we accept a function `f` that does not have a derivative component, # which makes sense since `f` is not a closure with data that could be differentiated. # - the [`Duplicated`](@ref) annotation given in the second argument annotates the return value of `f`. This means that @@ -96,7 +97,7 @@ g(y, x) = f(y, x)^2 # function to differentiate # To squeeze out the last drop of performance, the below rule avoids computing the output of the original function and # just computes its derivative. -function forward(func::Const{typeof(f)}, ::Type{<:DuplicatedNoNeed}, y::Duplicated, x::Duplicated) +function forward(config, func::Const{typeof(f)}, ::Type{<:DuplicatedNoNeed}, y::Duplicated, x::Duplicated) println("Using custom rule with DuplicatedNoNeed output.") y.val .= x.val.^2 y.dval .= 2 .* x.val .* x.dval @@ -127,7 +128,7 @@ dy = [0.0, 0.0] Base.delete_method.(methods(forward, (Const{typeof(f)}, Vararg{Any}))) # delete our old rules -function forward(func::Const{typeof(f)}, RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, +function forward(config, func::Const{typeof(f)}, RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, y::Union{Const, Duplicated}, x::Union{Const, Duplicated}) println("Using our general custom rule!") y.val .= x.val.^2 diff --git a/ext/EnzymeChainRulesCoreExt.jl b/ext/EnzymeChainRulesCoreExt.jl index 81491f608ef..c6a41d77710 100644 --- a/ext/EnzymeChainRulesCoreExt.jl +++ b/ext/EnzymeChainRulesCoreExt.jl @@ -54,7 +54,7 @@ function Enzyme._import_frule(fn, tys...) end quote - function EnzymeRules.forward(fn::FA, ::Type{RetAnnotation}, $(exprs...); kwargs...) where {RetAnnotation, FA<:Annotation{<:$(esc(fn))}, $(anns...)} + function EnzymeRules.forward(config, fn::FA, ::Type{RetAnnotation}, $(exprs...); kwargs...) where {RetAnnotation, FA<:Annotation{<:$(esc(fn))}, $(anns...)} batchsize = same_or_one(1, $(vals...)) if batchsize == 1 dfn = fn isa Const ? $ChainRulesCore.NoTangent() : fn.dval diff --git a/lib/EnzymeCore/Project.toml b/lib/EnzymeCore/Project.toml index c0c4e0b1e61..2d39f92f452 100644 --- a/lib/EnzymeCore/Project.toml +++ b/lib/EnzymeCore/Project.toml @@ -1,11 +1,11 @@ name = "EnzymeCore" uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" authors = ["William Moses ", "Valentin Churavy "] -version = "0.7.8" +version = "0.8.0" [compat] Adapt = "3, 4" -julia = "1.6" +julia = "1.10" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index 0e67e4e3c0a..0175cb4caf5 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -216,59 +216,73 @@ const DefaultABI = FFIABI Abstract type for what differentiation mode will be used. """ -abstract type Mode{ABI, ErrIfFuncWritten} end +abstract type Mode{ABI, ErrIfFuncWritten, RuntimeActivity} end """ - struct ReverseMode{ReturnPrimal,ABI,Holomorphic} <: Mode{ABI} + struct ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} <: Mode{ABI, ErrIfFuncWritten, RuntimeActivity} Reverse mode differentiation. - `ReturnPrimal`: Should Enzyme return the primal return value from the augmented-forward. +- `RuntimeActivity`: Should Enzyme enable runtime activity (default off) - `ABI`: What runtime ABI to use - `Holomorphic`: Whether the complex result function is holomorphic and we should compute d/dz +- `ErrIfFuncWritten`: Should Enzyme err if the function differentiated is a closure and written to. """ -struct ReverseMode{ReturnPrimal,ABI,Holomorphic,ErrIfFuncWritten} <: Mode{ABI, ErrIfFuncWritten} end -const Reverse = ReverseMode{false,DefaultABI, false, false}() -const ReverseWithPrimal = ReverseMode{true,DefaultABI, false, false}() -const ReverseHolomorphic = ReverseMode{false,DefaultABI, true, false}() -const ReverseHolomorphicWithPrimal = ReverseMode{true,DefaultABI, true, false}() +struct ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} <: Mode{ABI, ErrIfFuncWritten,RuntimeActivity} end +const Reverse = ReverseMode{false,false,DefaultABI, false, false}() +const ReverseWithPrimal = ReverseMode{true,false,DefaultABI, false, false}() +const ReverseHolomorphic = ReverseMode{false,false,DefaultABI, true, false}() +const ReverseHolomorphicWithPrimal = ReverseMode{true,false,DefaultABI, true, false}() -@inline set_err_if_func_written(::ReverseMode{ReturnPrimal,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,ABI,Holomorphic,true}() -@inline clear_err_if_func_written(::ReverseMode{ReturnPrimal,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,ABI,Holomorphic,false}() +@inline set_err_if_func_written(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,true}() +@inline clear_err_if_func_written(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,false}() -@inline set_abi(::ReverseMode{ReturnPrimal,OldABI,Holomorphic,ErrIfFuncWritten}, ::Type{NewABI}) where {ReturnPrimal,OldABI,Holomorphic,ErrIfFuncWritten,NewABI<:ABI} = ReverseMode{ReturnPrimal,NewABI,Holomorphic,ErrIfFuncWritten}() +@inline set_abi(::ReverseMode{ReturnPrimal,RuntimeActivity,OldABI,Holomorphic,ErrIfFuncWritten}, ::Type{NewABI}) where {ReturnPrimal,RuntimeActivity,OldABI,Holomorphic,ErrIfFuncWritten,NewABI<:ABI} = ReverseMode{ReturnPrimal,RuntimeActivity,NewABI,Holomorphic,ErrIfFuncWritten}() + +@inline set_runtime_activity(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,true,ABI,Holomorphic,ErrIfFuncWritten}() +@inline set_runtime_activity(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, rt::Bool) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,rt,ABI,Holomorphic,ErrIfFuncWritten}() +@inline clear_runtime_activity(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,false,ABI,Holomorphic,ErrIfFuncWritten}() """ - struct ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI} <: Mode{ABI} + struct ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI} <: Mode{ABI,ErrIfFuncWritten,RuntimeActivity} Reverse mode differentiation. - `ReturnPrimal`: Should Enzyme return the primal return value from the augmented-forward. - `ReturnShadow`: Should Enzyme return the shadow return value from the augmented-forward. +- `RuntimeActivity`: Should Enzyme differentiate with runtime activity on (default off). - `Width`: Batch Size (0 if to be automatically derived) - `ModifiedBetween`: Tuple of each argument's modified between state (true if to be automatically derived). """ -struct ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI, ErrIfFuncWritten} <: Mode{ABI, ErrIfFuncWritten} end -const ReverseSplitNoPrimal = ReverseModeSplit{false, true, 0, true,DefaultABI, false}() -const ReverseSplitWithPrimal = ReverseModeSplit{true, true, 0, true,DefaultABI, false}() -@inline ReverseSplitModified(::ReverseModeSplit{ReturnPrimal, ReturnShadow, Width, MBO, ABI, ErrIfFuncWritten}, ::Val{MB}) where {ReturnPrimal,ReturnShadow,Width,MB,MBO,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,MB,ABI, ErrIfFuncWritten}() -@inline ReverseSplitWidth(::ReverseModeSplit{ReturnPrimal, ReturnShadow, WidthO, MB, ABI, ErrIfFuncWritten}, ::Val{Width}) where {ReturnPrimal,ReturnShadow,Width,MB,WidthO,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,MB,ABI, ErrIfFuncWritten}() +struct ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,RuntimeActivity,ModifiedBetween,ABI, ErrIfFuncWritten} <: Mode{ABI, ErrIfFuncWritten,RuntimeActivity} end +const ReverseSplitNoPrimal = ReverseModeSplit{false, true, false, 0, true,DefaultABI, false}() +const ReverseSplitWithPrimal = ReverseModeSplit{true, true, false, 0, true,DefaultABI, false}() +@inline ReverseSplitModified(::ReverseModeSplit{ReturnPrimal, ReturnShadow, RuntimeActivity, Width, MBO, ABI, ErrIfFuncWritten}, ::Val{MB}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,MBO,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity, Width,MB,ABI, ErrIfFuncWritten}() +@inline ReverseSplitWidth(::ReverseModeSplit{ReturnPrimal, ReturnShadow, RuntimeActivity, WidthO, MB, ABI, ErrIfFuncWritten}, ::Val{Width}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,WidthO,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,ABI, ErrIfFuncWritten}() + +@inline set_err_if_func_written(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, true}() +@inline clear_err_if_func_written(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI, false}() -@inline set_err_if_func_written(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI, true}() -@inline clear_err_if_func_written(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI, false}() +@inline set_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,true,Width,ModifiedBetween,ABI, ErrIfFuncWritten}() +@inline set_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}, rt::Bool) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,rt,Width,ModifiedBetween,ABI, ErrIfFuncWritten}() +@inline clear_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,false,Width,ModifiedBetween,ABI, ErrIfFuncWritten}() """ - struct Forward <: Mode + struct Forward{ABI, ErrIfFuncWritten,RuntimeActivity} <: Mode{ABI, ErrIfFuncWritten, RuntimeActivity} Forward mode differentiation """ -struct ForwardMode{ABI, ErrIfFuncWritten} <: Mode{ABI, ErrIfFuncWritten} +struct ForwardMode{ABI, ErrIfFuncWritten,RuntimeActivity} <: Mode{ABI, ErrIfFuncWritten, RuntimeActivity} end -const Forward = ForwardMode{DefaultABI, false}() +const Forward = ForwardMode{DefaultABI, false, false}() +@inline set_err_if_func_written(::ForwardMode{ABI,ErrIfFuncWritten,RuntimeActivity}) where {ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ABI,true,RuntimeActivity}() +@inline clear_err_if_func_written(::ForwardMode{ABI,ErrIfFuncWritten,RuntimeActivity}) where {ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ABI,false,RuntimeActivity}() -@inline set_err_if_func_written(::ForwardMode{ABI,ErrIfFuncWritten}) where {ABI,ErrIfFuncWritten} = ForwardMode{ABI,true}() -@inline clear_err_if_func_written(::ForwardMode{ABI,ErrIfFuncWritten}) where {ABI,ErrIfFuncWritten} = ForwardMode{ABI,false}() +@inline set_abi(::ForwardMode{OldABI,ErrIfFuncWritten,RuntimeActivity}, ::Type{NewABI}) where {OldABI,ErrIfFuncWritten,RuntimeActivity,NewABI<:ABI} = ForwardMode{NewABI,ErrIfFuncWritten,RuntimeActivity}() -@inline set_abi(::ForwardMode{OldABI,ErrIfFuncWritten}, ::Type{NewABI}) where {OldABI,ErrIfFuncWritten,NewABI<:ABI} = ForwardMode{NewABI,ErrIfFuncWritten}() +@inline set_runtime_activity(::ForwardMode{ABI,ErrIfFuncWritten,RuntimeActivity}) where {ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ABI,ErrIfFuncWritten,true}() +@inline set_runtime_activity(::ForwardMode{ABI,ErrIfFuncWritten,RuntimeActivity}, rt::Bool) where {ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ABI,ErrIfFuncWritten,rt}() +@inline clear_runtime_activity(::ForwardMode{ABI,ErrIfFuncWritten,RuntimeActivity}) where {ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ABI,ErrIfFuncWritten,false}() function autodiff end function autodiff_deferred end diff --git a/lib/EnzymeCore/src/rules.jl b/lib/EnzymeCore/src/rules.jl index 398c7900879..27b14619e36 100644 --- a/lib/EnzymeCore/src/rules.jl +++ b/lib/EnzymeCore/src/rules.jl @@ -1,42 +1,77 @@ module EnzymeRules -import EnzymeCore: Annotation, Const, Duplicated -export Config, ConfigWidth, AugmentedReturn -export needs_primal, needs_shadow, width, overwritten +import EnzymeCore +import EnzymeCore: Annotation, Const, Duplicated, Mode +export RevConfig, RevConfigWidth +export FwdConfig, FwdConfigWidth +export AugmentedReturn +export needs_primal, needs_shadow, width, overwritten, runtime_activity export primal_type, shadow_type, tape_type import Base: unwrapva, isvarargtype, unwrap_unionall, rewrap_unionall """ - forward(func::Annotation{typeof(f)}, RT::Type{<:Annotation}, args::Annotation...) + forward(fwdconfig, func::Annotation{typeof(f)}, RT::Type{<:Annotation}, args::Annotation...) -Calculate the forward derivative. The first argument `func` is the callable -for which the rule applies to. Either wrapped in a [`Const`](@ref)), or -a [`Duplicated`](@ref) if it is a closure. -The second argument is the return type annotation, and all other arguments are -the annotated function arguments. +Calculate the forward derivative. The first argument is a [`FwdConfig](@ref) object +describing parameters of the differentiation. +The second argument `func` is the callable for which the rule applies to. +Either wrapped in a [`Const`](@ref)), or a [`Duplicated`](@ref) if it is a closure. +The third argument is the return type annotation, and all other arguments are the annotated function arguments. """ function forward end """ - Config{NeedsPrimal, NeedsShadow, Width, Overwritten} - ConfigWidth{Width} = Config{<:Any,<:Any, Width} + FwdConfig{Width, RuntimeActivity} + FwdConfigWidth{Width} = FwdConfig{Width} + +Configuration type to dispatch on in custom forward rules (see [`forward`](@ref). +* `Width`: an integer that specifies the number of adjoints/shadows simultaneously being propagated. +* `RuntimeActivity`: whether runtime activity is enabled. + +Getters for the type parameters are provided by `width` and `runtime_activity`. +""" +struct FwdConfig{Width, RuntimeActivity} end +const FwdConfigWidth{Width} = FwdConfig{Width} +@inline width(::FwdConfig{Width}) where Width = Width +@inline runtime_activity(::FwdConfig{<:Any, RuntimeActivity}) where RuntimeActivity = RuntimeActivity + + +""" + RevConfig{NeedsPrimal, NeedsShadow, Width, Overwritten, RuntimeActivity} + RevConfigWidth{Width} = RevConfig{<:Any,<:Any, Width} Configuration type to dispatch on in custom reverse rules (see [`augmented_primal`](@ref) and [`reverse`](@ref)). * `NeedsPrimal` and `NeedsShadow`: boolean values specifying whether the primal and shadow (resp.) should be returned. * `Width`: an integer that specifies the number of adjoints/shadows simultaneously being propagated. * `Overwritten`: a tuple of booleans of whether each argument (including the function itself) is modified between the forward and reverse pass (true if potentially modified between). +* `RuntimeActivity`: whether runtime activity is enabled. + +Getters for the four type parameters are provided by `needs_primal`, `needs_shadow`, `width`, `overwritten`, and `runtime_activity`. +""" +struct RevConfig{NeedsPrimal, NeedsShadow, Width, Overwritten, RuntimeActivity} end +const RevConfigWidth{Width} = RevConfig{<:Any,<:Any, Width} + +@inline needs_primal(::RevConfig{NeedsPrimal}) where NeedsPrimal = NeedsPrimal +@inline needs_shadow(::RevConfig{<:Any, NeedsShadow}) where NeedsShadow = NeedsShadow +@inline width(::RevConfig{<:Any, <:Any, Width}) where Width = Width +@inline overwritten(::RevConfig{<:Any, <:Any, <:Any, Overwritten}) where Overwritten = Overwritten +@inline runtime_activity(::RevConfig{<:Any, <:Any, <:Any, <:Any, RuntimeActivity}) where RuntimeActivity = RuntimeActivity + +""" + primal_type(::RevConfig, ::Type{<:Annotation{RT}}) -Getters for the four type parameters are provided by `needs_primal`, `needs_shadow`, `width`, and `overwritten`. +Compute the exepcted primal return type given a reverse mode config and return activity """ -struct Config{NeedsPrimal, NeedsShadow, Width, Overwritten} end -const ConfigWidth{Width} = Config{<:Any,<:Any, Width} +@inline primal_type(config::RevConfig, ::Type{<:Annotation{RT}}) where RT = needs_primal(config) ? RT : Nothing -@inline needs_primal(::Config{NeedsPrimal}) where NeedsPrimal = NeedsPrimal -@inline needs_shadow(::Config{<:Any, NeedsShadow}) where NeedsShadow = NeedsShadow -@inline width(::Config{<:Any, <:Any, Width}) where Width = Width -@inline overwritten(::Config{<:Any, <:Any, <:Any, Overwritten}) where Overwritten = Overwritten +""" + shadow_type(::RevConfig, ::Type{<:Annotation{RT}}) + +Compute the exepcted shadow return type given a reverse mode config and return activity +""" +@inline shadow_type(config::RevConfig, ::Type{<:Annotation{RT}}) where RT = needs_shadow(config) ? (width(config) == 1 ? RT : NTuple{width(config), RT}) : Nothing """ AugmentedReturn(primal, shadow, tape) @@ -73,7 +108,7 @@ end @inline tape_type(::Type{AugmentedReturnFlexShadow{PrimalType,ShadowType,TapeType}}) where {PrimalType,ShadowType,TapeType} = TapeType @inline tape_type(::AugmentedReturnFlexShadow{PrimalType,ShadowType,TapeType}) where {PrimalType,ShadowType,TapeType} = TapeType """ - augmented_primal(::Config, func::Annotation{typeof(f)}, RT::Type{<:Annotation}, args::Annotation...) + augmented_primal(::RevConfig, func::Annotation{typeof(f)}, RT::Type{<:Annotation}, args::Annotation...) Must return an [`AugmentedReturn`](@ref) type. * The primal must be the same type of the original return if `needs_primal(config)`, otherwise nothing. @@ -84,8 +119,8 @@ Must return an [`AugmentedReturn`](@ref) type. function augmented_primal end """ - reverse(::Config, func::Annotation{typeof(f)}, dret::Active, tape, args::Annotation...) - reverse(::Config, func::Annotation{typeof(f)}, ::Type{<:Annotation), tape, args::Annotation...) + reverse(::RevConfig, func::Annotation{typeof(f)}, dret::Active, tape, args::Annotation...) + reverse(::RevConfig, func::Annotation{typeof(f)}, ::Type{<:Annotation), tape, args::Annotation...) Takes gradient of derivative, activity annotation, and tape. If there is an active return dret is passed as Active{T} with the derivative of the active return val. Otherwise dret is passed as Type{Duplicated{T}}, etc. @@ -117,7 +152,7 @@ function has_frule_from_sig(@nospecialize(TT); method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing, caller::Union{Nothing,Core.MethodInstance}=nothing) ft, tt = _annotate_tt(TT) - TT = Tuple{<:Annotation{ft}, Type{<:Annotation}, tt...} + TT = Tuple{<:FwdConfig, <:Annotation{ft}, Type{<:Annotation}, tt...} return isapplicable(forward, TT; world, method_table, caller) end @@ -126,7 +161,7 @@ function has_rrule_from_sig(@nospecialize(TT); method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing, caller::Union{Nothing,Core.MethodInstance}=nothing) ft, tt = _annotate_tt(TT) - TT = Tuple{<:Config, <:Annotation{ft}, Type{<:Annotation}, tt...} + TT = Tuple{<:RevConfig, <:Annotation{ft}, Type{<:Annotation}, tt...} return isapplicable(augmented_primal, TT; world, method_table, caller) end @@ -241,4 +276,6 @@ Mark a particular type `Ty` as always being inactive. """ inactive_type(::Type) = false +@inline EnzymeCore.set_runtime_activity(::M, ::Config) where {M<:Mode, Config <: Union{FwdConfig, RevConfig}} = EnzymeCore.set_runtime_activity(M, runtime_activity(Config)) + end # EnzymeRules diff --git a/lib/EnzymeCore/test/runtests.jl b/lib/EnzymeCore/test/runtests.jl index 9b76ebf56b5..d85d4dea156 100644 --- a/lib/EnzymeCore/test/runtests.jl +++ b/lib/EnzymeCore/test/runtests.jl @@ -4,7 +4,7 @@ using EnzymeCore import EnzymeCore.EnzymeRules: forward, has_frule_from_sig g(x) = x ^ 2 -function forward(::Const{typeof(g)}, ::Type{<:Const}, x::Const) +function forward(config, ::Const{typeof(g)}, ::Type{<:Const}, x::Const) return Const(g(x.val)) end @@ -12,7 +12,7 @@ end f(;kwargs) = 1.0 -function forward(::Const{typeof(f)}, ::Type{<:Const}; kwargs...) +function forward(config, ::Const{typeof(f)}, ::Type{<:Const}; kwargs...) return Const(f(; kwargs...)) end diff --git a/lib/EnzymeTestUtils/Project.toml b/lib/EnzymeTestUtils/Project.toml index 05e5e6b94e3..72684a97812 100644 --- a/lib/EnzymeTestUtils/Project.toml +++ b/lib/EnzymeTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeTestUtils" uuid = "12d8515a-0907-448a-8884-5fe00fdf1c5a" authors = ["Seth Axen ", "William Moses ", "Valentin Churavy "] -version = "0.1.8" +version = "0.2.0" [deps] ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" @@ -13,12 +13,12 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] ConstructionBase = "1.4.1" -Enzyme = "0.11, 0.12, 0.13" -EnzymeCore = "0.5, 0.6, 0.7" +Enzyme = "0.13" +EnzymeCore = "0.5, 0.6, 0.7, 0.8" FiniteDifferences = "0.12.12" MetaTesting = "0.1" Quaternions = "0.7" -julia = "1.6" +julia = "1.10" [extras] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/lib/EnzymeTestUtils/src/test_forward.jl b/lib/EnzymeTestUtils/src/test_forward.jl index e57a5c7e34b..fcfc987cb9a 100644 --- a/lib/EnzymeTestUtils/src/test_forward.jl +++ b/lib/EnzymeTestUtils/src/test_forward.jl @@ -61,6 +61,7 @@ function test_forward( rtol::Real=1e-9, atol::Real=1e-9, testset_name=nothing, + runtime_activity::Bool=false ) call_with_copy(f, xs...) = deepcopy(f)(deepcopy(xs)...; deepcopy(fkwargs)...) call_with_kwargs(f, xs...) = f(xs...; fkwargs...) @@ -78,7 +79,7 @@ function test_forward( # call finitedifferences, avoid mutating original arguments dy_fdm = _fd_forward(fdm, call_with_copy, ret_activity, y, activities) # call autodiff, allow mutating original arguments - y_and_dy_ad = autodiff(Forward, call_with_kwargs, ret_activity, activities...) + y_and_dy_ad = autodiff(set_runtime_activity(Forward, runtime_activity), call_with_kwargs, ret_activity, activities...) if ret_activity <: Union{Duplicated,BatchDuplicated} @test_msg( "For return type $ret_activity the return value and derivative must be returned", diff --git a/lib/EnzymeTestUtils/src/test_reverse.jl b/lib/EnzymeTestUtils/src/test_reverse.jl index f204b00a7b5..6c20aebb7aa 100644 --- a/lib/EnzymeTestUtils/src/test_reverse.jl +++ b/lib/EnzymeTestUtils/src/test_reverse.jl @@ -81,6 +81,7 @@ function test_reverse( rtol::Real=1e-9, atol::Real=1e-9, testset_name=nothing, + runtime_activity::Bool=false ) call_with_captured_kwargs(f, xs...) = f(xs...; fkwargs...) if testset_name === nothing @@ -108,8 +109,9 @@ function test_reverse( dx_fdm = _fd_reverse(fdm, call_with_captured_kwargs, ȳ, activities, !(ret_activity <: Const)) # call autodiff, allow mutating original arguments c_act = Const(call_with_kwargs) + mode = set_runtime_activity(ReverseSplitWithPrimal, runtime_activity) forward, reverse = autodiff_thunk( - ReverseSplitWithPrimal, typeof(c_act), ret_activity, typeof(Const(fkwargs)), map(typeof, activities)... + mode, typeof(c_act), ret_activity, typeof(Const(fkwargs)), map(typeof, activities)... ) tape, y_ad, shadow_result = forward(c_act, Const(fkwargs), activities...) test_approx( diff --git a/lib/EnzymeTestUtils/test/test_forward.jl b/lib/EnzymeTestUtils/test/test_forward.jl index 7f870af7bf8..57385a1dd98 100644 --- a/lib/EnzymeTestUtils/test/test_forward.jl +++ b/lib/EnzymeTestUtils/test/test_forward.jl @@ -178,7 +178,6 @@ end end @testset "mutating function" begin - Enzyme.API.runtimeActivity!(true) sz = (2, 3) @testset for Tret in (Const, Duplicated, BatchDuplicated), Tx in (Const, Duplicated, BatchDuplicated), @@ -196,10 +195,9 @@ end atol = rtol = sqrt(eps(real(T))) @test !fails() do - test_forward(f_mut_fwd!, Tret, (y, Ty), (x, Tx), (a, Ta); atol, rtol) + test_forward(f_mut_fwd!, Tret, (y, Ty), (x, Tx), (a, Ta); atol, rtol, runtime_activity=true) end skip = (VERSION < v"1.8" && T <: Complex) end - Enzyme.API.runtimeActivity!(false) end @testset "incorrect mutated argument detected" begin diff --git a/src/Enzyme.jl b/src/Enzyme.jl index a5b949c60a1..fcc12d57a82 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -5,8 +5,8 @@ import EnzymeCore import EnzymeCore: Forward, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode, ReverseHolomorphic, ReverseHolomorphicWithPrimal export Forward, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode, ReverseHolomorphic, ReverseHolomorphicWithPrimal -import EnzymeCore: Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, ABI, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written, set_abi -export Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written, set_abi +import EnzymeCore: Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, ABI, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written, set_abi, set_runtime_activity, clear_runtime_activity +export Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written, set_abi, set_runtime_activity, clear_runtime_activity import EnzymeCore: BatchDuplicatedFunc export BatchDuplicatedFunc @@ -229,7 +229,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) [`Active`](@ref) will automatically convert plain integers to floating point values, but cannot do so for integer values in tuples and structs. """ -@inline function autodiff(rmode::ReverseMode{ReturnPrimal, RABI,Holomorphic, ErrIfFuncWritten}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal, RABI<:ABI,Holomorphic, Nargs, ErrIfFuncWritten} +@inline function autodiff(rmode::ReverseMode{ReturnPrimal, RuntimeActivity,RABI,Holomorphic, ErrIfFuncWritten}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal, RuntimeActivity, RABI<:ABI,Holomorphic, Nargs, ErrIfFuncWritten} tt′ = vaTypeof(args...) width = same_or_one(1, args...) if width == 0 @@ -256,7 +256,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) if A <: Active if (!allocatedinline(rt) || rt isa Union) && rt != Union{} - forward, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, Duplicated{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(true), RABI, Val(ErrIfFuncWritten)) + forward, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, Duplicated{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(true), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) res = forward(f, args...) tape = res[1] if ReturnPrimal @@ -286,7 +286,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) args = seed_complex_args(seen, seen2, args...) tt′ = vaTypeof(args...) - thunk = Enzyme.Compiler.thunk(opt_mi, typeof(f), A, tt′, #=Split=# Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten)) + thunk = Enzyme.Compiler.thunk(opt_mi, typeof(f), A, tt′, #=Split=# Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) results = thunk(f, args..., (rt(0), rt(1), rt(im))) @@ -308,7 +308,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) throw(ErrorException("Reverse-mode Active Complex return is ambiguous and requires more information to specify the desired result. See https://enzyme.mit.edu/julia/stable/faq/#Complex-numbers for more details.")) end - thunk = Enzyme.Compiler.thunk(opt_mi, FA, A, tt′, #=Split=# Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten)) + thunk = Enzyme.Compiler.thunk(opt_mi, FA, A, tt′, #=Split=# Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) if A <: Active args = (args..., Compiler.default_adjoint(rt)) @@ -389,7 +389,7 @@ f(x) = x*x (6.28,) ``` """ -@inline function autodiff(::ForwardMode{RABI, ErrIfFuncWritten}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation} where {RABI <: ABI, Nargs, ErrIfFuncWritten} +@inline function autodiff(::ForwardMode{RABI, ErrIfFuncWritten, RuntimeActivity}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation} where {RABI <: ABI, Nargs, ErrIfFuncWritten, RuntimeActivity} if any_active(args...) throw(ErrorException("Active arguments not allowed in forward mode")) end @@ -429,7 +429,7 @@ f(x) = x*x end thunk = Enzyme.Compiler.thunk(opt_mi, FA, RT, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), - ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten)) + ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) thunk(f, args...) end @@ -439,7 +439,7 @@ end Same as [`autodiff`](@ref) but uses deferred compilation to support usage in GPU code, as well as high-order differentiation. """ -@inline function autodiff_deferred(::ReverseMode{ReturnPrimal, ABI,Holomorphic,ErrIfFuncWritten}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal, Nargs, ABI,Holomorphic,ErrIfFuncWritten} +@inline function autodiff_deferred(::ReverseMode{ReturnPrimal, RuntimeActivity, ABI,Holomorphic,ErrIfFuncWritten}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal, Nargs, ABI,Holomorphic,ErrIfFuncWritten, RuntimeActivity} tt′ = vaTypeof(args...) width = same_or_one(1, args...) if width == 0 @@ -463,7 +463,7 @@ code, as well as high-order differentiation. ModifiedBetween = Val(falses_from_args(Nargs+1)) - adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), UnknownTapeType, Val(ErrIfFuncWritten)) + adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), UnknownTapeType, Val(ErrIfFuncWritten), Val(RuntimeActivity)) thunk = Compiler.CombinedAdjointThunk{Ptr{Cvoid}, FA, rt, tt′, width, ReturnPrimal}(adjoint_ptr) if rt <: Active @@ -480,7 +480,7 @@ end Same as `autodiff(::ForwardMode, f, Activity, args)` but uses deferred compilation to support usage in GPU code, as well as high-order differentiation. """ -@inline function autodiff_deferred(::ForwardMode{ABI, ErrIfFuncWritten}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation, Nargs, ABI, ErrIfFuncWritten} +@inline function autodiff_deferred(::ForwardMode{ABI, ErrIfFuncWritten, RuntimeActivity}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation, Nargs, ABI, ErrIfFuncWritten, RuntimeActivity} if any_active(args...) throw(ErrorException("Active arguments not allowed in forward mode")) end @@ -527,7 +527,7 @@ code, as well as high-order differentiation. ReturnPrimal = RT <: Duplicated || RT <: BatchDuplicated ModifiedBetween = Val(falses_from_args(Nargs+1)) - adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), UnknownTapeType, Val(ErrIfFuncWritten)) + adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), UnknownTapeType, Val(ErrIfFuncWritten), Val(RuntimeActivity)) thunk = Compiler.ForwardModeThunk{Ptr{Cvoid}, FA, rt, tt′, width, ReturnPrimal}(adjoint_ptr) thunk(f, args...) end @@ -608,7 +608,7 @@ result, ∂v, ∂A (7.26, 2.2, [3.3]) ``` """ -@inline function autodiff_thunk(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT,RABI, ErrIfFuncWritten}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT,RABI<:ABI, Nargs, ErrIfFuncWritten} +@inline function autodiff_thunk(rs::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity, Width,ModifiedBetweenT,RABI, ErrIfFuncWritten}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT,RABI<:ABI, Nargs, ErrIfFuncWritten, RuntimeActivity} width = if Width == 0 w = same_or_one(1, args...) if w == 0 @@ -636,7 +636,7 @@ result, ∂v, ∂A else Val(codegen_world_age(eltype(FA), tt)) end - Enzyme.Compiler.thunk(opt_mi, FA, A, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten)) + Enzyme.Compiler.thunk(opt_mi, FA, A, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) end """ @@ -683,7 +683,7 @@ forward = autodiff_thunk(Forward, Const{typeof(f)}, DuplicatedNoNeed, Duplicated (6.28,) ``` """ -@inline function autodiff_thunk(::ForwardMode{RABI, ErrIfFuncWritten}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, RABI<:ABI, Nargs, ErrIfFuncWritten} +@inline function autodiff_thunk(::ForwardMode{RABI, ErrIfFuncWritten, RuntimeActivity}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, RABI<:ABI, Nargs, ErrIfFuncWritten, RuntimeActivity} width = same_or_one(1, A, args...) if width == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) @@ -702,10 +702,10 @@ forward = autodiff_thunk(Forward, Const{typeof(f)}, DuplicatedNoNeed, Duplicated else Val(codegen_world_age(eltype(FA), tt)) end - Enzyme.Compiler.thunk(opt_mi, FA, A, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten)) + Enzyme.Compiler.thunk(opt_mi, FA, A, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) end -@inline function tape_type(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI, ErrIfFuncWritten}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs, ErrIfFuncWritten} +@inline function tape_type(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity, Width,ModifiedBetweenT, RABI, ErrIfFuncWritten}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs, ErrIfFuncWritten, RuntimeActivity} width = if Width == 0 w = same_or_one(1, args...) if w == 0 @@ -731,7 +731,7 @@ end else Val(codegen_world_age(eltype(FA), primal_tt)) end - nondef = Enzyme.Compiler.thunk(opt_mi, FA, A, TT, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten)) + nondef = Enzyme.Compiler.thunk(opt_mi, FA, A, TT, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) if nondef[1] isa Enzyme.Compiler.PrimalErrorThunk return Nothing else @@ -747,9 +747,9 @@ const tape_cache_lock = ReentrantLock() import .Compiler: fspec, remove_innerty, UnknownTapeType @inline function tape_type( - parent_job::Union{GPUCompiler.CompilerJob,Nothing}, ::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI}, + parent_job::Union{GPUCompiler.CompilerJob,Nothing}, ::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity, Width,ModifiedBetweenT, RABI}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs} -) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs} +) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs, RuntimeActivity} width = if Width == 0 w = same_or_one(1, args...) if w == 0 @@ -779,7 +779,8 @@ import .Compiler: fspec, remove_innerty, UnknownTapeType params = Compiler.EnzymeCompilerParams( Tuple{FA, TT.parameters...}, API.DEM_ReverseModeGradient, width, Compiler.remove_innerty(A), true, #=abiwrap=#false, ModifiedBetweenT, - ReturnPrimal, #=ShadowInit=#false, Compiler.UnknownTapeType, RABI, #=errifwritte=#false + ReturnPrimal, #=ShadowInit=#false, Compiler.UnknownTapeType, RABI, #=errifwritte=#false, + RuntimeActivity ) job = Compiler.CompilerJob(mi, Compiler.CompilerConfig(target, params; kernel=false)) @@ -849,7 +850,7 @@ result, ∂v, ∂A (7.26, 2.2, [3.3]) ``` """ -@inline function autodiff_deferred_thunk(mode::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI, ErrIfFuncWritten}, tt::Type{TapeType}, fa::Type{FA}, a2::Type{A2}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A2<:Annotation, TapeType, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs, ErrIfFuncWritten} +@inline function autodiff_deferred_thunk(mode::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity, Width,ModifiedBetweenT, RABI, ErrIfFuncWritten}, tt::Type{TapeType}, fa::Type{FA}, a2::Type{A2}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A2<:Annotation, TapeType, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs, ErrIfFuncWritten, RuntimeActivity} @assert RABI == FFIABI width = if Width == 0 w = same_or_one(1, args...) @@ -873,8 +874,8 @@ result, ∂v, ∂A primal_tt = Tuple{map(eltype, args)...} world = codegen_world_age(eltype(FA), primal_tt) - primal_ptr = Compiler.deferred_codegen(Val(world), FA, Val(TT), Val(Compiler.remove_innerty(A2)), Val(API.DEM_ReverseModePrimal), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType, Val(ErrIfFuncWritten)) - adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(TT), Val(Compiler.remove_innerty(A2)), Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType, Val(ErrIfFuncWritten)) + primal_ptr = Compiler.deferred_codegen(Val(world), FA, Val(TT), Val(Compiler.remove_innerty(A2)), Val(API.DEM_ReverseModePrimal), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType, Val(ErrIfFuncWritten), Val(RuntimeActivity)) + adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(TT), Val(Compiler.remove_innerty(A2)), Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType, Val(ErrIfFuncWritten), Val(RuntimeActivity)) RT = if A2 <: Duplicated && width != 1 if A2 isa UnionAll @@ -1031,15 +1032,23 @@ grad = gradient(Reverse, only ∘ f, (a = 2.0, b = [3.0], c = "str")) (a = 3.0, b = [2.0], c = "str") ``` """ -@inline function gradient(rm::ReverseMode, f::F, x::X) where {F, X} +@inline function gradient(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, f::F, x::X) where {F, X, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten} if Compiler.active_reg_inner(X, #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState dx = Ref(make_zero(x)) - autodiff(rm, f, Active, MixedDuplicated(x, dx)) - return only(dx) + res = autodiff(rm, f, Active, MixedDuplicated(x, dx)) + if ReturnPrimal + (res[2], only(dx)) + else + only(dx) + end else dx = make_zero(x) - autodiff(rm, f, Active, Duplicated(x, dx)) - return dx + res = autodiff(rm, f, Active, Duplicated(x, dx)) + if ReturnPrimal + (res[2], dx) + else + dx + end end end @@ -1048,15 +1057,23 @@ end Like [`gradient`](@ref), except it using deferred mode. """ -@inline function gradient_deferred(rm::ReverseMode, f::F, x::X) where {F, X} +@inline function gradient_deferred(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, f::F, x::X) where {F, X, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten} if Compiler.active_reg_inner(X, #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState dx = Ref(make_zero(x)) autodiff_deferred(rm, f, Active, MixedDuplicated(x, dx)) - return only(dx) + if ReturnPrimal + return (res[2], only(dx)) + else + return only(dx) + end else dx = make_zero(x) autodiff_deferred(rm, f, Active, Duplicated(x, dx)) - return dx + if ReturnPrimal + (res[2], dx) + else + dx + end end end @@ -1082,10 +1099,14 @@ gradient!(Reverse, dx, f, [2.0, 3.0]) 2.0 ``` """ -@inline function gradient!(::ReverseMode, dx::X, f::F, x::X) where {X<:Array, F} +@inline function gradient!(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, dx::X, f::F, x::X) where {X<:Array, F, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten} make_zero!(dx) - autodiff(Reverse, f, Active, Duplicated(x, dx)) - dx + res = autodiff(rm, f, Active, Duplicated(x, dx)) + return if ReturnPrimal + (res[2], dx) + else + dx + end end @@ -1094,10 +1115,14 @@ end Like [`gradient!`](@ref), except it using deferred mode. """ -@inline function gradient_deferred!(::ReverseMode, dx::X, f::F, x::X) where {X<:Array, F} +@inline function gradient_deferred!(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, dx::X, f::F, x::X) where {X<:Array, F, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten} make_zero!(dx) - autodiff_deferred(Reverse, f, Active, Duplicated(x, dx)) - dx + autodiff_deferred(rm, f, Active, Duplicated(x, dx)) + return if ReturnPrimal + (res[2], dx) + else + dx + end end """ @@ -1121,11 +1146,11 @@ grad = gradient(Forward, f, [2.0, 3.0]) (3.0, 2.0) ``` """ -@inline function gradient(::ForwardMode, f, x; shadow=onehot(x)) +@inline function gradient(fm::ForwardMode, f, x; shadow=onehot(x)) if length(shadow) == 0 return () end - res = values(only(autodiff(Forward, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow)))) + res = values(only(autodiff(fm, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow)))) if x isa AbstractFloat res[1] else @@ -1169,12 +1194,12 @@ grad = gradient(Forward, f, [2.0, 3.0], Val(2)) (3.0, 2.0) ``` """ -@inline function gradient(::ForwardMode, f::F, x::X, ::Val{chunk}; shadow=chunkedonehot(x, Val(chunk))) where {F, X, chunk} +@inline function gradient(fm::ForwardMode, f::F, x::X, ::Val{chunk}; shadow=chunkedonehot(x, Val(chunk))) where {F, X, chunk} if chunk == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) end tmp = ntuple(length(shadow)) do i - values(autodiff(Forward, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow[i]))[1]) + values(autodiff(fm, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow[i]))[1]) end res = tupleconcat(tmp...) if x isa AbstractFloat @@ -1184,9 +1209,9 @@ grad = gradient(Forward, f, [2.0, 3.0], Val(2)) end end -@inline function gradient(::ForwardMode, f::F, x::X, ::Val{1}; shadow=onehot(x)) where {F, X} +@inline function gradient(fm::ForwardMode, f::F, x::X, ::Val{1}; shadow=onehot(x)) where {F, X} res = ntuple(length(shadow)) do i - autodiff(Forward, f, DuplicatedNoNeed, Duplicated(x, shadow[i]))[1] + autodiff(fm, f, DuplicatedNoNeed, Duplicated(x, shadow[i]))[1] end if x isa AbstractFloat res[1] @@ -1223,11 +1248,11 @@ whose shape is `(size(output)..., size(input)...)` For functions who return other types, this function will retun an array or tuple of shape `size(input)` of values of the output type. """ -@inline function jacobian(::ForwardMode, f, x; shadow=onehot(x)) +@inline function jacobian(fm::ForwardMode, f, x; shadow=onehot(x)) cols = if length(shadow) == 0 () else - values(only(autodiff(Forward, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow)))) + values(only(autodiff(fm, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow)))) end if x isa AbstractFloat cols[1] @@ -1252,13 +1277,13 @@ of shape `size(input)` of values of the output type. end end -@inline function jacobian(::ForwardMode, f::F, x::X, ::Val{chunk}; shadow=chunkedonehot(x, Val(chunk))) where {F, X, chunk} +@inline function jacobian(fm::ForwardMode, f::F, x::X, ::Val{chunk}; shadow=chunkedonehot(x, Val(chunk))) where {F, X, chunk} if chunk == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) end tmp = ntuple(length(shadow)) do i Base.@_inline_meta - values(autodiff(Forward, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow[i]))[1]) + values(autodiff(fm, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow[i]))[1]) end cols = tupleconcat(tmp...) if x isa AbstractFloat @@ -1284,10 +1309,10 @@ end end end -@inline function jacobian(::ForwardMode, f::F, x::X, ::Val{1}; shadow=onehot(x)) where {F,X} +@inline function jacobian(fm::ForwardMode, f::F, x::X, ::Val{1}; shadow=onehot(x)) where {F,X} cols = ntuple(length(shadow)) do i Base.@_inline_meta - autodiff(Forward, f, DuplicatedNoNeed, Duplicated(x, shadow[i]))[1] + autodiff(fm, f, DuplicatedNoNeed, Duplicated(x, shadow[i]))[1] end if x isa AbstractFloat cols[1] @@ -1341,7 +1366,7 @@ For functions who return other types, this function will retun an array or tuple of shape `size(output)` of values of the input type. ``` """ -@inline function jacobian(::ReverseMode{#=ReturnPrimal=#false,RABI, #=Holomorphic=#false, ErrIfFuncWritten}, f::F, x::X, n_outs::Val{n_out_val}, ::Val{chunk}) where {F, X, chunk, n_out_val, RABI<:ABI, ErrIfFuncWritten} +@inline function jacobian(::ReverseMode{#=ReturnPrimal=#false,RuntimeActivity, RABI, #=Holomorphic=#false, ErrIfFuncWritten}, f::F, x::X, n_outs::Val{n_out_val}, ::Val{chunk}) where {F, X, chunk, n_out_val, RABI<:ABI, ErrIfFuncWritten, RuntimeActivity} num = ((n_out_val + chunk - 1) ÷ chunk) if chunk == 0 @@ -1360,7 +1385,7 @@ of shape `size(output)` of values of the input type. else Val(codegen_world_age(Core.Typeof(f), tt)) end - primal, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(chunk), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten)) + primal, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(chunk), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) if num * chunk == n_out_val last_size = chunk @@ -1368,7 +1393,7 @@ of shape `size(output)` of values of the input type. else last_size = n_out_val - (num-1)*chunk tt′ = Tuple{BatchDuplicated{Core.Typeof(x), last_size}} - primal2, adjoint2 = Enzyme.Compiler.thunk(opt_mi, FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(last_size), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten)) + primal2, adjoint2 = Enzyme.Compiler.thunk(opt_mi, FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(last_size), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) end tmp = ntuple(num) do i @@ -1417,7 +1442,7 @@ of shape `size(output)` of values of the input type. end end -@inline function jacobian(::ReverseMode{#=ReturnPrimal=#false,RABI, #=Holomorphic=#false, ErrIfFuncWritten}, f::F, x::X, n_outs::Val{n_out_val}, ::Val{1} = Val(1)) where {F, X, n_out_val,RABI<:ABI, ErrIfFuncWritten} +@inline function jacobian(::ReverseMode{#=ReturnPrimal=#false,RuntimeActivity,RABI, #=Holomorphic=#false, ErrIfFuncWritten}, f::F, x::X, n_outs::Val{n_out_val}, ::Val{1} = Val(1)) where {F, X, n_out_val,RuntimeActivity,RABI<:ABI, ErrIfFuncWritten} XT = Core.Typeof(x) MD = Compiler.active_reg_inner(XT, #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState tt′ = MD ? Tuple{MixedDuplicated{XT}} : Tuple{Duplicated{XT}} @@ -1430,7 +1455,7 @@ end else Val(codegen_world_age(Core.Typeof(f), tt)) end - primal, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, DuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(1), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten)) + primal, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, DuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(1), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) tmp = ntuple(n_outs) do i Base.@_inline_meta z = make_zero(x) @@ -1466,12 +1491,12 @@ end end end -@inline function jacobian(::ReverseMode{ReturnPrimal,RABI,Holomorphic,ErrIfFuncWritten}, f::F, x::X) where {ReturnPrimal, F, X, RABI<:ABI, Holomorphic, ErrIfFuncWritten} +@inline function jacobian(::ReverseMode{ReturnPrimal,RuntimeActivity, RABI, Holomorphic, ErrIfFuncWritten}, f::F, x::X) where {ReturnPrimal, F, X, RABI<:ABI, ErrIfFuncWritten, RuntimeActivity, Holomorphic} res = f(x) jac = if res isa AbstractArray - jacobian(ReverseMode{false,RABI, Holomorphic, ErrIfFuncWritten}(), f, x, Val(length(res))) + jacobian(ReverseMode{false,RuntimeActivity,RABI, Holomorphic, ErrIfFuncWritten}(), f, x, Val(length(res))) elseif res isa AbstractFloat - gradient(ReverseMode{false,RABI, Holomorphic, ErrIfFuncWritten}(), f, x) + gradient(ReverseMode{false,RuntimeActivity,RABI, Holomorphic, ErrIfFuncWritten}(), f, x) else throw(AssertionError("Unsupported return type of function for reverse-mode jacobian, $(Core.Typeof(res))")) end diff --git a/src/api.jl b/src/api.jl index 6de95beb55a..9e446dcf579 100644 --- a/src/api.jl +++ b/src/api.jl @@ -156,30 +156,30 @@ end # \p AtomicAdd is whether to perform all adjoint updates to memory in an atomic way # \p PostOpt is whether to perform basic optimization of the function after synthesis function EnzymeCreatePrimalAndGradient(logic, todiff, retType, constant_args, TA, - returnValue, dretUsed, mode, width, additionalArg, + returnValue, dretUsed, mode, runtimeActivity, width, additionalArg, forceAnonymousTape, typeInfo, uncacheable_args, augmented, atomicAdd) freeMemory = true ccall((:EnzymeCreatePrimalAndGradient, libEnzyme), LLVMValueRef, (EnzymeLogicRef, LLVMValueRef, LLVM.API.LLVMBuilderRef, LLVMValueRef, CDIFFE_TYPE, Ptr{CDIFFE_TYPE}, Csize_t, - EnzymeTypeAnalysisRef, UInt8, UInt8, CDerivativeMode, Cuint, UInt8, LLVMTypeRef, UInt8, CFnTypeInfo, + EnzymeTypeAnalysisRef, UInt8, UInt8, CDerivativeMode, UInt8, Cuint, UInt8, LLVMTypeRef, UInt8, CFnTypeInfo, Ptr{UInt8}, Csize_t, EnzymeAugmentedReturnPtr, UInt8), logic, C_NULL, C_NULL, todiff, retType, constant_args, length(constant_args), TA, returnValue, - dretUsed, mode, width, freeMemory, additionalArg, forceAnonymousTape, typeInfo, uncacheable_args, length(uncacheable_args), + dretUsed, mode, runtimeActivity, width, freeMemory, additionalArg, forceAnonymousTape, typeInfo, uncacheable_args, length(uncacheable_args), augmented, atomicAdd) end function EnzymeCreateForwardDiff(logic, todiff, retType, constant_args, TA, - returnValue, mode, width, additionalArg, typeInfo, + returnValue, mode, runtimeActivity, width, additionalArg, typeInfo, uncacheable_args) freeMemory = true aug = C_NULL ccall((:EnzymeCreateForwardDiff, libEnzyme), LLVMValueRef, (EnzymeLogicRef, LLVMValueRef, LLVM.API.LLVMBuilderRef, LLVMValueRef, CDIFFE_TYPE, Ptr{CDIFFE_TYPE}, Csize_t, - EnzymeTypeAnalysisRef, UInt8, CDerivativeMode, UInt8, Cuint, LLVMTypeRef, CFnTypeInfo, + EnzymeTypeAnalysisRef, UInt8, CDerivativeMode, UInt8, UInt8, Cuint, LLVMTypeRef, CFnTypeInfo, Ptr{UInt8}, Csize_t, EnzymeAugmentedReturnPtr), logic, C_NULL, C_NULL, todiff, retType, constant_args, length(constant_args), TA, returnValue, - mode, freeMemory, width, additionalArg, typeInfo, uncacheable_args, length(uncacheable_args), aug) + mode, freeMemory, runtimeActivity, width, additionalArg, typeInfo, uncacheable_args, length(uncacheable_args), aug) end # Create an augmented forward pass. @@ -195,14 +195,14 @@ end # \p PostOpt is whether to perform basic optimization of the function after synthesis function EnzymeCreateAugmentedPrimal(logic, todiff, retType, constant_args, TA, returnUsed, shadowReturnUsed, - typeInfo, uncacheable_args, forceAnonymousTape, width, atomicAdd) + typeInfo, uncacheable_args, forceAnonymousTape, runtimeActivity, width, atomicAdd) ccall((:EnzymeCreateAugmentedPrimal, libEnzyme), EnzymeAugmentedReturnPtr, (EnzymeLogicRef, LLVMValueRef, LLVM.API.LLVMBuilderRef, LLVMValueRef, CDIFFE_TYPE, Ptr{CDIFFE_TYPE}, Csize_t, EnzymeTypeAnalysisRef, UInt8, UInt8, - CFnTypeInfo, Ptr{UInt8}, Csize_t, UInt8, Cuint, UInt8), + CFnTypeInfo, Ptr{UInt8}, Csize_t, UInt8, UInt8, Cuint, UInt8), logic, C_NULL, C_NULL, todiff, retType, constant_args, length(constant_args), TA, returnUsed, shadowReturnUsed, - typeInfo, uncacheable_args, length(uncacheable_args), forceAnonymousTape, width, atomicAdd) + typeInfo, uncacheable_args, length(uncacheable_args), forceAnonymousTape, runtimeActivity, width, atomicAdd) end # typedef uint8_t (*CustomRuleType)(int /*direction*/, CTypeTreeRef /*return*/, @@ -252,6 +252,7 @@ EnzymeGradientUtilsErase(gutils, a) = ccall((:EnzymeGradientUtilsErase, libEnzym EnzymeGradientUtilsEraseWithPlaceholder(gutils, a, orig, erase) = ccall((:EnzymeGradientUtilsEraseWithPlaceholder, libEnzyme), Cvoid, (EnzymeGradientUtilsRef,LLVMValueRef, LLVMValueRef, UInt8), gutils, a, orig, erase) EnzymeGradientUtilsGetMode(gutils) = ccall((:EnzymeGradientUtilsGetMode, libEnzyme), CDerivativeMode, (EnzymeGradientUtilsRef,), gutils) EnzymeGradientUtilsGetWidth(gutils) = ccall((:EnzymeGradientUtilsGetWidth, libEnzyme), UInt64, (EnzymeGradientUtilsRef,), gutils) +EnzymeGradientUtilsGetRuntimeActivity(gutils) = ccall((:EnzymeGradientUtilsGetRuntimeActivity, libEnzyme), UInt8, (EnzymeGradientUtilsRef,), gutils) != 0 EnzymeGradientUtilsNewFromOriginal(gutils, val) = ccall((:EnzymeGradientUtilsNewFromOriginal, libEnzyme), LLVMValueRef, (EnzymeGradientUtilsRef, LLVMValueRef), gutils, val) EnzymeGradientUtilsSetDebugLocFromOriginal(gutils, val, orig) = ccall((:EnzymeGradientUtilsSetDebugLocFromOriginal, libEnzyme), Cvoid, (EnzymeGradientUtilsRef, LLVMValueRef, LLVMValueRef), gutils, val, orig) EnzymeGradientUtilsLookup(gutils, val, B) = ccall((:EnzymeGradientUtilsLookup, libEnzyme), LLVMValueRef, (EnzymeGradientUtilsRef, LLVMValueRef, LLVM.API.LLVMBuilderRef), gutils, val, B) @@ -556,51 +557,6 @@ function strong_zero!(val) ccall((:EnzymeSetCLInteger, libEnzyme), Cvoid, (Ptr{Cvoid}, UInt8), ptr, val) end -""" - runtimeActivity!(val::Bool) - -Enzyme runs an activity analysis which deduces which values, instructions, etc -are necessary to be differentiated and therefore involved in the differentiation -procedure. This runs at compile time. However, there may be implementation flaws -in this analysis that means that Enzyme cannot deduce that an inactive (const) -value is actually const. Alternatively, there may be some data which is conditionally -active, depending on which runtime branch is taken. In these cases Enzyme conservatively -presumes the value is active. - -However, in certain cases, an insufficiently aggressive activity analysis may result -in derivative errors -- for example by mistakenly using the primal (const) argument -and mistaking it for the duplicated shadow. As a result this may result in incorrect -results, or accidental updates to the primal. - -This flag enables runntime activity which tells all load/stores to check at runtime -whether the value they are updating is indeed active (in addition to the compile-time -activity analysis). This will remedy these such errors, but at a performance penalty -of performing such checks. - -It is on the Enzyme roadmap to add a PotentiallyDuplicated style activity, in addition -to the current Const and Duplicated styles that will disable the need for this, -which does not require the check when a value is guaranteed active, but still supports -runtime-based activity information. - -This function takes an argument to set the runtime activity value, true means it is on, -and false means off. By default it is off. -""" -function runtimeActivity!(val::Bool) - ptr = cglobal((:EnzymeRuntimeActivityCheck, libEnzyme)) - ccall((:EnzymeSetCLInteger, libEnzyme), Cvoid, (Ptr{Cvoid}, UInt8), ptr, val) -end - -""" - runtimeActivity() - -Gets the current value of the runtime activity. See [`runtimeActivity!`](@ref) for -more information. -""" -function runtimeActivity() - ptr = cglobal((:EnzymeRuntimeActivityCheck, libEnzyme)) - return EnzymeGetCLBool(ptr) != 0 -end - """ typeWarning!(val::Bool) diff --git a/src/compiler.jl b/src/compiler.jl index 0d9e1471560..fb7ce8d8bf8 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -954,11 +954,11 @@ end function Base.showerror(io::IO, ece::EnzymeRuntimeActivityError) println(io, "Constant memory is stored (or returned) to a differentiable variable.") println(io, "As a result, Enzyme cannot provably ensure correctness and throws this error.") - println(io, "This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Activity-of-temporary-storage).") + println(io, "This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Runtime-Activity).") println(io, "If Enzyme should be able to prove this use non-differentable, open an issue!"); println(io, "To work around this issue, either:"); println(io, " a) rewrite this variable to not be conditionally active (fastest, but requires a code change), or") - println(io, " b) set Enzyme.API.runtimeActivity!(true) immediately after loading Enzyme (which maintains correctness, but may slightly reduce performance).") + println(io, " b) set the Enzyme mode to turn on runtime activity (e.g. autodiff(set_runtime_activity(Reverse), ...) ). This will maintain correctness, but may slightly reduce performance.") msg = Base.unsafe_string(ece.msg) print(io, msg, '\n') end @@ -3324,6 +3324,9 @@ struct EnzymeCompilerParams <: AbstractEnzymeCompilerParams ABI::Type{<:ABI} # Whether to error if the function is written to err_if_func_written::Bool + + # Whether runtime activity is enabled + runtimeActivity::Bool end struct UnknownTapeType end @@ -3843,6 +3846,7 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr world = job.world interp = GPUCompiler.get_interpreter(job) rt = job.config.params.rt + runtimeActivity = job.config.params.runtimeActivity @assert eltype(rt) != Union{} shadow_init = job.config.params.shadowInit @@ -3960,7 +3964,7 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr augmented = API.EnzymeCreateAugmentedPrimal( logic, primalf, retType, args_activity, TA, #=returnUsed=# returnUsed, #=shadowReturnUsed=#shadowReturnUsed, - typeInfo, uncacheable_args, #=forceAnonymousTape=# false, width, #=atomicAdd=# parallel) + typeInfo, uncacheable_args, #=forceAnonymousTape=# false, runtimeActivity, width, #=atomicAdd=# parallel) # 2. get new_primalf and tape augmented_primalf = LLVM.Function(API.EnzymeExtractFunctionFromAugmentation(augmented)) @@ -3988,7 +3992,7 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr adjointf = LLVM.Function(API.EnzymeCreatePrimalAndGradient( logic, primalf, retType, args_activity, TA, - #=returnValue=#false, #=dretUsed=#false, #=mode=#API.DEM_ReverseModeGradient, width, + #=returnValue=#false, #=dretUsed=#false, #=mode=#API.DEM_ReverseModeGradient, runtimeActivity, width, #=additionalArg=#tape, #=forceAnonymousTape=#false, typeInfo, uncacheable_args, augmented, #=atomicAdd=# parallel)) if wrap @@ -3999,7 +4003,7 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr returnUsed &= returnPrimal adjointf = LLVM.Function(API.EnzymeCreatePrimalAndGradient( logic, primalf, retType, args_activity, TA, - #=returnValue=#returnUsed, #=dretUsed=#false, #=mode=#API.DEM_ReverseModeCombined, width, + #=returnValue=#returnUsed, #=dretUsed=#false, #=mode=#API.DEM_ReverseModeCombined, runtimeActivity, width, #=additionalArg=#C_NULL, #=forceAnonymousTape=#false, typeInfo, uncacheable_args, #=augmented=#C_NULL, #=atomicAdd=# parallel)) augmented_primalf = nothing @@ -4011,7 +4015,7 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr returnUsed &= returnPrimal adjointf = LLVM.Function(API.EnzymeCreateForwardDiff( logic, primalf, retType, args_activity, TA, - #=returnValue=#returnUsed, #=mode=#API.DEM_ForwardMode, width, + #=returnValue=#returnUsed, #=mode=#API.DEM_ForwardMode, runtimeActivity, width, #=additionalArg=#C_NULL, typeInfo, uncacheable_args)) augmented_primalf = nothing @@ -5495,7 +5499,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; ForwardModeTypes = ("s", "d", "c", "z") ReverseModeTypes = ("s", "d") # Tablegen BLAS does not support forward mode yet - if !(mode == API.DEM_ForwardMode && Enzyme.API.runtimeActivity()) + if !(mode == API.DEM_ForwardMode && params.runtimeActivity) for ty in (mode == API.DEM_ForwardMode ? ForwardModeTypes : ReverseModeTypes) for func in (mode == API.DEM_ForwardMode ? ForwardModeDerivatives : ReverseModeDerivatives) for prefix in ("", "cblas_") @@ -7124,9 +7128,9 @@ end @inline remove_innerty(::Type{<:MixedDuplicated}) = MixedDuplicated @inline remove_innerty(::Type{<:BatchMixedDuplicated}) = MixedDuplicated -@inline function thunkbase(ctx, mi::Core.MethodInstance, ::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}, ::Val{ErrIfFuncWritten}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI, ErrIfFuncWritten} +@inline function thunkbase(ctx, mi::Core.MethodInstance, ::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}, ::Val{ErrIfFuncWritten}, ::Val{RuntimeActivity}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI, ErrIfFuncWritten, RuntimeActivity} target = Compiler.EnzymeTarget() - params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI, ErrIfFuncWritten) + params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI, ErrIfFuncWritten, RuntimeActivity) tmp_job = if World isa Nothing Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false)) else @@ -7165,7 +7169,7 @@ end A2 end - params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, rt2, run_enzyme, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI, ErrIfFuncWritten) + params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, rt2, run_enzyme, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI, ErrIfFuncWritten, RuntimeActivity) job = if World isa Nothing Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false)) else @@ -7202,25 +7206,25 @@ end end end -@inline function thunk(mi::Core.MethodInstance, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}, ::Val{ErrIfFuncWritten}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, ABI, ErrIfFuncWritten} +@inline function thunk(mi::Core.MethodInstance, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}, ::Val{ErrIfFuncWritten}, ::Val{RuntimeActivity}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, ABI, ErrIfFuncWritten, RuntimeActivity} ts_ctx = JuliaContext() ctx = context(ts_ctx) activate(ctx) try - return thunkbase(ctx, mi, Val(#=World=#nothing), FA, A, TT, Val(Mode), Val(width), Val(ModifiedBetween), Val(ReturnPrimal), Val(ShadowInit), ABI, Val(ErrIfFuncWritten)) + return thunkbase(ctx, mi, Val(#=World=#nothing), FA, A, TT, Val(Mode), Val(width), Val(ModifiedBetween), Val(ReturnPrimal), Val(ShadowInit), ABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) finally deactivate(ctx) dispose(ts_ctx) end end -@inline @generated function thunk(::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}, ::Val{ErrIfFuncWritten}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI, ErrIfFuncWritten} +@inline @generated function thunk(::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}, ::Val{ErrIfFuncWritten}, ::Val{RuntimeActivity}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI, ErrIfFuncWritten, RuntimeActivity} mi = fspec(eltype(FA), TT, World) ts_ctx = JuliaContext() ctx = context(ts_ctx) activate(ctx) res = try - thunkbase(ctx, mi, Val(World), FA, A, TT, Val(Mode), Val(width), Val(ModifiedBetween), Val(ReturnPrimal), Val(ShadowInit), ABI, Val(ErrIfFuncWritten)) + thunkbase(ctx, mi, Val(World), FA, A, TT, Val(Mode), Val(width), Val(ModifiedBetween), Val(ReturnPrimal), Val(ShadowInit), ABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) finally deactivate(ctx) dispose(ts_ctx) @@ -7234,14 +7238,14 @@ end import GPUCompiler: deferred_codegen_jobs @generated function deferred_codegen(::Val{World}, ::Type{FA}, ::Val{TT}, ::Val{A},::Val{Mode}, - ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal},::Val{ShadowInit},::Type{ExpectedTapeType}, ::Val{ErrIfFuncWritten}) where {World, FA<:Annotation,TT, A, Mode, width, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType, ErrIfFuncWritten} + ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal},::Val{ShadowInit},::Type{ExpectedTapeType}, ::Val{ErrIfFuncWritten}, ::Val{RuntimeActivity}) where {World, FA<:Annotation,TT, A, Mode, width, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType, ErrIfFuncWritten, RuntimeActivity} JuliaContext() do ctx Base.@_inline_meta mi = fspec(eltype(FA), TT, World) target = EnzymeTarget() rt2 = if A isa UnionAll - params = EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType, FFIABI, ErrIfFuncWritten) + params = EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType, FFIABI, ErrIfFuncWritten, RuntimeActivity) tmp_job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) interp = GPUCompiler.get_interpreter(tmp_job) @@ -7265,7 +7269,7 @@ import GPUCompiler: deferred_codegen_jobs A end - params = EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, rt2, true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType, FFIABI, ErrIfFuncWritten) + params = EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, rt2, true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType, FFIABI, ErrIfFuncWritten, RuntimeActivity) job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) addr = get_trampoline(job) diff --git a/src/compiler/reflection.jl b/src/compiler/reflection.jl index 944b0b24989..583a6f2f68c 100644 --- a/src/compiler/reflection.jl +++ b/src/compiler/reflection.jl @@ -1,5 +1,5 @@ function get_job(@nospecialize(func), @nospecialize(A), @nospecialize(types); - run_enzyme::Bool=true, mode::API.CDerivativeMode=API.DEM_ReverseModeCombined, dupClosure::Bool=false, argwrap::Bool=true, width::Int=1, modifiedBetween=nothing, returnPrimal::Bool=false, augmentedInit=false, world=nothing, ABI=DefaultABI, ErrIfFuncWritten=false, kwargs...) + run_enzyme::Bool=true, mode::API.CDerivativeMode=API.DEM_ReverseModeCombined, dupClosure::Bool=false, argwrap::Bool=true, width::Int=1, modifiedBetween=nothing, returnPrimal::Bool=false, augmentedInit=false, world=nothing, ABI=DefaultABI, ErrIfFuncWritten=false, RuntimeActivity=true, kwargs...) tt = Tuple{map(eltype, types.parameters)...} if world === nothing @@ -15,7 +15,7 @@ function get_job(@nospecialize(func), @nospecialize(A), @nospecialize(types); defaultMod = mode != API.DEM_ReverseModeCombined && mode != API.DEM_ForwardMode modifiedBetween = (defaultMod, (defaultMod for _ in types.parameters)...) end - params = Compiler.EnzymeCompilerParams(Tuple{(dupClosure ? Duplicated : Const){Core.Typeof(func)}, types.parameters...}, mode, width, rt, run_enzyme, argwrap, modifiedBetween, returnPrimal, augmentedInit, Compiler.UnknownTapeType, ABI, ErrIfFuncWritten) + params = Compiler.EnzymeCompilerParams(Tuple{(dupClosure ? Duplicated : Const){Core.Typeof(func)}, types.parameters...}, mode, width, rt, run_enzyme, argwrap, modifiedBetween, returnPrimal, augmentedInit, Compiler.UnknownTapeType, ABI, ErrIfFuncWritten, RuntimeActivity) return Compiler.CompilerJob(primal, CompilerConfig(target, params; kernel=false), world) end diff --git a/src/gradientutils.jl b/src/gradientutils.jl index f7f80fd396e..ff83f60c1d7 100644 --- a/src/gradientutils.jl +++ b/src/gradientutils.jl @@ -13,6 +13,7 @@ end get_width(gutils::GradientUtils) = API.EnzymeGradientUtilsGetWidth(gutils) get_mode(gutils::GradientUtils) = API.EnzymeGradientUtilsGetMode(gutils) +get_runtime_activity(gutils::GradientUtils) = API.EnzymeGradientUtilsGetRuntimeActivity(gutils) function get_shadow_type(gutils::GradientUtils, T::LLVM.LLVMType) w = get_width(gutils) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 8b772e8e24d..96f774f69ed 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -121,23 +121,13 @@ Enzyme.EnzymeRules.inactive_noinl(::typeof(Core._compute_sparams), args...) = no @inline EnzymeRules.inactive_type(v::Type{Core.Compiler.WorldRange}) = true @inline EnzymeRules.inactive_type(v::Type{Core.MethodInstance}) = true -@inline width(::Duplicated) = 1 -@inline width(::BatchDuplicated{T, N}) where {T, N} = N -@inline width(::DuplicatedNoNeed) = 1 -@inline width(::BatchDuplicatedNoNeed{T, N}) where {T, N} = N - -@inline width(::Type{Duplicated{T}}) where T = 1 -@inline width(::Type{BatchDuplicated{T, N}}) where {T, N} = N -@inline width(::Type{DuplicatedNoNeed{T}}) where T = 1 -@inline width(::Type{BatchDuplicatedNoNeed{T, N}}) where {T, N} = N - # Note all of these forward mode definitions do not support runtime activity as # the do not keep the primal if shadow(x.y) == primal(x.y) -function EnzymeRules.forward(::Const{typeof(Base.deepcopy)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) +function EnzymeRules.forward(config::EnzymeRules.FwdConfig, ::Const{typeof(Base.deepcopy)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) return deepcopy(x.dval) end -function EnzymeRules.forward(::Const{typeof(Base.deepcopy)}, ::Type{<:BatchDuplicatedNoNeed}, x::BatchDuplicated{T, N}) where {T, N} +function EnzymeRules.forward(config::EnzymeRules.FwdConfig, ::Const{typeof(Base.deepcopy)}, ::Type{<:BatchDuplicatedNoNeed}, x::BatchDuplicated{T, N}) where {T, N} ntuple(Val(N)) do _ deepcopy(x.dval) end @@ -164,19 +154,19 @@ end return seen[shadow] end -function EnzymeRules.forward(func::Const{typeof(Base.deepcopy)}, ::Type{<:Duplicated}, x::Duplicated) +function EnzymeRules.forward(config::EnzymeRules.FwdConfig, func::Const{typeof(Base.deepcopy)}, ::Type{<:Duplicated}, x::Duplicated) primal = func.val(x.val) return Duplicated(primal, deepcopy_rtact(primal, x.val, IdDict(), x.dval)) end -function EnzymeRules.forward(func::Const{typeof(Base.deepcopy)}, ::Type{<:BatchDuplicated}, x::BatchDuplicated{T, N}) where {T,N} +function EnzymeRules.forward(config::EnzymeRules.FwdConfig, func::Const{typeof(Base.deepcopy)}, ::Type{<:BatchDuplicated}, x::BatchDuplicated{T, N}) where {T,N} primal = func.val(x.val) return BatchDuplicated(primal, ntuple(Val(N)) do i deepcopy_rtact(primal, x.val, IdDict(), x.dval[i]) end) end -function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.deepcopy)}, ::Type{RT}, x::Annotation{Ty}) where {RT, Ty} +function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, func::Const{typeof(Base.deepcopy)}, ::Type{RT}, x::Annotation{Ty}) where {RT, Ty} primal = if EnzymeRules.needs_primal(config) func.val(x.val) else @@ -244,7 +234,7 @@ end return seen[into] end -function EnzymeRules.reverse(config, func::Const{typeof(Base.deepcopy)}, ::Type{RT}, shadow, x::Annotation{Ty}) where {RT, Ty} +function EnzymeRules.reverse(config::EnzymeRules.RevConfig, func::Const{typeof(Base.deepcopy)}, ::Type{RT}, shadow, x::Annotation{Ty}) where {RT, Ty} if EnzymeRules.needs_shadow(config) if EnzymeRules.width(config) == 1 accumulate_into(x.dval, IdDict(), shadow) @@ -266,9 +256,9 @@ end unsafe_store!(tapes, thunk(f, Const(idx), fargs...)[1], idx) end -function EnzymeRules.augmented_primal(config, func::Const{typeof(Enzyme.pmap)}, ::Type{Const{Nothing}}, body::BodyTy, count, args::Vararg{Annotation, N}) where {BodyTy, N} +function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, func::Const{typeof(Enzyme.pmap)}, ::Type{Const{Nothing}}, body::BodyTy, count, args::Vararg{Annotation, N}) where {BodyTy, N} - config2 = ReverseModeSplit{false, false, EnzymeRules.width(config), EnzymeRules.overwritten(config)[2:end],InlineABI, false}() + config2 = ReverseModeSplit{false, false, EnzymeRules.runtime_activity(config), EnzymeRules.width(config), EnzymeRules.overwritten(config)[2:end],InlineABI, false}() fwd_thunk, rev_thunk = autodiff_thunk(config2, BodyTy, Const, typeof(count), map(typeof, args)...) TapeType = EnzymeRules.tape_type(fwd_thunk) @@ -291,9 +281,9 @@ end thunk(f, Const(idx), fargs..., unsafe_load(tapes, idx)) end -function EnzymeRules.reverse(config, func::Const{typeof(Enzyme.pmap)}, ::Type{Const{Nothing}}, tapes, body::BodyTy, count, args::Vararg{Annotation, N}) where {BodyTy, N} +function EnzymeRules.reverse(config::EnzymeRules.RevConfig, func::Const{typeof(Enzyme.pmap)}, ::Type{Const{Nothing}}, tapes, body::BodyTy, count, args::Vararg{Annotation, N}) where {BodyTy, N} - config2 = ReverseModeSplit{false, false, EnzymeRules.width(config), EnzymeRules.overwritten(config)[2:end],InlineABI, false}() + config2 = ReverseModeSplit{false, false, EnzymeRules.runtime_activity(config), EnzymeRules.width(config), EnzymeRules.overwritten(config)[2:end],InlineABI, false}() fwd_thunk, rev_thunk = autodiff_thunk(config2, BodyTy, Const, typeof(count), map(typeof, args)...) Enzyme.pmap(pmap_rev, count.val, tapes, rev_thunk, body, args...) @@ -338,7 +328,7 @@ end # y=inv(A) B # dA −= z y^T # dB += z, where z = inv(A^T) dy -function EnzymeRules.augmented_primal(config, func::Const{typeof(\)}, ::Type{RT}, A::Annotation{AT}, b::Annotation{BT}) where {RT, AT <: Array, BT <: Array} +function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, func::Const{typeof(\)}, ::Type{RT}, A::Annotation{AT}, b::Annotation{BT}) where {RT, AT <: Array, BT <: Array} cache_A = if EnzymeRules.overwritten(config)[2] copy(A.val) @@ -395,13 +385,13 @@ function EnzymeRules.augmented_primal(config, func::Const{typeof(\)}, ::Type{RT} ) return EnzymeRules.AugmentedReturn{ - EnzymeRules.needs_primal(config) ? eltype(RT) : Nothing, - EnzymeRules.needs_shadow(config) ? (EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{EnzymeRules.width(config), eltype(RT)}) : Nothing, + EnzymeRules.primal_type(config, RT), + EnzymeRules.shadow_type(config, RT), typeof(cache) }(retres, dres, cache) end -function EnzymeRules.reverse(config, func::Const{typeof(\)}, ::Type{RT}, cache, A::Annotation{<:Array}, b::Annotation{<:Array}) where RT +function EnzymeRules.reverse(config::EnzymeRules.RevConfig, func::Const{typeof(\)}, ::Type{RT}, cache, A::Annotation{<:Array}, b::Annotation{<:Array}) where RT y, dys, cache_A, cache_b = cache @@ -469,7 +459,7 @@ const EnzymeTriangulars = Union{ } function EnzymeRules.augmented_primal( - config, + config::EnzymeRules.RevConfig, func::Const{typeof(ldiv!)}, ::Type{RT}, Y::Annotation{YT}, @@ -483,12 +473,17 @@ function EnzymeRules.augmented_primal( primal = EnzymeRules.needs_primal(config) ? Y.val : nothing shadow = EnzymeRules.needs_shadow(config) ? Y.dval : nothing func.val(Y.val, A.val, B.val) - return EnzymeRules.AugmentedReturn{typeof(primal), typeof(shadow), Any}( - primal, shadow, (cache_Y, cache_A, cache_B)) + return EnzymeRules.AugmentedReturn{ + EnzymeRules.primal_type(config, RT), + EnzymeRules.shadow_type(config, RT), + Tuple{typeof(cache_Y), typeof(cache_A), typeof(cache_B)} + }( + primal, shadow, (cache_Y, cache_A, cache_B) + ) end function EnzymeRules.reverse( - config, + config::EnzymeRules.RevConfig, func::Const{typeof(ldiv!)}, ::Type{RT}, cache, @@ -521,7 +516,7 @@ _zero_unused_elements!(X, ::UnitUpperTriangular) = triu!(X, 1) _zero_unused_elements!(X, ::UnitLowerTriangular) = tril!(X, -1) # Force a rule around hvcat_fill as it is type unstable if the tuple is not of the same type (e.g., int, float, int, float) -function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple} +function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple} primal = if EnzymeRules.needs_primal(config) out.val else @@ -536,7 +531,7 @@ function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.hvcat_fill return EnzymeRules.AugmentedReturn(primal, shadow, nothing) end -function EnzymeRules.reverse(config, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, _, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple} +function EnzymeRules.reverse(config::EnzymeRules.RevConfig, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, _, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple} nr, nc = size(out.val,1), size(out.val,2) for b in 1:EnzymeRules.width(config) da = if EnzymeRules.width(config) == 1 @@ -569,7 +564,7 @@ function EnzymeRules.reverse(config, func::Const{typeof(Base.hvcat_fill!)}, ::Ty return (nothing, nothing) end -function EnzymeRules.forward( +function EnzymeRules.forward(config::EnzymeRules.FwdConfig, ::Const{typeof(sort!)}, RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, xs::Duplicated{T}; @@ -587,7 +582,7 @@ function EnzymeRules.forward( end end -function EnzymeRules.forward( +function EnzymeRules.forward(config::EnzymeRules.FwdConfig, ::Const{typeof(sort!)}, RT::Type{<:Union{Const, BatchDuplicatedNoNeed, BatchDuplicated}}, xs::BatchDuplicated{T, N}; @@ -609,7 +604,7 @@ end function EnzymeRules.augmented_primal( - config::EnzymeRules.ConfigWidth{1}, + config::EnzymeRules.RevConfigWidth{1}, ::Const{typeof(sort!)}, RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, xs::Duplicated{T}; @@ -632,7 +627,7 @@ function EnzymeRules.augmented_primal( end function EnzymeRules.reverse( - config::EnzymeRules.ConfigWidth{1}, + config::EnzymeRules.RevConfigWidth{1}, ::Const{typeof(sort!)}, RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, tape, @@ -645,7 +640,7 @@ function EnzymeRules.reverse( return (nothing,) end -function EnzymeRules.forward( +function EnzymeRules.forward(config::EnzymeRules.FwdConfig, ::Const{typeof(partialsort!)}, RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, xs::Duplicated{T}, @@ -670,7 +665,7 @@ function EnzymeRules.forward( end end -function EnzymeRules.forward( +function EnzymeRules.forward(config::EnzymeRules.FwdConfig, ::Const{typeof(partialsort!)}, RT::Type{<:Union{Const, BatchDuplicatedNoNeed, BatchDuplicated}}, xs::BatchDuplicated{T, N}, @@ -702,7 +697,7 @@ function EnzymeRules.forward( end function EnzymeRules.augmented_primal( - config::EnzymeRules.ConfigWidth{1}, + config::EnzymeRules.RevConfigWidth{1}, ::Const{typeof(partialsort!)}, RT::Type{<:Union{Const, Active, DuplicatedNoNeed, Duplicated}}, xs::Duplicated{T}, @@ -728,7 +723,7 @@ function EnzymeRules.augmented_primal( end function EnzymeRules.reverse( - config::EnzymeRules.ConfigWidth{1}, + config::EnzymeRules.RevConfigWidth{1}, ::Const{typeof(partialsort!)}, dret::Union{Active, Type{<:Union{Const, Active, DuplicatedNoNeed, Duplicated}}}, tape, @@ -755,7 +750,7 @@ end # -> # B(out) = inv(A) B(in) # dB(out) = inv(A) [ dB(in) - dA B(out) ] -function EnzymeRules.forward(func::Const{typeof(ldiv!)}, +function EnzymeRules.forward(config::EnzymeRules.FwdConfig, func::Const{typeof(ldiv!)}, RT::Type{<:Union{Const,Duplicated,BatchDuplicated}}, fact::Annotation{<:Cholesky}, B::Annotation{<:AbstractVecOrMat}; @@ -763,7 +758,7 @@ function EnzymeRules.forward(func::Const{typeof(ldiv!)}, if B isa Const return func.val(fact.val, B.val; kwargs...) else - N = width(B) + N = EnzymeRules.width(config) retval = B.val L = fact.val.L @@ -810,7 +805,7 @@ end # Float64 ranges in Julia use bitwise `&` with higher precision # to correct for numerical error, thus we put rules over the # operations as this is not directly differentiable -function EnzymeRules.forward(func::Const{Colon}, +function EnzymeRules.forward(config::EnzymeRules.FwdConfig, func::Const{Colon}, RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated, BatchDuplicated,BatchDuplicatedNoNeed}}, start::Annotation{<:AbstractFloat}, step::Annotation{<:AbstractFloat}, stop::Annotation{<:AbstractFloat}) @@ -820,7 +815,7 @@ function EnzymeRules.forward(func::Const{Colon}, elseif start isa Duplicated || start isa DuplicatedNoNeed start.dval elseif start isa BatchDuplicated || start isa BatchDuplicatedNoNeed - ntuple(i -> start.dval[i], Val(width(RT))) + ntuple(i -> start.dval[i], Val(EnzymeRules.width(config))) else error("Annotation type $(typeof(start)) not supported for range start. Please open an issue") end @@ -830,7 +825,7 @@ function EnzymeRules.forward(func::Const{Colon}, elseif step isa Duplicated || step isa DuplicatedNoNeed step.dval elseif step isa BatchDuplicated || step isa BatchDuplicatedNoNeed - ntuple(i -> step.dval[i], Val(width(RT))) + ntuple(i -> step.dval[i], Val(EnzymeRules.width(config))) else error("Annotation type $(typeof(start)) not supported for range step. Please open an issue") end @@ -845,11 +840,11 @@ function EnzymeRules.forward(func::Const{Colon}, BatchDuplicated(ret, ntuple(i -> range(dstart isa Number ? dstart : dstart[i]; step=dstep isa Number ? dstep : dstep[i], - length=length(ret)), Val(width(RT)))) + length=length(ret)), Val(EnzymeRules.width(config)))) elseif RT <: BatchDuplicatedNoNeed ntuple(i -> range(dstart isa Number ? dstart : dstart[i]; step=dstep isa Number ? dstep : dstep[i], - length=length(ret)), Val(width(RT))) + length=length(ret)), Val(EnzymeRules.width(config))) else error("This should not be possible. Please report.") end @@ -857,7 +852,7 @@ end -function EnzymeRules.augmented_primal(config, func::Const{Colon}, ::Type{<:Active}, +function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, func::Const{Colon}, ::Type{<:Active}, start::Annotation{<:AbstractFloat}, step::Annotation{<:AbstractFloat}, stop::Annotation{<:AbstractFloat}) if EnzymeRules.needs_primal(config) @@ -868,7 +863,7 @@ function EnzymeRules.augmented_primal(config, func::Const{Colon}, ::Type{<:Activ return EnzymeRules.AugmentedReturn(primal, nothing, nothing) end -function EnzymeRules.reverse(config, func::Const{Colon}, dret, tape::Nothing, +function EnzymeRules.reverse(config::EnzymeRules.RevConfig, func::Const{Colon}, dret, tape::Nothing, start::Annotation{T1}, step::Annotation{T2}, stop::Annotation{T3}) where {T1<:AbstractFloat, T2<:AbstractFloat, T3<:AbstractFloat} dstart = if start isa Const @@ -908,7 +903,7 @@ function EnzymeRules.reverse(config, func::Const{Colon}, dret, tape::Nothing, end -function EnzymeRules.forward( +function EnzymeRules.forward(config::EnzymeRules.FwdConfig, Ty::Const{Type{BigFloat}}, RT::Type{<:Union{DuplicatedNoNeed, Duplicated, BatchDuplicated, BatchDuplicatedNoNeed}}; kwargs... @@ -920,13 +915,13 @@ function EnzymeRules.forward( elseif RT <: Duplicated return RT(Ty.val(; kwargs...), Ty.val(; kwargs...)) elseif RT <: BatchDuplicatedNoNeed - ntuple(Val(width(RT))) do i + ntuple(Val(EnzymeRules.width(config))) do i Base.@_inline_meta Ty.val(; kwargs...) end else @assert RT <: BatchDuplicated - tup = ntuple(Val(width(RT))) do i + tup = ntuple(Val(EnzymeRules.width(config))) do i Base.@_inline_meta Ty.val(; kwargs...) end @@ -935,7 +930,7 @@ function EnzymeRules.forward( end function EnzymeRules.augmented_primal( - config, + config::EnzymeRules.RevConfig, Ty::Const{Type{BigFloat}}, RT::Type{<:Union{DuplicatedNoNeed, Duplicated, BatchDuplicated, BatchDuplicatedNoNeed}}, kwargs... @@ -961,7 +956,7 @@ function EnzymeRules.augmented_primal( end function EnzymeRules.reverse( - config, + config::EnzymeRules.RevConfig, Ty::Const{Type{BigFloat}}, RT::Type{<:Union{DuplicatedNoNeed, Duplicated, BatchDuplicated, BatchDuplicatedNoNeed}}, tape, diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 005557c65ab..75e36370d8c 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -352,6 +352,7 @@ end end width = get_width(gutils) + C = EnzymeRules.FwdConfig{Int(width), get_runtime_activity(gutils)} if shadowR != C_NULL unsafe_store!(shadowR,UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))).ref) @@ -384,10 +385,12 @@ end @assert kwtup !== nothing insert!(tt, 1, kwtup) insert!(tt, 2, Core.typeof(EnzymeRules.forward)) - insert!(tt, 4, Type{RT}) + insert!(tt, 3, C) + insert!(tt, 5, Type{RT}) else @assert kwtup === nothing - insert!(tt, 2, Type{RT}) + insert!(tt, 1, C) + insert!(tt, 3, Type{RT}) end TT = Tuple{tt...} @@ -595,7 +598,7 @@ end fn = LLVM.parent(LLVM.parent(orig)) world = enzyme_extract_world(fn) - C = EnzymeRules.Config{Bool(needsPrimal), Bool(needsShadowJL), Int(width), overwritten} + C = EnzymeRules.RevConfig{Bool(needsPrimal), Bool(needsShadowJL), Int(width), overwritten, get_runtime_activity(gutils)} mode = get_mode(gutils) @@ -673,7 +676,7 @@ end needsShadow end - C = EnzymeRules.Config{Bool(needsPrimal), Bool(needsShadowJL), Int(width), overwritten} + C = EnzymeRules.RevConfig{Bool(needsPrimal), Bool(needsShadowJL), Int(width), overwritten, get_runtime_activity(gutils)} alloctx = LLVM.IRBuilder() position!(alloctx, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils))) diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index 49622ada901..f3f05087c00 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -218,7 +218,7 @@ function body_runtime_generic_fwd(N, Width, wrapped, primtypes) world = codegen_world_age(FT, tt) opt_mi = Val(world) - forward = thunk(opt_mi, (dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ForwardMode), width, #=ModifiedBetween=#Val(($(ModifiedBetween...),)), #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false)) + forward = thunk(opt_mi, (dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ForwardMode), width, #=ModifiedBetween=#Val(($(ModifiedBetween...),)), #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity) res = forward(dupClosure ? Duplicated(f, df) : Const(f), args...) @@ -238,13 +238,13 @@ function func_runtime_generic_fwd(N, Width) body = body_runtime_generic_fwd(N, Width, wrapped, primtypes) quote - function runtime_generic_fwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...)) where {ActivityTup, ReturnType, F, DF, $(typeargs...)} + function runtime_generic_fwd(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{$Width}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...)) where {ActivityTup, RuntimeActivity, ReturnType, F, DF, $(typeargs...)} $body end end end -@generated function runtime_generic_fwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, Width, ReturnType, F, DF} +@generated function runtime_generic_fwd(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{Width}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, RuntimeActivity, Width, ReturnType, F, DF} N = div(length(allargs)+2, Width+1)-1 _, _, primtypes, _, _, wrapped, _, _, _ = setup_macro_wraps(true, N, Width, :allargs) return body_runtime_generic_fwd(N, Width, wrapped, primtypes) @@ -333,7 +333,7 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) opt_mi = Val(world) forward, adjoint = thunk(opt_mi, dupClosure0 ? Duplicated{FT} : Const{FT}, annotationA, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width, - ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false)) + ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity) internal_tape, origRet, initShadow = forward(dupClosure0 ? Duplicated(f, df) : Const(f), args...) annotation = annotationA @@ -358,13 +358,13 @@ function func_runtime_generic_augfwd(N, Width) body = body_runtime_generic_augfwd(N, Width, wrapped, primtypes, active_refs) quote - function runtime_generic_augfwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...))::ReturnType where {ActivityTup, MB, ReturnType, F, DF, $(typeargs...)} + function runtime_generic_augfwd(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...))::ReturnType where {ActivityTup, MB, ReturnType, RuntimeActivity, F, DF, $(typeargs...)} $body end end end -@generated function runtime_generic_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, allargs...)::ReturnType where {ActivityTup, MB, Width, ReturnType, F, DF} +@generated function runtime_generic_augfwd(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, allargs...)::ReturnType where {ActivityTup, MB, RuntimeActivity, Width, ReturnType, F, DF} N = div(length(allargs)+2, Width+1)-1 _, _, primtypes, _, _, wrapped, _, _, active_refs = setup_macro_wraps(false, N, Width, :allargs) return body_runtime_generic_augfwd(N, Width, wrapped, primtypes, active_refs) @@ -462,7 +462,7 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act opt_mi = Val(world) _, adjoint = thunk(opt_mi, dupClosure0 ? Duplicated{FT} : Const{FT}, annotation, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width, - ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false)) + ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity) tup = if annotation0 <: Active || annotation0 <: MixedDuplicated || annotation0 <: BatchMixedDuplicated adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., $shadowret, tape.internal_tape)[1] @@ -480,13 +480,13 @@ function func_runtime_generic_rev(N, Width) body = body_runtime_generic_rev(N, Width, wrapped, primtypes, batchshadowargs, active_refs) quote - function runtime_generic_rev(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, $(allargs...)) where {ActivityTup, MB, TapeType, F, DF, $(typeargs...)} + function runtime_generic_rev(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{$Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, $(allargs...)) where {ActivityTup, RuntimeActivity, MB, TapeType, F, DF, $(typeargs...)} $body end end end -@generated function runtime_generic_rev(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, TapeType, F, DF} +@generated function runtime_generic_rev(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, allargs...) where {ActivityTup, MB, RuntimeActivity, Width, TapeType, F, DF} N = div(length(allargs)+2, Width+1)-1 _, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width, :allargs) return body_runtime_generic_rev(N, Width, wrapped, primtypes, batchshadowargs, active_refs) @@ -676,7 +676,7 @@ end end # This is explicitly escaped here to be what is apply generic in total [and thus all the insides are stable] -function fwddiff_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType}, ::Type{FT}, ::Type{tt′}, f::FT, df::DF, args::Vararg{Annotation, Nargs})::ReturnType where {width, dupClosure0, ReturnType, FT, tt′, DF, Nargs} +function fwddiff_with_return(runtimeActivity::Val{RuntimeActivity}, ::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType}, ::Type{FT}, ::Type{tt′}, f::FT, df::DF, args::Vararg{Annotation, Nargs})::ReturnType where {RuntimeActivity, width, dupClosure0, ReturnType, FT, tt′, DF, Nargs} ReturnPrimal = Val(true) ModifiedBetween = Val(Enzyme.falses_from_args(Nargs+1)) @@ -714,7 +714,7 @@ function fwddiff_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType end opt_mi = Val(world) res = thunk(opt_mi, FA, annotation, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), - ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false))(fa, args...) + ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity)(fa, args...) return if annotation <: Const ReturnType(allFirst(Val(width+1), res)) else @@ -736,7 +736,7 @@ function body_runtime_iterate_fwd(N, Width, wrapped, primtypes, active_refs) args = ($(wrappedexexpand...),) tt′ = Enzyme.vaTypeof(args...) FT = Core.Typeof(f) - fwddiff_with_return(Val($Width), Val(ActivityTup[1]), ReturnType, FT, tt′, f, df, args...)::ReturnType + fwddiff_with_return(runtimeActivity, Val($Width), Val(ActivityTup[1]), ReturnType, FT, tt′, f, df, args...)::ReturnType end end @@ -745,13 +745,13 @@ function func_runtime_iterate_fwd(N, Width) body = body_runtime_iterate_fwd(N, Width, wrapped, primtypes, active_refs) quote - function runtime_iterate_fwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...)) where {ActivityTup, ReturnType, F, DF, $(typeargs...)} + function runtime_iterate_fwd(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{$Width}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...)) where {ActivityTup, RuntimeActivity, ReturnType, F, DF, $(typeargs...)} $body end end end -@generated function runtime_iterate_fwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, Width, ReturnType, F, DF} +@generated function runtime_iterate_fwd(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{Width}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, RuntimeActivity, Width, ReturnType, F, DF} N = div(length(allargs)+2, Width+1)-1 _, _, primtypes, _, _, wrapped, _, _, active_refs = setup_macro_wraps(true, N, Width, :allargs, #=iterate=#true) return body_runtime_iterate_fwd(N, Width, wrapped, primtypes, active_refs) @@ -822,7 +822,7 @@ end end # This is explicitly escaped here to be what is apply generic in total [and thus all the insides are stable] -function augfwd_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType}, ::Val{ModifiedBetween0}, ::Type{FT}, ::Type{tt′}, f::FT, df::DF, args::Vararg{Annotation, Nargs})::ReturnType where {width, dupClosure0, ReturnType, ModifiedBetween0, FT, tt′, DF, Nargs} +function augfwd_with_return(runtimeActivity::Val{RuntimeActivity}, ::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType}, ::Val{ModifiedBetween0}, ::Type{FT}, ::Type{tt′}, f::FT, df::DF, args::Vararg{Annotation, Nargs})::ReturnType where {RuntimeActivity, width, dupClosure0, ReturnType, ModifiedBetween0, FT, tt′, DF, Nargs} ReturnPrimal = Val(true) ModifiedBetween = Val(ModifiedBetween0) @@ -869,7 +869,7 @@ function augfwd_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType} opt_mi = Val(world) forward, adjoint = thunk(opt_mi, FA, annotation, tt′, Val(API.DEM_ReverseModePrimal), Val(width), - ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false)) + ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity) forward(fa, args...) else nothing, primal_tuple(args...), annotation <: Active ? nothing : shadow_tuple(annotation, Val(width), args...) @@ -933,7 +933,7 @@ function body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes, a args = ($(wrappedexexpand...),) tt′ = Enzyme.vaTypeof(args...) FT = Core.Typeof(f) - tmpvals = augfwd_with_return(Val($Width), Val(ActivityTup[1]), ReturnType, Val(concat($(modbetween...))), FT, tt′, f, df, args...)::ReturnType + tmpvals = augfwd_with_return(runtimeActivity, Val($Width), Val(ActivityTup[1]), ReturnType, Val(concat($(modbetween...))), FT, tt′, f, df, args...)::ReturnType ReturnType(($(results...), (tmpvals[$(Width+2)], refs))) end end @@ -943,13 +943,13 @@ function func_runtime_iterate_augfwd(N, Width) body = body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes, active_refs) quote - function runtime_iterate_augfwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...)) where {ActivityTup, MB, ReturnType, F, DF, $(typeargs...)} + function runtime_iterate_augfwd(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...)) where {ActivityTup, RuntimeActivity, MB, ReturnType, F, DF, $(typeargs...)} $body end end end -@generated function runtime_iterate_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, ReturnType, F, DF} +@generated function runtime_iterate_augfwd(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, RuntimeActivity, MB, Width, ReturnType, F, DF} N = div(length(allargs)+2, Width+1)-1 _, _, primtypes, _, _, wrapped, _ , modbetween, active_refs = setup_macro_wraps(false, N, Width, :allargs, #=iterate=#true) return body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes, active_refs) @@ -970,7 +970,7 @@ function add_into_vec!(val::T, expr, vec, idx_in_vec) where T end # This is explicitly escaped here to be what is apply generic in total [and thus all the insides are stable] -@generated function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween0}, ::Val{lengths}, ::Type{FT}, ::Type{ttp}, f::FT, df::DF, tape, shadowargs, args::Vararg{Annotation, Nargs})::Nothing where {width, dupClosure0, ModifiedBetween0, lengths, FT, ttp, DF, Nargs} +@generated function rev_with_return(runtimeActivity::Val{RuntimeActivity}, ::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween0}, ::Val{lengths}, ::Type{FT}, ::Type{ttp}, f::FT, df::DF, tape, shadowargs, args::Vararg{Annotation, Nargs})::Nothing where {RuntimeActivity, width, dupClosure0, ModifiedBetween0, lengths, FT, ttp, DF, Nargs} nontupexprs = Vector{Expr}(undef, Nargs) for i in 1:Nargs @@ -1092,7 +1092,7 @@ end opt_mi = Val(world) forward, adjoint = thunk(opt_mi, FA, annotation, $ttp, Val(API.DEM_ReverseModePrimal), Val($width), - ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false)) + ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity) tup = if tape.shadow_return !== nothing $shadadj @@ -1155,7 +1155,7 @@ function body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, shado args = ($(wrappedexexpand...),) tt′ = Enzyme.vaTypeof(args...) FT = Core.Typeof(f) - rev_with_return(Val($Width), Val(ActivityTup[1]), Val(concat($(modbetween...))), Val(concat($(lengths...))), FT, tt′, f, df, tape0, ($(shadowsplat...),), args...) + rev_with_return(runtimeActivity, Val($Width), Val(ActivityTup[1]), Val(concat($(modbetween...))), Val(concat($(lengths...))), FT, tt′, f, df, tape0, ($(shadowsplat...),), args...) return nothing end end @@ -1165,13 +1165,13 @@ function func_runtime_iterate_rev(N, Width) body = body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, batchshadowargs, active_refs) quote - function runtime_iterate_rev(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, $(allargs...)) where {ActivityTup, MB, TapeType, F, DF, $(typeargs...)} + function runtime_iterate_rev(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{$Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, $(allargs...)) where {ActivityTup, RuntimeActivity, MB, TapeType, F, DF, $(typeargs...)} $body end end end -@generated function runtime_iterate_rev(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, TapeType, F, DF} +@generated function runtime_iterate_rev(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, allargs...) where {ActivityTup, RuntimeActivity, MB, Width, TapeType, F, DF} N = div(length(allargs)+2, Width+1)-1 primargs, _, primtypes, _, _, wrapped, batchshadowargs, modbetween, active_refs = setup_macro_wraps(false, N, Width, :allargs, #=iterate=#true; reverse=true) return body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, batchshadowargs, active_refs) @@ -1187,7 +1187,7 @@ for (N, Width) in Iterators.product(0:30, 1:10) eval(func_runtime_iterate_rev(N, Width)) end -function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, lookup; sret=nothing, tape=nothing, firstconst=false, endcast=true, firstconst_after_tape=true) +function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, lookup; sret=nothing, tape=nothing, firstconst=false, endcast=true, firstconst_after_tape=true, runtime_activity=true) width = get_width(gutils) mode = get_mode(gutils) mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) @@ -1234,7 +1234,7 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, if lookup inverted = lookup_value(gutils, inverted, B) end - if API.runtimeActivity() + if get_runtime_activity(gutils) inv_0 = if width == 1 inverted else @@ -1295,6 +1295,9 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, end pushfirst!(vals, unsafe_to_llvm(B, Val(Int(width)))) + if runtime_activity + pushfirst!(vals, unsafe_to_llvm(B, Val(get_runtime_activity(gutils)))) + end etup0 = emit_tuple!(B, ActivityList) etup = emit_apply_type!(B, Base.Val, [etup0]) if isa(etup, LLVM.Instruction) diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index 962a4f46afd..965114447cf 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -335,7 +335,7 @@ end GPUCompiler.@safe_warn "TODO forward zero-set of arraycopy used memset rather than runtime type $btstr" LLVM.memset!(B, get_array_data(B, shadowres), LLVM.ConstantInt(i8, 0, false), length, algn) end - if API.runtimeActivity() + if get_runtime_activity(gutils) prev = new_from_original(gutils, orig) shadowres = LLVM.select!(B, LLVM.icmp!(B, LLVM.API.LLVMIntNE, shadowin, new_from_original(gutils, origops[1])), shadowres, prev) API.moveBefore(prev, shadowres, B) @@ -358,7 +358,7 @@ end GPUCompiler.@safe_warn "TODO forward zero-set of arraycopy used memset rather than runtime type $btstr" LLVM.memset!(B, get_array_data(B, callv), LLVM.ConstantInt(i8, 0, false), length, algn) end - if API.runtimeActivity() + if get_runtime_activity(gutils) prev = new_from_original(gutils, orig) callv = LLVM.select!(B, LLVM.icmp!(B, LLVM.API.LLVMIntNE, ev, new_from_original(gutils, origops[1])), callv, prev) if idx == 1 @@ -1094,7 +1094,7 @@ end else extract_value!(B, shadowin, idx-1) end - if API.runtimeActivity() + if get_runtime_activity(gutils) emit_error(B, orig, "Enzyme: Not yet implemented runtime activity for reverse of jl_array_del_end") end args = LLVM.Value[anti, offset] diff --git a/src/rules/parallelrules.jl b/src/rules/parallelrules.jl index 54208fe21cc..29648389474 100644 --- a/src/rules/parallelrules.jl +++ b/src/rules/parallelrules.jl @@ -1,9 +1,9 @@ -function runtime_newtask_fwd(world::Val{World}, fn::FT1, dfn::FT2, post::Any, ssize::Int, ::Val{width}) where {FT1, FT2, World, width} +function runtime_newtask_fwd(world::Val{World}, fn::FT1, dfn::FT2, post::Any, ssize::Int, runtimeActivity::Val{RuntimeActivity}, ::Val{width}) where {FT1, FT2, World, width, RuntimeActivity} FT = Core.Typeof(fn) ghos = guaranteed_const(FT) opt_mi = world - forward = thunk(opt_mi, (ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ForwardMode), Val(width), Val((false,)), #=returnPrimal=#Val(true), #=shadowinit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false)) + forward = thunk(opt_mi, (ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ForwardMode), Val(width), Val((false,)), #=returnPrimal=#Val(true), #=shadowinit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity) ft = ghos ? Const(fn) : Duplicated(fn, dfn) function fclosure() res = forward(ft) @@ -13,12 +13,12 @@ function runtime_newtask_fwd(world::Val{World}, fn::FT1, dfn::FT2, post::Any, ss return ccall(:jl_new_task, Ref{Task}, (Any, Any, Int), fclosure, post, ssize) end -function runtime_newtask_augfwd(world::Val{World}, fn::FT1, dfn::FT2, post::Any, ssize::Int, ::Val{width}, ::Val{ModifiedBetween}) where {FT1, FT2, World, width, ModifiedBetween} +function runtime_newtask_augfwd(world::Val{World}, fn::FT1, dfn::FT2, post::Any, ssize::Int, runtimeActivity::Val{RuntimeActivity}, ::Val{width}, ::Val{ModifiedBetween}) where {FT1, FT2, World, width, ModifiedBetween, RuntimeActivity} # TODO make this AD subcall type stable FT = Core.Typeof(fn) ghos = guaranteed_const(FT) opt_mi = world - forward, adjoint = thunk(opt_mi, (ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ReverseModePrimal), Val(width), Val(ModifiedBetween), #=returnPrimal=#Val(true), #=shadowinit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false)) + forward, adjoint = thunk(opt_mi, (ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ReverseModePrimal), Val(width), Val(ModifiedBetween), #=returnPrimal=#Val(true), #=shadowinit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity) ft = ghos ? Const(fn) : Duplicated(fn, dfn) taperef = Ref{Any}() @@ -189,7 +189,7 @@ end if mode == API.DEM_ForwardMode if fwdmodenm === nothing etarget = Compiler.EnzymeTarget() - eparams = Compiler.EnzymeCompilerParams(Tuple{(dupClosure ? Duplicated : Const){funcT}, e_tt.parameters...}, API.DEM_ForwardMode, width, Const{Nothing}, #=runEnzyme=#true, #=abiwrap=#true, modifiedBetween, #=returnPrimal=#false, #=shadowInit=#false, UnknownTapeType, FFIABI, #=ErrIfFuncWritten=#false) + eparams = Compiler.EnzymeCompilerParams(Tuple{(dupClosure ? Duplicated : Const){funcT}, e_tt.parameters...}, API.DEM_ForwardMode, width, Const{Nothing}, #=runEnzyme=#true, #=abiwrap=#true, modifiedBetween, #=returnPrimal=#false, #=shadowInit=#false, UnknownTapeType, FFIABI, #=ErrIfFuncWritten=#false, get_runtime_activity(gutils)) ejob = Compiler.CompilerJob(mi2, CompilerConfig(etarget, eparams; kernel=false), world) cmod, fwdmodenm, _, _ = _thunk(ejob, #=postopt=#false) @@ -220,7 +220,7 @@ end if augfwdnm === nothing || adjointnm === nothing etarget = Compiler.EnzymeTarget() # TODO modifiedBetween - eparams = Compiler.EnzymeCompilerParams(Tuple{(dupClosure ? Duplicated : Const){funcT}, e_tt.parameters...}, API.DEM_ReverseModePrimal, width, Const{Nothing}, #=runEnzyme=#true, #=abiwrap=#true, modifiedBetween, #=returnPrimal=#false, #=shadowInit=#false, UnknownTapeType, FFIABI, #=ErrIfFuncWritten=#false) + eparams = Compiler.EnzymeCompilerParams(Tuple{(dupClosure ? Duplicated : Const){funcT}, e_tt.parameters...}, API.DEM_ReverseModePrimal, width, Const{Nothing}, #=runEnzyme=#true, #=abiwrap=#true, modifiedBetween, #=returnPrimal=#false, #=shadowInit=#false, UnknownTapeType, FFIABI, #=ErrIfFuncWritten=#false, get_runtime_activity(gutils)) ejob = Compiler.CompilerJob(mi2, CompilerConfig(etarget, eparams; kernel=false), world) cmod, adjointnm, augfwdnm, TapeType = _thunk(ejob, #=postopt=#false) @@ -505,6 +505,7 @@ end invert_pointer(gutils, ops[1], B), new_from_original(gutils, ops[2]), (sizeof(Int) == sizeof(Int64) ? emit_box_int64! : emit_box_int32!)(B, new_from_original(gutils, ops[3])), + unsafe_to_llvm(B, Val(get_runtime_activity(gutils))), unsafe_to_llvm(B, Val(width)), ] @@ -555,7 +556,7 @@ end invert_pointer(gutils, ops[1], B), new_from_original(gutils, ops[2]), (sizeof(Int) == sizeof(Int64) ? emit_box_int64! : emit_box_int32!)(B, new_from_original(gutils, ops[3])), - unsafe_to_llvm(B, Val(width)), + unsafe_to_llvm(B, Val(get_runtime_activity(gutils))), unsafe_to_llvm(B, Val(width)), unsafe_to_llvm(B, Val(ModifiedBetween)), ] diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index 5372a677265..6117e464d81 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -326,7 +326,7 @@ function common_newstructv_augfwd(offset, B, orig, gutils, normalR, shadowR, tap width = get_width(gutils) - sret = generic_setup(orig, runtime_newstruct_augfwd, width == 1 ? Any : AnyArray(Int(width)), gutils, #=start=#offset, B, false; firstconst=true, endcast = false, firstconst_after_tape=true) + sret = generic_setup(orig, runtime_newstruct_augfwd, width == 1 ? Any : AnyArray(Int(width)), gutils, #=start=#offset, B, false; firstconst=true, endcast = false, firstconst_after_tape=true, runtime_activity=false) if width == 1 shadow = sret @@ -370,7 +370,7 @@ function common_newstructv_rev(offset, B, orig, gutils, tape) if !newstruct_common(#=fwd=#false, #=run=#false, offset, B, orig, gutils, #=normalR=#nothing, #=shadowR=#nothing) @assert tape !== C_NULL width = get_width(gutils) - generic_setup(orig, runtime_newstruct_rev, Nothing, gutils, #=start=#offset, B, true; firstconst=true, tape, firstconst_after_tape=true) + generic_setup(orig, runtime_newstruct_rev, Nothing, gutils, #=start=#offset, B, true; firstconst=true, tape, firstconst_after_tape=true, runtime_activity=false) end return nothing @@ -399,7 +399,7 @@ function common_f_tuple_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) width = get_width(gutils) - sret = generic_setup(orig, runtime_tuple_augfwd, width == 1 ? Any : AnyArray(Int(width)), gutils, #=start=#offset+1, B, false; endcast = false) + sret = generic_setup(orig, runtime_tuple_augfwd, width == 1 ? Any : AnyArray(Int(width)), gutils, #=start=#offset+1, B, false; endcast = false, runtime_activity=false) if width == 1 shadow = sret @@ -465,7 +465,7 @@ function common_f_tuple_rev(offset, B, orig, gutils, tape) else tape end - generic_setup(orig, runtime_tuple_rev, Nothing, gutils, #=start=#offset+1, B, true; tape=tape2) + generic_setup(orig, runtime_tuple_rev, Nothing, gutils, #=start=#offset+1, B, true; tape=tape2, runtime_activity=false) end return nothing end diff --git a/test/Project.toml b/test/Project.toml index a3f84527121..3ce8fc645c2 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -24,4 +24,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] Aqua = "0.8" -EnzymeTestUtils = "0.1.4" +EnzymeTestUtils = "0.1.4, 0.2" diff --git a/test/ext/.chainrulescore.jl.swp b/test/ext/.chainrulescore.jl.swp new file mode 100644 index 0000000000000000000000000000000000000000..94b31875e128106c0c92d4e28639ebc06d4e3451 GIT binary patch literal 12288 zcmeI2&u<$=6vrpMP@ts<4i&dYe85{A{|b?4R99(3f{;oMsNqtoXl&2A3-+$NJ67x- zhy;H@4^W`M0d55$aYLe9;D)$!;m83YAt51gtoY8%I=gn_G;J?LdMkZ)*R$`vdGEV# zY(<%7_088_r?cg%!11&YU*B5xfBJjp*r^k9V!Ib5DlrcZKAhb`Pqvy-^iDsH>g70+ zy>bw06^F7r_I}qHRyu*Mtc5p5Jym|YThS`f6*bhA)-@S~@t~`cRur2@V?VyK0<6Fz zDlpLg!pkQ&7wa{<)C=>^)3eWPK4MUIX9ZXRR)7^?1y})AfE8c`Sb?LYfbP$U9VC1# zP5NAVotwI*ANj%xumY?AE5Hh{0;~WlzzVPetN<&(3a|o4PyyK%;^t#Qe0&_q6AN##eJ+nEJH$IoX27 zKI7{leT^9ov$Jr^kK%0^w~I*>>k?g#70$&W?uMQxtQk%PdjnGpQxH<|l|jE7w4_$; z5?z$KbuAmndsz>J+~bp(Z$ukPwKJ#OoD=z7l!+?4B&Q~H`L0G`q9;;5q`OPtL4#xk z?C&{i^33m`d1rY~Rh|zq^(d#rtYG`6l8qiB#MGotimAavgvsV~5wg#4&Nakjo1328 zJ7Pv>Rs1mB?fOcXu;wVxL-cgInRg}V^|8vfe6xv{^r){Qzqs4i9?^wxBgH4}SFesAS(_D8WP+tN-!$uw+Pf=qo|>wIMe3BZcI3AM-!G16;XL_^MH_bH znb00J8Vxf=7Bn>;rpEX#fPKz*DK$Q*Du2;4)mV_W4ZdpjFTY{-FRz;Y%PUDsc4fR} z#XwS+X9Nw3OJ|}`6^EE`%NCnc{5y`kS=&4X+8#d;le6@bvWj?b) zYv5O%1WUe?;J~Wt9<6sHsTb?_uc(e1^5vw(*kN0V4N;(^uUdg0v?vKWq11y|nL_p{ zrS)aj!ib- Op22j$L`rX6Xz?G6bcs*^ literal 0 HcmV?d00001 diff --git a/test/kwrrules.jl b/test/kwrrules.jl index f5b9d2338a1..34749e9baad 100644 --- a/test/kwrrules.jl +++ b/test/kwrrules.jl @@ -11,7 +11,7 @@ end import .EnzymeRules: augmented_primal, reverse using .EnzymeRules -function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(f_kw)}, ::Type{<:Active}, x::Active; kwargs...) +function augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(f_kw)}, ::Type{<:Active}, x::Active; kwargs...) @show kwargs @assert length(overwritten(config)) == 2 if needs_primal(config) @@ -21,7 +21,7 @@ function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(f_kw)}, ::T end end -function reverse(config::ConfigWidth{1}, ::Const{typeof(f_kw)}, dret::Active, tape, x::Active; kwargs...) +function reverse(config::RevConfigWidth{1}, ::Const{typeof(f_kw)}, dret::Active, tape, x::Active; kwargs...) @show kwargs # TODO do we want them here? @assert length(overwritten(config)) == 2 if needs_primal(config) @@ -43,7 +43,7 @@ function f_kw2(x; kwargs...) x^2 end -function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(f_kw2)}, ::Type{<:Active}, x::Active) +function augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(f_kw2)}, ::Type{<:Active}, x::Active) if needs_primal(config) return AugmentedReturn(func.val(x.val), nothing, nothing) else @@ -51,7 +51,7 @@ function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(f_kw2)}, :: end end -function reverse(config::ConfigWidth{1}, ::Const{typeof(f_kw2)}, dret::Active, tape, x::Active) +function reverse(config::RevConfigWidth{1}, ::Const{typeof(f_kw2)}, dret::Active, tape, x::Active) if needs_primal(config) return (10+2*x.val*dret.val,) else @@ -68,7 +68,7 @@ function f_kw3(x; val=nothing) x^2 end -function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(f_kw3)}, ::Type{<:Active}, x::Active; dval=nothing) +function augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(f_kw3)}, ::Type{<:Active}, x::Active; dval=nothing) if needs_primal(config) return AugmentedReturn(func.val(x.val), nothing, nothing) else @@ -76,7 +76,7 @@ function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(f_kw3)}, :: end end -function reverse(config::ConfigWidth{1}, ::Const{typeof(f_kw3)}, dret::Active, tape, x::Active; dval=nothing) +function reverse(config::RevConfigWidth{1}, ::Const{typeof(f_kw3)}, dret::Active, tape, x::Active; dval=nothing) if needs_primal(config) return (10+2*x.val*dret.val,) else @@ -92,7 +92,7 @@ function f_kw4(x; y=2.0) x*y end -function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(f_kw4)}, ::Type{<:Active}, x::Active; y) +function augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(f_kw4)}, ::Type{<:Active}, x::Active; y) @assert length(overwritten(config)) == 2 if needs_primal(config) return AugmentedReturn(func.val(x.val), nothing, nothing) @@ -101,7 +101,7 @@ function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(f_kw4)}, :: end end -function reverse(config::ConfigWidth{1}, ::Const{typeof(f_kw4)}, dret::Active, tape, x::Active; y) +function reverse(config::RevConfigWidth{1}, ::Const{typeof(f_kw4)}, dret::Active, tape, x::Active; y) @assert length(overwritten(config)) == 2 return (1000*y+2*x.val*dret.val,) end @@ -126,7 +126,7 @@ function wrapclos(cl, x) cl(x; width=9) end -function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{Closure2}, +function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Const{Closure2}, ::Type{<:Active}, args::Vararg{Active,N}; width=7) where {N} vec = copy(func.val.v) pval = func.val(args[1].val) @@ -138,7 +138,7 @@ function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{Closur return AugmentedReturn(primal, nothing, vec) end -function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{Closure2}, +function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{Closure2}, dret::Active, tape, args::Vararg{Active,N}; width=7) where {N} dargs = ntuple(Val(N)) do i 7 * args[1].val * dret.val + tape[1] * 1000 + width * 100000 diff --git a/test/kwrules.jl b/test/kwrules.jl index 91d3dc859dc..9761c235102 100644 --- a/test/kwrules.jl +++ b/test/kwrules.jl @@ -10,7 +10,7 @@ function f_kw(x; kwargs...) x^2 end -function forward(::Const{typeof(f_kw)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated; kwargs...) +function forward(config, ::Const{typeof(f_kw)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated; kwargs...) return 10+2*x.val*x.dval end @@ -25,7 +25,7 @@ function f_kw2(x; kwargs...) x^2 end -function forward(::Const{typeof(f_kw2)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) +function forward(config, ::Const{typeof(f_kw2)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) return 10+2*x.val*x.dval end @@ -37,7 +37,7 @@ function f_kw3(x; val=nothing) x^2 end -function forward(::Const{typeof(f_kw3)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated; dval=nothing) +function forward(config, ::Const{typeof(f_kw3)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated; dval=nothing) return 10+2*x.val*x.dval end @@ -49,7 +49,7 @@ function f_kw4(x; y=2.0) x*y end -function forward(::Const{typeof(f_kw4)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated; y) +function forward(config, ::Const{typeof(f_kw4)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated; y) return 1000*y+2*x.val*x.dval end diff --git a/test/mixedrrule.jl b/test/mixedrrule.jl index 32407f3c12a..db1d3ab251f 100644 --- a/test/mixedrrule.jl +++ b/test/mixedrrule.jl @@ -17,7 +17,7 @@ function mixouter(x, y) return res end -function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{typeof(mixfnc)}, +function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(mixfnc)}, ::Type{<:Active}, tup::MixedDuplicated{Tuple{Float64, Vector{Float64}}}) pval = func.val(tup.val) vec = copy(tup.val[2]) @@ -29,7 +29,7 @@ function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{typeof return AugmentedReturn(primal, nothing, vec) end -function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{typeof(mixfnc)}, +function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(mixfnc)}, dret::Active, tape, tup::MixedDuplicated{Tuple{Float64, Vector{Float64}}}) prev = tup.dval[] tup.dval[] = (7 * tape[1] * dret.val, prev[2]) @@ -57,7 +57,7 @@ function recmixouter(x, y, z) return res end -function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{typeof(recmixfnc)}, +function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(recmixfnc)}, ::Type{<:Active}, tup) pval = func.val(tup.val) vec = copy(tup.val[2]) @@ -76,7 +76,7 @@ end return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState end -function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{typeof(recmixfnc)}, +function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(recmixfnc)}, dret::Active, tape, tup) prev = tup.dval[] dRT = typeof(prev) diff --git a/test/rrules.jl b/test/rrules.jl index 6c2a965b0e0..cd41b497166 100644 --- a/test/rrules.jl +++ b/test/rrules.jl @@ -15,7 +15,7 @@ end import .EnzymeRules: augmented_primal, reverse, Annotation, has_rrule_from_sig using .EnzymeRules -function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(f)}, ::Type{<:Active}, x::Active) +function augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(f)}, ::Type{<:Active}, x::Active) if needs_primal(config) return AugmentedReturn(func.val(x.val), nothing, nothing) else @@ -23,7 +23,7 @@ function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(f)}, ::Type end end -function reverse(config::ConfigWidth{1}, ::Const{typeof(f)}, dret::Active, tape, x::Active) +function reverse(config::RevConfigWidth{1}, ::Const{typeof(f)}, dret::Active, tape, x::Active) if needs_primal(config) return (10+2*x.val*dret.val,) else @@ -31,13 +31,13 @@ function reverse(config::ConfigWidth{1}, ::Const{typeof(f)}, dret::Active, tape, end end -function augmented_primal(::Config{false, false, 1}, func::Const{typeof(f_ip)}, ::Type{<:Const}, x::Duplicated) +function augmented_primal(::RevConfig{false, false, 1}, func::Const{typeof(f_ip)}, ::Type{<:Const}, x::Duplicated) v = x.val[1] x.val[1] *= v return AugmentedReturn(nothing, nothing, v) end -function reverse(::Config{false, false, 1}, ::Const{typeof(f_ip)}, ::Type{<:Const}, tape, x::Duplicated) +function reverse(::RevConfig{false, false, 1}, ::Const{typeof(f_ip)}, ::Type{<:Const}, tape, x::Duplicated) x.dval[1] = 100 + x.dval[1] * tape return (nothing,) end @@ -107,7 +107,7 @@ end end q(x) = x^2 -function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(q)}, ::Type{<:Active}, x::Active) +function augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(q)}, ::Type{<:Active}, x::Active) tape = (Ref(2.0), Ref(3.4)) if needs_primal(config) return AugmentedReturn(func.val(x.val), nothing, tape) @@ -116,7 +116,7 @@ function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(q)}, ::Type end end -function reverse(config::ConfigWidth{1}, ::Const{typeof(q)}, dret::Active, tape, x::Active) +function reverse(config::RevConfigWidth{1}, ::Const{typeof(q)}, dret::Active, tape, x::Active) @test tape[1][] == 2.0 @test tape[2][] == 3.4 if needs_primal(config) @@ -133,7 +133,7 @@ end foo(x::Complex) = 2x function EnzymeRules.augmented_primal( - config::EnzymeRules.ConfigWidth{1}, + config::EnzymeRules.RevConfigWidth{1}, func::Const{typeof(foo)}, ::Type{<:Active}, x @@ -154,7 +154,7 @@ function EnzymeRules.augmented_primal( end function EnzymeRules.reverse( - config::EnzymeRules.ConfigWidth{1}, + config::EnzymeRules.RevConfigWidth{1}, func::Const{typeof(foo)}, dret, tape, @@ -177,7 +177,7 @@ function _dot(X::StridedArray{T}, Y::StridedArray{T}) where {T<:Union{Real,Compl end function augmented_primal( - config::ConfigWidth{1}, + config::RevConfigWidth{1}, func::Const{typeof(_dot)}, ::Type{<:Union{Const,Active}}, X::Duplicated{<:StridedArray{T}}, @@ -191,7 +191,7 @@ function augmented_primal( end function reverse( - ::ConfigWidth{1}, + ::RevConfigWidth{1}, ::Const{typeof(_dot)}, dret::Union{Active,Type{<:Const}}, tape, @@ -235,7 +235,7 @@ function cprimal(x0, y0) return @inbounds x[1] end -function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{typeof(cmyfunc!)}, ::Type{<:Const}, +function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(cmyfunc!)}, ::Type{<:Const}, y::Duplicated, x::Duplicated) cmyfunc!(y.val, x.val) tape = (copy(x.val), 3) @@ -243,7 +243,7 @@ function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{typeof end const seen = Set() -function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{typeof(cmyfunc!)}, ::Type{<:Const}, tape, +function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(cmyfunc!)}, ::Type{<:Const}, tape, y::Duplicated, x::Duplicated) xval = tape[1] p = pointer(xval) @@ -265,7 +265,7 @@ function remultr(arg) arg * arg end -function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{typeof(remultr)}, +function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(remultr)}, ::Type{<:Active}, args::Vararg{Active,N}) where {N} primal = if EnzymeRules.needs_primal(config) func.val(args[1].val) @@ -275,7 +275,7 @@ function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{typeof return AugmentedReturn(primal, nothing, nothing) end -function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{typeof(remultr)}, +function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(remultr)}, dret::Active, tape, args::Vararg{Active,N}) where {N} dargs = ntuple(Val(N)) do i @@ -315,7 +315,7 @@ function (cl::Closure)(x) end -function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{Closure}, +function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Const{Closure}, ::Type{<:Active}, args::Vararg{Active,N}) where {N} vec = copy(func.val.v) pval = func.val(args[1].val) @@ -327,7 +327,7 @@ function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{Closur return AugmentedReturn(primal, nothing, vec) end -function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{Closure}, +function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{Closure}, dret::Active, tape, args::Vararg{Active,N}) where {N} dargs = ntuple(Val(N)) do i 7 * args[1].val * dret.val + tape[1] * 1000 @@ -377,7 +377,7 @@ end unstabletape(x) = x^2 -function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(unstabletape)}, ::Type{<:Active}, x::Active) +function augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(unstabletape)}, ::Type{<:Active}, x::Active) tape = if x.val < 3 400 else @@ -390,7 +390,7 @@ function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(unstabletap end end -function reverse(config::ConfigWidth{1}, ::Const{typeof(unstabletape)}, dret, tape, x::Active{T}) where T +function reverse(config::RevConfigWidth{1}, ::Const{typeof(unstabletape)}, dret, tape, x::Active{T}) where T return (T(tape)::T,) end diff --git a/test/ruleinvalidation.jl b/test/ruleinvalidation.jl index 87c52861cc3..62579e2415f 100644 --- a/test/ruleinvalidation.jl +++ b/test/ruleinvalidation.jl @@ -11,25 +11,25 @@ call_issue696(args...) = issue696(args...) @test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 # should invalidate cache for the previous result -forward(::Const{typeof(issue696)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) = +forward(config, ::Const{typeof(issue696)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) = 10+2*x.val*x.dval -forward(func::Const{typeof(issue696)}, ::Type{<:Duplicated}, x::Duplicated) = +forward(config, func::Const{typeof(issue696)}, ::Type{<:Duplicated}, x::Duplicated) = Duplicated(func.val(x.val), 10+2*x.val*x.dval) @test autodiff(Forward, issue696, Duplicated(1.0, 1.0))[1] ≈ 12.0 @test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 12.0 # should invalidate cache for the previous result again -forward(::Const{typeof(issue696)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) = +forward(config, ::Const{typeof(issue696)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) = 20+2*x.val*x.dval -forward(func::Const{typeof(issue696)}, ::Type{<:Duplicated}, x::Duplicated) = +forward(config, func::Const{typeof(issue696)}, ::Type{<:Duplicated}, x::Duplicated) = Duplicated(func.val(x.val), 20+2*x.val*x.dval) @test autodiff(Forward, issue696, Duplicated(1.0, 1.0))[1] ≈ 22.0 @test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 22.0 # check that `Base.delete_method` works as expected -for m in methods(forward, Tuple{Const{typeof(issue696)},Vararg{Any}}) +for m in methods(forward, Tuple{Any,Const{typeof(issue696)},Vararg{Any}}) Base.delete_method(m) end @test autodiff(Forward, issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 diff --git a/test/rules.jl b/test/rules.jl index b6644d8c55a..0ef2e0fe8e8 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -4,7 +4,7 @@ using Enzyme using Enzyme: EnzymeRules using Test -import .EnzymeRules: forward, Annotation, has_frule_from_sig +import .EnzymeRules: forward, Annotation, has_frule_from_sig, FwdConfig f(x) = x^2 @@ -13,23 +13,23 @@ function f_ip(x) return nothing end -function forward(::Const{typeof(f)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) +function forward(config, ::Const{typeof(f)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) return 10+2*x.val*x.dval end -function forward(::Const{typeof(f)}, ::Type{<:BatchDuplicatedNoNeed}, x::BatchDuplicated{T, N}) where {T, N} +function forward(config, ::Const{typeof(f)}, ::Type{<:BatchDuplicatedNoNeed}, x::BatchDuplicated{T, N}) where {T, N} return NTuple{N, T}(1000+2*x.val*dv for dv in x.dval) end -function forward(func::Const{typeof(f)}, ::Type{<:Duplicated}, x::Duplicated) +function forward(config, func::Const{typeof(f)}, ::Type{<:Duplicated}, x::Duplicated) return Duplicated(func.val(x.val), 100+2*x.val*x.dval) end -function forward(func::Const{typeof(f)}, ::Type{<:BatchDuplicated}, x::BatchDuplicated{T, N}) where {T,N} +function forward(config, func::Const{typeof(f)}, ::Type{<:BatchDuplicated}, x::BatchDuplicated{T, N}) where {T,N} return BatchDuplicated(func.val(x.val), NTuple{N, T}(10000+2*x.val*dv for dv in x.dval)) end -function forward(::Const{Core.typeof(f_ip)}, ::Type{<:Const}, x::Duplicated) +function forward(config, ::Const{Core.typeof(f_ip)}, ::Type{<:Const}, x::Duplicated) ld = x.val[1] x.val[1] *= ld x.dval[1] *= 2 * ld + 10 @@ -38,7 +38,7 @@ end function has_frule(f, @nospecialize(RT), @nospecialize(TT::Type{<:Tuple}); world=Base.get_world_counter()) TT = Base.unwrap_unionall(TT) - TT = Tuple{<:Annotation{Core.typeof(f)}, Type{<:RT}, TT.parameters...} + TT = Tuple{<:FwdConfig, <:Annotation{Core.typeof(f)}, Type{<:RT}, TT.parameters...} EnzymeRules.isapplicable(forward, TT; world) end @@ -82,7 +82,7 @@ end end g(x) = x ^ 2 -function forward(func::Const{typeof(g)}, ::Type{<:Const}, x::Const) +function forward(config, func::Const{typeof(g)}, ::Type{<:Const}, x::Const) return Const(g(x.val)) end @@ -107,11 +107,11 @@ function h2(x) y * y end -function forward(func::Const{typeof(alloc_sq)}, ::Type{<:Duplicated}, x::Duplicated) +function forward(config, func::Const{typeof(alloc_sq)}, ::Type{<:Duplicated}, x::Duplicated) return Duplicated(Ref(x.val*x.val), Ref(10*2*x.val*x.dval)) end -function forward(func::Const{typeof(alloc_sq)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) +function forward(config, func::Const{typeof(alloc_sq)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) return Ref(1000*2*x.val*x.dval) end @@ -123,7 +123,7 @@ function h3(x) alloc_sq2(x)[] end -function forward(func::Const{typeof(alloc_sq2)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) +function forward(config, func::Const{typeof(alloc_sq2)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) return Duplicated(Ref(0.0), Ref(1000*2*x.val*x.dval)) end @@ -136,7 +136,7 @@ end foo(x) = 2x; -function EnzymeRules.forward( +function EnzymeRules.forward(config, func::Const{typeof(foo)}, RT::Type{<:Union{Duplicated,BatchDuplicated}}, x::Union{Duplicated,BatchDuplicated}, diff --git a/test/runtests.jl b/test/runtests.jl index 6ffd3dd09c5..bdda7604bf2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -141,10 +141,10 @@ end @test Enzyme.Compiler.active_reg_inner(Tuple, (), nothing, #=justactive=#Val(false), #=unionsret=#Val(false), #=abstractismixed=#Val(true)) == Enzyme.Compiler.MixedState @test Enzyme.Compiler.active_reg_inner(Tuple{A,A} where A, (), nothing, #=justactive=#Val(false), #=unionsret=#Val(false), #=abstractismixed=#Val(true)) == Enzyme.Compiler.MixedState world = codegen_world_age(typeof(f0), Tuple{Float64}) - thunk_a = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Active, Tuple{Active{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false)) - thunk_b = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Const, Tuple{Const{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false)) - thunk_c = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Active{Float64}, Tuple{Active{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false)) - thunk_d = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Active{Float64}, Tuple{Active{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false)) + thunk_a = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Active, Tuple{Active{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false), Val(false)) + thunk_b = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Const, Tuple{Const{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false), Val(false)) + thunk_c = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Active{Float64}, Tuple{Active{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false), Val(false)) + thunk_d = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Active{Float64}, Tuple{Active{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false), Val(false)) @test thunk_a.adjoint !== thunk_b.adjoint @test thunk_c.adjoint === thunk_a.adjoint @test thunk_c.adjoint === thunk_d.adjoint @@ -153,7 +153,7 @@ end @test thunk_a(Const(f0), Active(2.0), 2.0) == ((2.0,),) @test thunk_b(Const(f0), Const(2.0)) === ((nothing,),) - forward, pullback = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Active, Tuple{Active{Float64}}, Val(Enzyme.API.DEM_ReverseModeGradient), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false)) + forward, pullback = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Active, Tuple{Active{Float64}}, Val(Enzyme.API.DEM_ReverseModeGradient), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false), Val(false)) @test forward(Const(f0), Active(2.0)) == (nothing,nothing,nothing) @test pullback(Const(f0), Active(2.0), 1.0, nothing) == ((1.0,),) @@ -164,7 +164,7 @@ end d = Duplicated([3.0, 5.0], [0.0, 0.0]) world = codegen_world_age(typeof(mul2), Tuple{Vector{Float64}}) - forward, pullback = Enzyme.Compiler.thunk(Val(world), Const{typeof(mul2)}, Active, Tuple{Duplicated{Vector{Float64}}}, Val(Enzyme.API.DEM_ReverseModeGradient), Val(1), Val((false, true)), Val(false), Val(false), DefaultABI, Val(false)) + forward, pullback = Enzyme.Compiler.thunk(Val(world), Const{typeof(mul2)}, Active, Tuple{Duplicated{Vector{Float64}}}, Val(Enzyme.API.DEM_ReverseModeGradient), Val(1), Val((false, true)), Val(false), Val(false), DefaultABI, Val(false), Val(false)) res = forward(Const(mul2), d) @test typeof(res[1]) == Tuple{Float64, Float64} pullback(Const(mul2), d, 1.0, res[1]) @@ -173,7 +173,7 @@ end d = Duplicated([3.0, 5.0], [0.0, 0.0]) world = codegen_world_age(typeof(vrec), Tuple{Int, Vector{Float64}}) - forward, pullback = Enzyme.Compiler.thunk(Val(world), Const{typeof(vrec)}, Active, Tuple{Const{Int}, Duplicated{Vector{Float64}}}, Val(Enzyme.API.DEM_ReverseModeGradient), Val(1), Val((false, false, true)), Val(false), Val(false), DefaultABI, Val(false)) + forward, pullback = Enzyme.Compiler.thunk(Val(world), Const{typeof(vrec)}, Active, Tuple{Const{Int}, Duplicated{Vector{Float64}}}, Val(Enzyme.API.DEM_ReverseModeGradient), Val(1), Val((false, false, true)), Val(false), Val(false), DefaultABI, Val(false), Val(false)) res = forward(Const(vrec), Const(Int(1)), d) pullback(Const(vrec), Const(1), d, 1.0, res[1]) @test d.dval[1] ≈ 5.0 @@ -1225,8 +1225,7 @@ end # Technically this test doesn't need runtimeactivity since the closure combo of active itr1 and const data # doesn't use any of the const data values, but now that we error for activity confusion, we need to # mark runtimeActivity to let this pass - Enzyme.API.runtimeActivity!(true) - Enzyme.autodiff(Enzyme.Reverse, Const(smallrf), Enzyme.Duplicated(weights, dweights), Enzyme.Const(data)) + Enzyme.autodiff(set_runtime_activity(Enzyme.Reverse), Const(smallrf), Enzyme.Duplicated(weights, dweights), Enzyme.Const(data)) @test dweights[1] ≈ 1. function invokesum(weights::Vector{Float64}, data::Vector{Float64})::Float64 @@ -1244,8 +1243,7 @@ end weights = [0.2, 0.8] dweights = [0.0, 0.0] - Enzyme.autodiff(Enzyme.Reverse, invokesum, Enzyme.Duplicated(weights, dweights), Enzyme.Const(data)) - Enzyme.API.runtimeActivity!(false) + Enzyme.autodiff(set_runtime_activity(Enzyme.Reverse), invokesum, Enzyme.Duplicated(weights, dweights), Enzyme.Const(data)) @test dweights[1] ≈ 20. @test dweights[2] ≈ 20. end @@ -1388,9 +1386,7 @@ end @testset "AbstractType calling convention" begin # TODO get rid of runtime activity - Enzyme.API.runtimeActivity!(true) - @test 1.0 ≈ Enzyme.autodiff(Reverse, dxdt_pred, Active(1.0))[1][1] - Enzyme.API.runtimeActivity!(false) + @test 1.0 ≈ Enzyme.autodiff(set_runtime_activity(Reverse), dxdt_pred, Active(1.0))[1][1] end function fillsum(x) @@ -1424,11 +1420,9 @@ function rtg_f(V,@nospecialize(cv)) end @testset "RuntimeActivity generic call" begin - Enzyme.API.runtimeActivity!(true) - res = autodiff(Forward, rtg_f, Duplicated, Duplicated([0.2], [1.0]), Const(RTGData(3.14))) + res = autodiff(set_runtime_activity(Forward), rtg_f, Duplicated, Duplicated([0.2], [1.0]), Const(RTGData(3.14))) @test 3.14 ≈ res[1] @test 0.0 ≈ res[2] - Enzyme.API.runtimeActivity!(false) end @inline function myquantile(v::AbstractVector, p::Real; alpha) @@ -2523,14 +2517,11 @@ end @testset "Getfield with reference" begin - Enzyme.API.runtimeActivity!(true) - d = GFNamedDist((;a = GFNormal(0.0, 1.0), b = GFProductDist([GFUniform(0.0, 1.0), GFUniform(0.0, 1.0)]))) p = (a = 1.0, b = [0.5, 0.5]) dp = Enzyme.make_zero(p) GFlogpdf(d, p) - autodiff(Reverse, GFlogpdf, Active, Const(d), Duplicated(p, dp)) - Enzyme.API.runtimeActivity!(false) + autodiff(set_runtime_activity(Reverse), GFlogpdf, Active, Const(d), Duplicated(p, dp)) end @testset "BLAS" begin @@ -2630,6 +2621,7 @@ end @testset "Union i8" begin args = ( Val{(false, false, false)}, + Val(false), Val(1), Val((true, true, true)), Base.Val(NamedTuple{(Symbol("1"), Symbol("2"), Symbol("3")), Tuple{Any, Any, Any}}), @@ -2647,6 +2639,7 @@ end args2 = ( Val{(false, false, false)}, + Val(false), Val(1), Val((true, true, true)), Base.Val(NamedTuple{(Symbol("1"), Symbol("2"), Symbol("3")), Tuple{Any, Any, Any}}), @@ -2664,13 +2657,13 @@ end end @testset "Batched inactive" begin - augres = Enzyme.Compiler.runtime_generic_augfwd(Val{(false, false, false)}, Val(2), Val((true, true, true)), + augres = Enzyme.Compiler.runtime_generic_augfwd(Val{(false, false, false)}, Val(false), Val(2), Val((true, true, true)), Val(Enzyme.Compiler.AnyArray(2+Int(2))), ==, nothing, nothing, :foo, nothing, nothing, :bar, nothing, nothing) - Enzyme.Compiler.runtime_generic_rev(Val{(false, false, false)}, Val(2), Val((true, true, true)), augres[end], + Enzyme.Compiler.runtime_generic_rev(Val{(false, false, false)}, Val(false), Val(2), Val((true, true, true)), augres[end], ==, nothing, nothing, :foo, nothing, nothing, :bar, nothing, nothing) @@ -3566,9 +3559,7 @@ end fn(0.0) end - Enzyme.API.runtimeActivity!(true) - res = autodiff(Forward, Const(f2), Duplicated, Duplicated(0.2, 1.0)) - Enzyme.API.runtimeActivity!(false) + res = autodiff(set_runtime_activity(Forward), Const(f2), Duplicated, Duplicated(0.2, 1.0)) @test res[1] ≈ 0.2 # broken as the return of an apply generic is {primal, primal} # but since the return is abstractfloat doing the diff --git a/test/sc.jl b/test/sc.jl new file mode 100644 index 00000000000..69789082013 --- /dev/null +++ b/test/sc.jl @@ -0,0 +1,64 @@ +module ReverseRules + +using Enzyme +using Enzyme: EnzymeRules +using LinearAlgebra +using Test + +f(x) = x^2 + +function f_ip(x) + x[1] *= x[1] + return nothing +end + +import .EnzymeRules: augmented_primal, reverse, Annotation, has_rrule_from_sig +using .EnzymeRules + +Enzyme.API.printall!(true) + +struct Closure + v::Vector{Float64} +end + +function (cl::Closure)(x) + val = cl.v[1] * x + cl.v[1] = 0.0 + return val +end + + +function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{Closure}, + ::Type{<:Active}, args::Vararg{Active,N}) where {N} + vec = copy(func.val.v) + pval = func.val(args[1].val) + primal = if EnzymeRules.needs_primal(config) + pval + else + nothing + end + return AugmentedReturn(primal, nothing, vec) +end + +function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{Closure}, + dret::Active, tape, args::Vararg{Active,N}) where {N} + + @show tape + @show dret + @show args + dargs = ntuple(Val(N)) do i + fval = 7 * args[1].val * dret.val + tape[1] * 1000 + @show fval + fval + end + return dargs +end + +@testset "Closure rule" begin + cl = Closure([3.14]) + res = autodiff(Reverse, cl, Active, Active(2.7))[1][1] + @test res ≈ 7 * 2.7 + 3.14 * 1000 + @test cl[1] ≈ 0.0 +end + +end # ReverseRules From e10ad8ca364026e82f68f07a7d15c178161fbeea Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 15 Sep 2024 11:32:17 -0500 Subject: [PATCH 25/87] Update Project.toml (#1831) * Update Project.toml * fix * fix --- Project.toml | 3 +-- src/compiler/interpreter.jl | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 1ea7b5c05b1..e7baec9129b 100644 --- a/Project.toml +++ b/Project.toml @@ -6,7 +6,6 @@ version = "0.13.0" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" -EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" Enzyme_jll = "7cc45869-7501-5eee-bdea-0790c847d4ef" GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55" LLVM = "929cbde3-209d-540e-8aea-75f648917ca0" @@ -36,7 +35,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.8" -Enzyme_jll = "0.0.146, 0.0.148" +Enzyme_jll = "0.0.149" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27" LLVM = "6.1, 7, 8, =9.0" LogExpFunctions = "0.3" diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 61a433af4c2..2ef66a15713 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -41,7 +41,7 @@ function EnzymeInterpreter(cache_or_token, mt::Union{Nothing,Core.MethodTable}, @assert world <= Base.get_world_counter() parms = @static if VERSION < v"1.12" - InferenceParams(unoptimize_throw_blocks=false), + InferenceParams(unoptimize_throw_blocks=false) else InferenceParams() end From 7faa4108eaebc5c01e99e22bd09ef8f74ad74fb5 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 15 Sep 2024 13:43:02 -0500 Subject: [PATCH 26/87] Fix rand set (#1833) --- src/internal_rules.jl | 62 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 96f774f69ed..238f7f7b033 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -964,3 +964,65 @@ function EnzymeRules.reverse( ) return () end + +function EnzymeRules.forward(config::EnzymeRules.FwdConfig, + Ty::Const{typeof(Random.rand!)}, + RT::Type, + rng::Annotation{rngty}, + dst::Annotation{<:Array{FT}}, + smpl::Annotation{<:Random.SamplerTrivial{Random.CloseOpen01{FT}}}, + ) where {rngty <: Union{TaskLocalRNG, Xoshiro}, FT <: Union{Float32, Float64}} + Ty.val(rng.val, dst.val, smpl.val) + if RT <: Duplicated + fill!(dst.dval, 0) + Duplicated(dst.val, dst.dval) + elseif RT <: Const + dst.val + elseif RT <: DuplicatedNoNeed + fill!(dst.dval, 0) + dst.dval + else + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + fill!(dst.dval[i], 0) + nothing + end + if RT <: BatchDuplicated + BatchDuplicated(dst.val, dst.dval) + else + dst.dval + end + end +end + +function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, + Ty::Const{typeof(Random.rand!)}, + RT::Type, + rng::Annotation{rngty}, + dst::Annotation{<:Array{FT}}, + smpl::Annotation{<:Random.SamplerTrivial{Random.CloseOpen01{FT}}}, + ) where {rngty <: Union{TaskLocalRNG, Xoshiro}, FT <: Union{Float32, Float64}} + Ty.val(rng.val, dst.val, smpl.val) + if RT <: Duplicated || RT <: DuplicatedNoNeed + fill!(dst.dval, 0) + dst.dval + elseif RT <: BatchDuplicated || RT <: BatchDuplicatedNoNeed + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + fill!(dst.dval[i], 0) + nothing + end + end + return EnzymeRules.AugmentedReturn(EnzymeRules.needs_primal(config) ? dst.val : nothing, EnzymeRules.needs_shadow(config) ? dst.dval : nothing, nothing) +end + +function EnzymeRules.reverse(config::EnzymeRules.RevConfig, + Ty::Const{typeof(Random.rand!)}, + RT::Type, + tape, + rng::Annotation{rngty}, + dst::Annotation{<:Array{FT}}, + smpl::Annotation{<:Random.SamplerTrivial{Random.CloseOpen01{FT}}}, + ) where {rngty <: Union{TaskLocalRNG, Xoshiro}, FT <: Union{Float32, Float64}} + return (nothing, nothing, nothing) +end From d8b09f75dbeef3fb931974ac3781a7ecb55548ff Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 15 Sep 2024 19:57:43 -0500 Subject: [PATCH 27/87] Jitrules batched fn (#1835) --- src/rules/jitrules.jl | 59 ++++++++++++++++++++++++++++++++++++++----- 1 file changed, 52 insertions(+), 7 deletions(-) diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index f3f05087c00..d5818ecf2ff 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -191,6 +191,21 @@ function body_runtime_generic_fwd(N, Width, wrapped, primtypes) else :(return ReturnType((res[1], res[2]...))) end + dup = if Width == 1 + :(Duplicated(f, df)) + else + fargs = [:df] + for i in 2:Width + push!(fargs, Symbol("df_$i")) + end + :(BatchDuplicated(f, ($(fargs...),))) + end + dupty = if Width == 1 + :(Duplicated{FT}) + else + :(BatchDuplicated{FT, $Width}) + end + return quote args = ($(wrapped...),) @@ -218,9 +233,9 @@ function body_runtime_generic_fwd(N, Width, wrapped, primtypes) world = codegen_world_age(FT, tt) opt_mi = Val(world) - forward = thunk(opt_mi, (dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ForwardMode), width, #=ModifiedBetween=#Val(($(ModifiedBetween...),)), #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity) + forward = thunk(opt_mi, dupClosure ? $dupty : Const{FT}, annotation, tt′, Val(API.DEM_ForwardMode), width, #=ModifiedBetween=#Val(($(ModifiedBetween...),)), #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity) - res = forward(dupClosure ? Duplicated(f, df) : Const(f), args...) + res = forward(dupClosure ? $dup : Const(f), args...) if length(res) == 0 return ReturnType(($(nnothing...),)) @@ -304,6 +319,21 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) else :(return ReturnType((origRet, shadow_return..., tape))) end + + dup = if Width == 1 + :(Duplicated(f, df)) + else + fargs = [:df] + for i in 2:Width + push!(fargs, Symbol("df_$i")) + end + :(BatchDuplicated(f, ($(fargs...),))) + end + dupty = if Width == 1 + :(Duplicated{FT}) + else + :(BatchDuplicated{FT, $Width}) + end return quote $(active_refs...) @@ -331,11 +361,11 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) world = codegen_world_age(FT, tt) opt_mi = Val(world) - forward, adjoint = thunk(opt_mi, dupClosure0 ? Duplicated{FT} : Const{FT}, + forward, adjoint = thunk(opt_mi, dupClosure0 ? $dupty : Const{FT}, annotationA, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width, ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity) - internal_tape, origRet, initShadow = forward(dupClosure0 ? Duplicated(f, df) : Const(f), args...) + internal_tape, origRet, initShadow = forward(dupClosure0 ? $dup : Const(f), args...) annotation = annotationA resT = typeof(origRet) @@ -435,6 +465,21 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act @inbounds Types[i] = Symbol("type_$i") end + dup = if Width == 1 + :(Duplicated(f, df)) + else + fargs = [:df] + for i in 2:Width + push!(fargs, Symbol("df_$i")) + end + :(BatchDuplicated(f, ($(fargs...),))) + end + dupty = if Width == 1 + :(Duplicated{FT}) + else + :(BatchDuplicated{FT, $Width}) + end + quote $(active_refs...) args = ($(wrapped...),) @@ -460,14 +505,14 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act world = codegen_world_age(FT, tt) opt_mi = Val(world) - _, adjoint = thunk(opt_mi, dupClosure0 ? Duplicated{FT} : Const{FT}, + _, adjoint = thunk(opt_mi, dupClosure0 ? $dupty : Const{FT}, annotation, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width, ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity) tup = if annotation0 <: Active || annotation0 <: MixedDuplicated || annotation0 <: BatchMixedDuplicated - adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., $shadowret, tape.internal_tape)[1] + adjoint(dupClosure0 ? $dup : Const(f), args..., $shadowret, tape.internal_tape)[1] else - adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., tape.internal_tape)[1] + adjoint(dupClosure0 ? $dup : Const(f), args..., tape.internal_tape)[1] end $(outs...) From b02cb6dc26bacb1f9f94e0817b3abbe6b5ae1c55 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 15 Sep 2024 22:11:55 -0500 Subject: [PATCH 28/87] Force usage of full typetree on copy/memset (#1838) * Force usage of full typetree on copy/memset * fix * fix * fix * fix * fix * fix * fix * hopefully final fix? * Update Project.toml --- src/compiler.jl | 21 ++++++- src/typetree.jl | 163 ++++++++++++++++++++++++++++++++++++------------ 2 files changed, 144 insertions(+), 40 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index fb7ce8d8bf8..df3db3c086b 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -5,7 +5,7 @@ import Enzyme: Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, BatchDuplicatedFunc, Annotation, guess_activity, eltype, - API, TypeTree, typetree, TypeTreeTable, only!, shift!, data0!, merge!, to_md, + API, TypeTree, typetree, TypeTreeTable, only!, shift!, data0!, merge!, to_md, to_fullmd, TypeAnalysis, FnTypeInfo, Logic, allocatedinline, ismutabletype using Enzyme @@ -6123,7 +6123,26 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; if length(blocks(fn)) != 0 continue end + + intr = LLVM.API.LLVMGetIntrinsicID(fn) + + if intr == LLVM.Intrinsic("llvm.memcpy").id || intr == LLVM.Intrinsic("llvm.memmove").id || intr == LLVM.Intrinsic("llvm.memset").id + legal, jTy = abs_typeof(operands(inst)[1]) + sz = if intr == LLVM.Intrinsic("llvm.memcpy").id || intr == LLVM.Intrinsic("llvm.memmove").id + operands(inst)[3] + else + operands(inst)[3] + end + if legal && Base.isconcretetype(jTy) + if !(jTy isa UnionAll || jTy isa Union || jTy == Union{} || jTy === Tuple || (is_concrete_tuple(jTy) && any(T2 isa Core.TypeofVararg for T2 in jTy.parameters))) + if isa(sz, LLVM.ConstantInt) && sizeof(jTy) == convert(Int, sz) + metadata(inst)["enzyme_truetype"] = to_fullmd(jTy) + end + end + end + end end + ty = value_type(inst) if ty == LLVM.VoidType() continue diff --git a/src/typetree.jl b/src/typetree.jl index 40b01edcce3..89e5a040f36 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -59,6 +59,119 @@ function merge!(dst::TypeTree, src::TypeTree; consume=true) return nothing end +@inline function typetree_primitive(t) + return nothing +end +@inline function typetree_primitive(::Type{T}) where {T<:Integer} + return API.DT_Integer +end +@inline function typetree_primitive(::Type{Char}) + return API.DT_Integer +end +@inline function typetree_primitive(::Type{Float16}) + return API.DT_Half +end +@inline function typetree_primitive(::Type{Float32}) + return API.DT_Float +end +@inline function typetree_primitive(::Type{Float64}) + return API.DT_Double +end + + +@static if VERSION >= v"1.11-" +const TypeTreePrimitives = ( + Char, + Float16, + Float32, + Float64, + Core.BFloat16 +) +else +const TypeTreePrimitives = ( + Char, + Float16, + Float32, + Float64 +) +end + +const TypeTreeEmptyPointers = ( + BigFloat, + Any, + Symbol, + Union{}, +) + +function get_offsets(@nospecialize(T::Type)) + for sT in (Integer, TypeTreePrimitives...) + if T <: sT + return ((typetree_primitive(T), 0),) + end + end + for sT in (DataType, AbstractString, TypeTreeEmptyPointers...) + if T <: sT + return ((API.DT_Pointer, 0),) + end + end + +@static if VERSION < v"1.11-" + TypeTreePtrs = (Core.SimpleVector, Ptr, Core.LLVMPtr, Array) +else + TypeTreePtrs = (Core.SimpleVector, Ptr, Core.LLVMPtr, Array, GenericMemory) +end + for sT in TypeTreeEmptyPointers + if T <: sT + return ((API.DT_Pointer, 0),) + end + end + + @assert !(T <: AbstractFloat) + + if fieldcount(T) == 0 + return () + end + + results = Tuple{API.CConcreteType, Int}[] + for f in 1:fieldcount(T) + offset = fieldoffset(T, f) + subT = fieldtype(T, f) + + if !allocatedinline(subT) || subT isa UnionAll || subT isa Union || subT == Union{} + push!(results, (API.DT_Pointer, offset)) + continue + end + + for (sT, sO) in get_offsets(subT) + push!(results, (sT, sO+offset)) + end + end + return results +end + +function to_fullmd(@nospecialize(T::Type)) + mds = LLVM.Metadata[] + for (sT, sO) in get_offsets(T) + if sT == API.DT_Pointer + push!(mds, LLVM.MDString("Pointer")) + elseif sT == API.DT_Integer + push!(mds, LLVM.MDString("Integer")) + elseif sT == API.DT_Half + push!(mds, LLVM.MDString("Float@half")) + elseif sT == API.DT_Float + push!(mds, LLVM.MDString("Float@float")) + elseif sT == API.DT_BFloat16 + push!(mds, LLVM.MDString("Float@bfloat16")) + elseif sT == API.DT_Double + push!(mds, LLVM.MDString("Float@double")) + else + @assert false + end + push!(mds, LLVM.Metadata(LLVM.ConstantInt(sO))) + end + return LLVM.MDNode(mds) +end + function to_md(tt::TypeTree, ctx) return LLVM.Metadata(LLVM.MetadataAsValue(ccall((:EnzymeTypeTreeToMD, API.libEnzyme), LLVM.API.LLVMValueRef, @@ -91,48 +204,28 @@ function typetree(@nospecialize(T::Type), ctx, dl, seen=TypeTreeTable()) return tree::TypeTree end -function typetree_inner(::Type{T}, ctx, dl, seen::TypeTreeTable) where {T<:Integer} - return TypeTree(API.DT_Integer, -1, ctx) -end - -function typetree_inner(::Type{Char}, ctx, dl, seen::TypeTreeTable) +function typetree_inner(::Type{<:Integer}, ctx, dl, seen::TypeTreeTable) return TypeTree(API.DT_Integer, -1, ctx) end - -function typetree_inner(::Type{Float16}, ctx, dl, seen::TypeTreeTable) - return TypeTree(API.DT_Half, -1, ctx) -end - -function typetree_inner(::Type{Float32}, ctx, dl, seen::TypeTreeTable) - return TypeTree(API.DT_Float, -1, ctx) -end - -function typetree_inner(::Type{Float64}, ctx, dl, seen::TypeTreeTable) - return TypeTree(API.DT_Double, -1, ctx) -end - -@static if VERSION >= v"1.11-" -function typetree_inner(::Type{Core.BFloat16}, ctx, dl, seen::TypeTreeTable) - return TypeTree(API.DT_BFloat16, -1, ctx) -end -end - -function typetree_inner(::Type{BigFloat}, ctx, dl, seen::TypeTreeTable) - return TypeTree() +for sT in TypeTreePrimitives + @eval function typetree_inner(::Type{$sT}, ctx, dl, seen::TypeTreeTable) + return TypeTree($(typetree_primitive(sT)), -1, ctx) + end end function typetree_inner(::Type{<:DataType}, ctx, dl, seen::TypeTreeTable) return TypeTree() end - -function typetree_inner(::Type{Any}, ctx, dl, seen::TypeTreeTable) +function typetree_inner(::Type{<:AbstractString}, ctx, dl, seen::TypeTreeTable) return TypeTree() end - -function typetree_inner(::Type{Symbol}, ctx, dl, seen::TypeTreeTable) - return TypeTree() +for sT in TypeTreeEmptyPointers + @eval function typetree_inner(::Type{$sT}, ctx, dl, seen::TypeTreeTable) + return TypeTree() + end end + function typetree_inner(::Type{Core.SimpleVector}, ctx, dl, seen::TypeTreeTable) tt = TypeTree() for i in 0:(sizeof(Csize_t) - 1) @@ -141,14 +234,6 @@ function typetree_inner(::Type{Core.SimpleVector}, ctx, dl, seen::TypeTreeTable) return tt end -function typetree_inner(::Type{Union{}}, ctx, dl, seen::TypeTreeTable) - return TypeTree() -end - -function typetree_inner(::Type{<:AbstractString}, ctx, dl, seen::TypeTreeTable) - return TypeTree() -end - function typetree_inner(::Type{<:Union{Ptr{T},Core.LLVMPtr{T}}}, ctx, dl, seen::TypeTreeTable) where {T} tt = copy(typetree(T, ctx, dl, seen)) From 1992f33f940bc9822b32b70164d1c3e9c368ba53 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 16 Sep 2024 00:04:50 -0500 Subject: [PATCH 29/87] Autodiff with do blocks (#1840) --- src/Enzyme.jl | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index fcc12d57a82..4ad8a4b061d 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -639,6 +639,28 @@ result, ∂v, ∂A Enzyme.Compiler.thunk(opt_mi, FA, A, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) end +""" + autodiff(::Function, ::Mode, args...) + +Specialization of [`autodiff`](@ref) to handle do argument closures. + +```jldoctest + +autodiff(Reverse, Active(3.1)) do x + return x*x +end + +# output +((6.2,),) +``` +""" +@inline function autodiff(f::Function, m::MMode, ::Type{A}, args::Vararg{Annotation, Nargs}) where {A<:Annotation, Nargs, MMode<:Mode} + autodiff(m, f, A, args...) +end +@inline function autodiff(f::Function, m::MMode, args::Vararg{Annotation, Nargs}) where {Nargs, MMode<:Mode} + autodiff(m, f, args...) +end + """ autodiff_thunk(::ForwardMode, ftype, Activity, argtypes::Vararg{Type{<:Annotation}, Nargs}) From 8cd61a529192373bcc1445cbd1bcd6e40ab2ca51 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Sep 2024 01:32:17 -0400 Subject: [PATCH 30/87] Fixup docs --- examples/custom_rule.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/custom_rule.jl b/examples/custom_rule.jl index 86ffcf234a4..2b3f226fb01 100644 --- a/examples/custom_rule.jl +++ b/examples/custom_rule.jl @@ -168,7 +168,7 @@ g(y, x) = f(y, x)^2 # function to differentiate # Let's look at how to write a simple reverse-mode rule! # First, we write a method for [`EnzymeRules.augmented_primal`](@ref): -function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(f)}, ::Type{<:Active}, +function augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(f)}, ::Type{<:Active}, y::Duplicated, x::Duplicated) println("In custom augmented primal rule.") ## Compute primal @@ -203,7 +203,7 @@ end # Now, we write a method for [`EnzymeRules.reverse`](@ref): -function reverse(config::ConfigWidth{1}, func::Const{typeof(f)}, dret::Active, tape, +function reverse(config::RevConfigWidth{1}, func::Const{typeof(f)}, dret::Active, tape, y::Duplicated, x::Duplicated) println("In custom reverse rule.") ## retrieve x value, either from original x or from tape if x may have been overwritten. From b94e3f490ff5d1ae830417d72b09d8ee91808e98 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 16 Sep 2024 12:17:41 -0500 Subject: [PATCH 31/87] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index e7baec9129b..19315d01dca 100644 --- a/Project.toml +++ b/Project.toml @@ -35,7 +35,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.8" -Enzyme_jll = "0.0.149" +Enzyme_jll = "0.0.150" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27" LLVM = "6.1, 7, 8, =9.0" LogExpFunctions = "0.3" From 6dc7c8f2a46c975bd76610a6b79c969d7161746e Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 16 Sep 2024 14:33:31 -0500 Subject: [PATCH 32/87] Move return primal into forward mode (#1832) * Move return primal into forward mode * fix * fix * more fixups * fix * fix * fix etu * fix * fix * fix * fixup * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * Update Project.toml * fix * Update internal_rules.jl * fix * fix * fix * fix --- lib/EnzymeCore/src/EnzymeCore.jl | 23 +-- lib/EnzymeCore/src/rules.jl | 21 +- lib/EnzymeTestUtils/src/test_forward.jl | 18 +- src/Enzyme.jl | 244 +++++++++++++----------- src/compiler.jl | 16 +- src/internal_rules.jl | 182 +++++++++++------- src/rules/customrules.jl | 6 +- src/rules/jitrules.jl | 8 +- test/abi.jl | 56 +++--- test/applyiter.jl | 14 +- test/ext/chainrulescore.jl | 13 +- test/rules.jl | 14 +- test/runtests.jl | 40 ++-- 13 files changed, 373 insertions(+), 282 deletions(-) diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index 0175cb4caf5..cc71f0f9c6b 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -1,8 +1,8 @@ module EnzymeCore -export Forward, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal +export Forward, ForwardWithPrimal, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal export ReverseSplitModified, ReverseSplitWidth, ReverseHolomorphic, ReverseHolomorphicWithPrimal -export Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed +export Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, Annotation export MixedDuplicated, BatchMixedDuplicated export DefaultABI, FFIABI, InlineABI, NonGenABI export BatchDuplicatedFunc @@ -267,22 +267,23 @@ const ReverseSplitWithPrimal = ReverseModeSplit{true, true, false, 0, true,Defau @inline clear_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,false,Width,ModifiedBetween,ABI, ErrIfFuncWritten}() """ - struct Forward{ABI, ErrIfFuncWritten,RuntimeActivity} <: Mode{ABI, ErrIfFuncWritten, RuntimeActivity} + struct Forward{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity} <: Mode{ABI, ErrIfFuncWritten, RuntimeActivity} Forward mode differentiation """ -struct ForwardMode{ABI, ErrIfFuncWritten,RuntimeActivity} <: Mode{ABI, ErrIfFuncWritten, RuntimeActivity} +struct ForwardMode{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity} <: Mode{ABI, ErrIfFuncWritten, RuntimeActivity} end -const Forward = ForwardMode{DefaultABI, false, false}() +const Forward = ForwardMode{false, DefaultABI, false, false}() +const ForwardWithPrimal = ForwardMode{true, DefaultABI, false, false}() -@inline set_err_if_func_written(::ForwardMode{ABI,ErrIfFuncWritten,RuntimeActivity}) where {ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ABI,true,RuntimeActivity}() -@inline clear_err_if_func_written(::ForwardMode{ABI,ErrIfFuncWritten,RuntimeActivity}) where {ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ABI,false,RuntimeActivity}() +@inline set_err_if_func_written(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ReturnPrimal,ABI,true,RuntimeActivity}() +@inline clear_err_if_func_written(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ReturnPrimal,ABI,false,RuntimeActivity}() -@inline set_abi(::ForwardMode{OldABI,ErrIfFuncWritten,RuntimeActivity}, ::Type{NewABI}) where {OldABI,ErrIfFuncWritten,RuntimeActivity,NewABI<:ABI} = ForwardMode{NewABI,ErrIfFuncWritten,RuntimeActivity}() +@inline set_abi(::ForwardMode{ReturnPrimal,OldABI,ErrIfFuncWritten,RuntimeActivity}, ::Type{NewABI}) where {ReturnPrimal,OldABI,ErrIfFuncWritten,RuntimeActivity,NewABI<:ABI} = ForwardMode{ReturnPrimal,NewABI,ErrIfFuncWritten,RuntimeActivity}() -@inline set_runtime_activity(::ForwardMode{ABI,ErrIfFuncWritten,RuntimeActivity}) where {ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ABI,ErrIfFuncWritten,true}() -@inline set_runtime_activity(::ForwardMode{ABI,ErrIfFuncWritten,RuntimeActivity}, rt::Bool) where {ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ABI,ErrIfFuncWritten,rt}() -@inline clear_runtime_activity(::ForwardMode{ABI,ErrIfFuncWritten,RuntimeActivity}) where {ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ABI,ErrIfFuncWritten,false}() +@inline set_runtime_activity(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,true}() +@inline set_runtime_activity(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}, rt::Bool) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,rt}() +@inline clear_runtime_activity(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,false}() function autodiff end function autodiff_deferred end diff --git a/lib/EnzymeCore/src/rules.jl b/lib/EnzymeCore/src/rules.jl index 27b14619e36..8d01d321da8 100644 --- a/lib/EnzymeCore/src/rules.jl +++ b/lib/EnzymeCore/src/rules.jl @@ -22,24 +22,29 @@ The third argument is the return type annotation, and all other arguments are th function forward end """ - FwdConfig{Width, RuntimeActivity} - FwdConfigWidth{Width} = FwdConfig{Width} + FwdConfig{NeedsPrimal, NeedsShadow, Width, RuntimeActivity} + FwdConfigWidth{Width} = FwdConfig{<:Any, <:Any, Width} Configuration type to dispatch on in custom forward rules (see [`forward`](@ref). +* `NeedsPrimal` and `NeedsShadow`: boolean values specifying whether the primal and shadow (resp.) should be returned. * `Width`: an integer that specifies the number of adjoints/shadows simultaneously being propagated. * `RuntimeActivity`: whether runtime activity is enabled. -Getters for the type parameters are provided by `width` and `runtime_activity`. +Getters for the type parameters are provided by `needs_primal`, `needs_shadow`, `width` and `runtime_activity`. """ -struct FwdConfig{Width, RuntimeActivity} end -const FwdConfigWidth{Width} = FwdConfig{Width} -@inline width(::FwdConfig{Width}) where Width = Width -@inline runtime_activity(::FwdConfig{<:Any, RuntimeActivity}) where RuntimeActivity = RuntimeActivity +struct FwdConfig{NeedsPrimal, NeedsShadow, Width, RuntimeActivity} end +const FwdConfigWidth{Width} = FwdConfig{<:Any,<:Any,Width} + +@inline needs_primal(::FwdConfig{NeedsPrimal}) where NeedsPrimal = NeedsPrimal +@inline needs_shadow(::FwdConfig{<:Any, NeedsShadow}) where NeedsShadow = NeedsShadow + +@inline width(::FwdConfig{<:Any, <:Any, Width}) where Width = Width +@inline runtime_activity(::FwdConfig{<:Any, <:Any, <:Any, RuntimeActivity}) where RuntimeActivity = RuntimeActivity """ RevConfig{NeedsPrimal, NeedsShadow, Width, Overwritten, RuntimeActivity} - RevConfigWidth{Width} = RevConfig{<:Any,<:Any, Width} + RevConfigWidth{Width} = RevConfig{<:Any, <:Any, Width} Configuration type to dispatch on in custom reverse rules (see [`augmented_primal`](@ref) and [`reverse`](@ref)). * `NeedsPrimal` and `NeedsShadow`: boolean values specifying whether the primal and shadow (resp.) should be returned. diff --git a/lib/EnzymeTestUtils/src/test_forward.jl b/lib/EnzymeTestUtils/src/test_forward.jl index fcfc987cb9a..8830ce67848 100644 --- a/lib/EnzymeTestUtils/src/test_forward.jl +++ b/lib/EnzymeTestUtils/src/test_forward.jl @@ -79,13 +79,27 @@ function test_forward( # call finitedifferences, avoid mutating original arguments dy_fdm = _fd_forward(fdm, call_with_copy, ret_activity, y, activities) # call autodiff, allow mutating original arguments - y_and_dy_ad = autodiff(set_runtime_activity(Forward, runtime_activity), call_with_kwargs, ret_activity, activities...) + mode = if ret_activity <: Union{DuplicatedNoNeed,BatchDuplicatedNoNeed, Const} + Forward + else + ForwardWithPrimal + end + mode = set_runtime_activity(mode, runtime_activity) + + ret_activity2 = if ret_activity <: DuplicatedNoNeed + Duplicated + elseif ret_activity <: BatchDuplicatedNoNeed + BatchDuplicated + else + ret_activity + end + y_and_dy_ad = autodiff(mode, call_with_kwargs, ret_activity2, activities...) if ret_activity <: Union{Duplicated,BatchDuplicated} @test_msg( "For return type $ret_activity the return value and derivative must be returned", length(y_and_dy_ad) == 2, ) - y_ad, dy_ad = y_and_dy_ad + dy_ad, y_ad = y_and_dy_ad test_approx( y_ad, y, "The return value of the rule and function must agree"; atol, rtol ) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 4ad8a4b061d..bb86a33fc79 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -2,8 +2,8 @@ module Enzyme import EnzymeCore -import EnzymeCore: Forward, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode, ReverseHolomorphic, ReverseHolomorphicWithPrimal -export Forward, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode, ReverseHolomorphic, ReverseHolomorphicWithPrimal +import EnzymeCore: Forward, ForwardWithPrimal, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode, ReverseHolomorphic, ReverseHolomorphicWithPrimal +export Forward, ForwardWithPrimal, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode, ReverseHolomorphic, ReverseHolomorphicWithPrimal import EnzymeCore: Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, ABI, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written, set_abi, set_runtime_activity, clear_runtime_activity export Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written, set_abi, set_runtime_activity, clear_runtime_activity @@ -358,38 +358,33 @@ instead use [`Duplicated`](@ref) or variants like [`DuplicatedNoNeed`](@ref). `Activity` is the Activity of the return value, it may be: * `Const` if the return is not to be differentiated with respect to -* `Duplicated`, if the return is being differentiated with respect to and - both the original value and the derivative return are desired -* `DuplicatedNoNeed`, if the return is being differentiated with respect to - and only the derivative return is desired. +* `Duplicated`, if the return is being differentiated with respect to * `BatchDuplicated`, like `Duplicated`, but computing multiple derivatives at once. All batch sizes must be the same for all arguments. -* `BatchDuplicatedNoNeed`, like `DuplicatedNoNeed`, but computing multiple - derivatives at one. All batch sizes must be the same for all arguments. Example returning both original return and derivative: ```jldoctest f(x) = x*x -res, ∂f_∂x = autodiff(Forward, f, Duplicated, Duplicated(3.14, 1.0)) +res, ∂f_∂x = autodiff(ForwardWithPrimal, f, Duplicated, Duplicated(3.14, 1.0)) # output -(9.8596, 6.28) +(6.28, 9.8596) ``` Example returning just the derivative: ```jldoctest f(x) = x*x -∂f_∂x = autodiff(Forward, f, DuplicatedNoNeed, Duplicated(3.14, 1.0)) +∂f_∂x = autodiff(Forward, f, Duplicated, Duplicated(3.14, 1.0)) # output (6.28,) ``` """ -@inline function autodiff(::ForwardMode{RABI, ErrIfFuncWritten, RuntimeActivity}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation} where {RABI <: ABI, Nargs, ErrIfFuncWritten, RuntimeActivity} +@inline function autodiff(::ForwardMode{ReturnPrimal, RABI, ErrIfFuncWritten, RuntimeActivity}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation} where {ReturnPrimal, RABI <: ABI, Nargs, ErrIfFuncWritten, RuntimeActivity} if any_active(args...) throw(ErrorException("Active arguments not allowed in forward mode")) end @@ -401,7 +396,9 @@ f(x) = x*x if A <: Active throw(ErrorException("Active Returns not allowed in forward mode")) end - ReturnPrimal = Val(A <: Duplicated || A <: BatchDuplicated) + if A <: DuplicatedNoNeed || A <: BatchDuplicatedNoNeed + throw(ErrorException("Return activity `DuplicatedNoNeed` is no longer now returning or avoiding the primal is passed in for Forward Mode AD.\nPlease use autodiff(Forward, ...) or autodiff(ForwardWithPrimal, ...)")) + end RT = if A <: Duplicated && width != 1 if A isa UnionAll BatchDuplicated{T, width} where T @@ -429,7 +426,7 @@ f(x) = x*x end thunk = Enzyme.Compiler.thunk(opt_mi, FA, RT, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), - ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) + ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) thunk(f, args...) end @@ -480,7 +477,7 @@ end Same as `autodiff(::ForwardMode, f, Activity, args)` but uses deferred compilation to support usage in GPU code, as well as high-order differentiation. """ -@inline function autodiff_deferred(::ForwardMode{ABI, ErrIfFuncWritten, RuntimeActivity}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation, Nargs, ABI, ErrIfFuncWritten, RuntimeActivity} +@inline function autodiff_deferred(::ForwardMode{ReturnPrimal, ABI, ErrIfFuncWritten, RuntimeActivity}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {ReturnPrimal, FA<:Annotation, A<:Annotation, Nargs, ABI, ErrIfFuncWritten, RuntimeActivity} if any_active(args...) throw(ErrorException("Active arguments not allowed in forward mode")) end @@ -489,6 +486,9 @@ code, as well as high-order differentiation. if width == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) end + if A <: DuplicatedNoNeed || A <: BatchDuplicatedNoNeed + throw(ErrorException("Return activity `DuplicatedNoNeed` is no longer now returning or avoiding the primal is passed in for Forward Mode AD.\nPlease use autodiff(Forward, ...) or autodiff(ForwardWithPrimal, ...)")) + end RT = if A <: Duplicated && width != 1 if A isa UnionAll BatchDuplicated{T, width} where T @@ -524,7 +524,6 @@ code, as well as high-order differentiation. throw(ErrorException("Active Returns not allowed in forward mode")) end - ReturnPrimal = RT <: Duplicated || RT <: BatchDuplicated ModifiedBetween = Val(falses_from_args(Nargs+1)) adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), UnknownTapeType, Val(ErrIfFuncWritten), Val(RuntimeActivity)) @@ -673,7 +672,7 @@ ftype when called with args of type `argtypes`. The forward function will return the primal (if requested) and the shadow (or nothing if not a `Duplicated` variant). -Example returning both original return and derivative: +Example returning both the return derivative and original return: ```jldoctest a = 4.2 @@ -681,12 +680,12 @@ b = [2.2, 3.3]; ∂f_∂b = zero(b) c = 55; d = 9 f(x) = x*x -forward = autodiff_thunk(Forward, Const{typeof(f)}, Duplicated, Duplicated{Float64}) +forward = autodiff_thunk(ForwardWithPrimal, Const{typeof(f)}, Duplicated, Duplicated{Float64}) res, ∂f_∂x = forward(Const(f), Duplicated(3.14, 1.0)) # output -(9.8596, 6.28) +(6.28, 9.8596) ``` Example returning just the derivative: @@ -697,7 +696,7 @@ b = [2.2, 3.3]; ∂f_∂b = zero(b) c = 55; d = 9 f(x) = x*x -forward = autodiff_thunk(Forward, Const{typeof(f)}, DuplicatedNoNeed, Duplicated{Float64}) +forward = autodiff_thunk(Forward, Const{typeof(f)}, Duplicated, Duplicated{Float64}) ∂f_∂x = forward(Const(f), Duplicated(3.14, 1.0)) # output @@ -705,7 +704,7 @@ forward = autodiff_thunk(Forward, Const{typeof(f)}, DuplicatedNoNeed, Duplicated (6.28,) ``` """ -@inline function autodiff_thunk(::ForwardMode{RABI, ErrIfFuncWritten, RuntimeActivity}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, RABI<:ABI, Nargs, ErrIfFuncWritten, RuntimeActivity} +@inline function autodiff_thunk(::ForwardMode{ReturnPrimal, RABI, ErrIfFuncWritten, RuntimeActivity}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {ReturnPrimal, FA<:Annotation, A<:Annotation, RABI<:ABI, Nargs, ErrIfFuncWritten, RuntimeActivity} width = same_or_one(1, A, args...) if width == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) @@ -713,7 +712,10 @@ forward = autodiff_thunk(Forward, Const{typeof(f)}, DuplicatedNoNeed, Duplicated if A <: Active throw(ErrorException("Active Returns not allowed in forward mode")) end - ReturnPrimal = Val(A <: Duplicated || A <: BatchDuplicated) + if A <: DuplicatedNoNeed || A <: BatchDuplicatedNoNeed + throw(ErrorException("Return activity `DuplicatedNoNeed` is no longer now returning or avoiding the primal is passed in for Forward Mode AD.\nPlease use autodiff(Forward, ...) or autodiff(ForwardWithPrimal, ...)")) + end + ModifiedBetween = Val(falses_from_args(Nargs+1)) tt = Tuple{map(eltype, args)...} @@ -724,7 +726,7 @@ forward = autodiff_thunk(Forward, Const{typeof(f)}, DuplicatedNoNeed, Duplicated else Val(codegen_world_age(eltype(FA), tt)) end - Enzyme.Compiler.thunk(opt_mi, FA, A, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) + results = Enzyme.Compiler.thunk(opt_mi, FA, A, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) end @inline function tape_type(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity, Width,ModifiedBetweenT, RABI, ErrIfFuncWritten}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs, ErrIfFuncWritten, RuntimeActivity} @@ -1046,6 +1048,14 @@ grad = gradient(Reverse, f, [2.0, 3.0]) 2.0 ``` +```jldoctest gradient + +grad = gradient(ReverseWithPrimal, f, [2.0, 3.0]) + +# output +([3.0, 2.0], 6.0) +``` + ```jldoctest gradient grad = gradient(Reverse, only ∘ f, (a = 2.0, b = [3.0], c = "str")) @@ -1059,7 +1069,7 @@ grad = gradient(Reverse, only ∘ f, (a = 2.0, b = [3.0], c = "str")) dx = Ref(make_zero(x)) res = autodiff(rm, f, Active, MixedDuplicated(x, dx)) if ReturnPrimal - (res[2], only(dx)) + (only(dx), res[2]) else only(dx) end @@ -1067,7 +1077,7 @@ grad = gradient(Reverse, only ∘ f, (a = 2.0, b = [3.0], c = "str")) dx = make_zero(x) res = autodiff(rm, f, Active, Duplicated(x, dx)) if ReturnPrimal - (res[2], dx) + (dx, res[2]) else dx end @@ -1084,7 +1094,7 @@ Like [`gradient`](@ref), except it using deferred mode. dx = Ref(make_zero(x)) autodiff_deferred(rm, f, Active, MixedDuplicated(x, dx)) if ReturnPrimal - return (res[2], only(dx)) + return (only(dx), res[2]) else return only(dx) end @@ -1092,7 +1102,7 @@ Like [`gradient`](@ref), except it using deferred mode. dx = make_zero(x) autodiff_deferred(rm, f, Active, Duplicated(x, dx)) if ReturnPrimal - (res[2], dx) + (dx, res[2]) else dx end @@ -1108,7 +1118,7 @@ Both `x` and `dx` must be `Array`s of the same type. Example: -```jldoctest +```jldoctest gradip f(x) = x[1]*x[2] dx = [0.0, 0.0] @@ -1120,12 +1130,20 @@ gradient!(Reverse, dx, f, [2.0, 3.0]) 3.0 2.0 ``` + +```jldoctest gradip +dx = [0.0, 0.0] +gradient!(ReverseWithPrimal, dx, f, [2.0, 3.0]) + +# output +([3.0, 2.0], 6.0) +``` """ @inline function gradient!(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, dx::X, f::F, x::X) where {X<:Array, F, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten} make_zero!(dx) res = autodiff(rm, f, Active, Duplicated(x, dx)) return if ReturnPrimal - (res[2], dx) + (dx, res[2]) else dx end @@ -1141,7 +1159,7 @@ Like [`gradient!`](@ref), except it using deferred mode. make_zero!(dx) autodiff_deferred(rm, f, Active, Duplicated(x, dx)) return if ReturnPrimal - (res[2], dx) + (dx, res[2]) else dx end @@ -1158,7 +1176,7 @@ within this call. Example: -```jldoctest +```jldoctest gradfwd f(x) = x[1]*x[2] grad = gradient(Forward, f, [2.0, 3.0]) @@ -1167,17 +1185,35 @@ grad = gradient(Forward, f, [2.0, 3.0]) (3.0, 2.0) ``` + +```jldoctest gradfwd +gradient(ForwardWithPrimal, f, [2.0, 3.0]) + +# output +((3.0, 2.0), 6.0) +``` """ -@inline function gradient(fm::ForwardMode, f, x; shadow=onehot(x)) +@inline function gradient(fm::ForwardMode{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity}, f, x; shadow=onehot(x)) where {ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity} if length(shadow) == 0 - return () + if ReturnPrimal + ((), f(x.val)) + else + return () + end end - res = values(only(autodiff(fm, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow)))) - if x isa AbstractFloat + resp = autodiff(fm, f, BatchDuplicated, BatchDuplicated(x, shadow)) + + res = values(resp[1]) + dres = if x isa AbstractFloat res[1] else res end + if ReturnPrimal + (dres, resp[2]) + else + dres + end end @inline function chunkedonehot(x, ::Val{chunk}) where chunk @@ -1216,29 +1252,64 @@ grad = gradient(Forward, f, [2.0, 3.0], Val(2)) (3.0, 2.0) ``` """ -@inline function gradient(fm::ForwardMode, f::F, x::X, ::Val{chunk}; shadow=chunkedonehot(x, Val(chunk))) where {F, X, chunk} +@inline function gradient(fm::ForwardMode{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity}, f::F, x::X, ::Val{chunk}; shadow=chunkedonehot(x, Val(chunk))) where {F, X, chunk, ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity} if chunk == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) end - tmp = ntuple(length(shadow)) do i - values(autodiff(fm, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow[i]))[1]) - end - res = tupleconcat(tmp...) - if x isa AbstractFloat - res[1] + if ReturnPrimal + rp = autodiff(fm, f, BatchDuplicated, BatchDuplicated(x, shadow[1]))[1] + dres1 = if chunk == 1 + (rp[1],) + else + values(rp[1]) + end + gres = if x isa AbstractFloat + dres1 + else + fm2 = ForwardMode{#=ReturnPrimal=#false, ABI, ErrIfFuncWritten,RuntimeActivity}() + tmp = ntuple(length(shadow)-1) do i + values(autodiff(fm2, f, BatchDuplicated, BatchDuplicated(x, shadow[i+1]))[1]) + end + tupleconcat(dres1, tmp...) + end + (gres, rp[2]) else - res + tmp = ntuple(length(shadow)) do i + values(autodiff(fm, f, BatchDuplicated, BatchDuplicated(x, shadow[i]))[1]) + end + res = tupleconcat(tmp...) + if x isa AbstractFloat + res[1] + else + res + end end end -@inline function gradient(fm::ForwardMode, f::F, x::X, ::Val{1}; shadow=onehot(x)) where {F, X} - res = ntuple(length(shadow)) do i - autodiff(fm, f, DuplicatedNoNeed, Duplicated(x, shadow[i]))[1] - end - if x isa AbstractFloat - res[1] +@inline function gradient(fm::ForwardMode{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity}, f::F, x::X, ::Val{1}; shadow=onehot(x)) where {F, X, ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity} + if ReturnPrimal + rp = autodiff(fm, f, Duplicated, Duplicated(x, shadow[1])) + dres1 = rp[1] + fm2 = ForwardMode{#=ReturnPrimal=#false, ABI, ErrIfFuncWritten,RuntimeActivity}() + + res = ntuple(length(shadow)-1) do i + autodiff(fm2, f, Duplicated, Duplicated(x, shadow[i+1]))[1] + end + gres = if x isa AbstractFloat + dres1 + else + (dres1, res...) + end + (gres, rp[2]) else - res + res = ntuple(length(shadow)) do i + autodiff(fm, f, Duplicated, Duplicated(x, shadow[i]))[1] + end + if x isa AbstractFloat + res[1] + else + res + end end end @@ -1270,46 +1341,16 @@ whose shape is `(size(output)..., size(input)...)` For functions who return other types, this function will retun an array or tuple of shape `size(input)` of values of the output type. """ -@inline function jacobian(fm::ForwardMode, f, x; shadow=onehot(x)) - cols = if length(shadow) == 0 - () +@inline function jacobian(fm::ForwardMode{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity}, args...; kwargs...) where {ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity} + gradtup = gradient(fm, args...; kwargs...) + cols = if ReturnPrimal + gradtup[1] else - values(only(autodiff(fm, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow)))) + gradtup end - if x isa AbstractFloat - cols[1] - elseif length(cols) > 0 && cols[1] isa AbstractArray - inshape = size(x) - outshape = size(cols[1]) - # st : outshape x total inputs - st = Base.stack(cols) - - st3 = if length(inshape) <= 1 - st - else - reshape(st, (outshape..., inshape...)) - end - - st3 - elseif x isa AbstractArray - inshape = size(x) - reshape(collect(cols), inshape) - else + x = args[2] + res = if x isa AbstractFloat cols - end -end - -@inline function jacobian(fm::ForwardMode, f::F, x::X, ::Val{chunk}; shadow=chunkedonehot(x, Val(chunk))) where {F, X, chunk} - if chunk == 0 - throw(ErrorException("Cannot differentiate with a batch size of 0")) - end - tmp = ntuple(length(shadow)) do i - Base.@_inline_meta - values(autodiff(fm, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow[i]))[1]) - end - cols = tupleconcat(tmp...) - if x isa AbstractFloat - cols[1] elseif length(cols) > 0 && cols[1] isa AbstractArray inshape = size(x) outshape = size(cols[1]) @@ -1329,33 +1370,10 @@ end else cols end -end - -@inline function jacobian(fm::ForwardMode, f::F, x::X, ::Val{1}; shadow=onehot(x)) where {F,X} - cols = ntuple(length(shadow)) do i - Base.@_inline_meta - autodiff(fm, f, DuplicatedNoNeed, Duplicated(x, shadow[i]))[1] - end - if x isa AbstractFloat - cols[1] - elseif length(cols) > 0 && cols[1] isa AbstractArray - inshape = size(x) - outshape = size(cols[1]) - # st : outshape x total inputs - st = Base.stack(cols) - - st3 = if length(inshape) <= 1 - st - else - reshape(st, (outshape..., inshape...)) - end - - st3 - elseif x isa AbstractArray - inshape = size(x) - reshape(collect(cols), inshape) + if ReturnPrimal + (res, gradtup[2]) else - cols + res end end diff --git a/src/compiler.jl b/src/compiler.jl index df3db3c086b..1d21fb99a1d 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -637,7 +637,7 @@ end return Const{T} end if Mode == API.DEM_ForwardMode - return DuplicatedNoNeed{T} + return Duplicated{T} else if ActReg == ActiveState return Active{T} @@ -4216,9 +4216,6 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, end end if Mode == API.DEM_ForwardMode - if returnPrimal - push!(sret_types, literal_rt) - end if !(rettype <: Const) if width == 1 push!(sret_types, literal_rt) @@ -4226,6 +4223,9 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, push!(sret_types, AnonymousStruct(NTuple{width, literal_rt})) end end + if returnPrimal + push!(sret_types, literal_rt) + end end combinedReturn = if any(any_jltypes(convert(LLVM.LLVMType, T; allow_boxed=true)) for T in sret_types) @@ -4562,7 +4562,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, val else @assert count_llvm_Sret > 1 - extract_value!(builder, val, returnNum) + extract_value!(builder, val, 1-returnNum) end) ptr = inbounds_gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), returnNum)]) ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval))) @@ -6906,7 +6906,7 @@ end push!(sret_types, TapeType) end - if returnPrimal + if returnPrimal && !(CC <: ForwardModeThunk) push!(sret_types, jlRT) end if is_forward @@ -6930,6 +6930,10 @@ end end end + if returnPrimal && (CC <: ForwardModeThunk) + push!(sret_types, jlRT) + end + # calls fptr llvmtys = LLVMType[convert(LLVMType, x; allow_boxed=true) for x in types] diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 238f7f7b033..f29ed0d9776 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -573,12 +573,14 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, inds = sortperm(xs.val; kwargs...) xs.val .= xs.val[inds] xs.dval .= xs.dval[inds] - if RT <: Const - return xs.val - elseif RT <: DuplicatedNoNeed + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return xs + elseif EnzymeRules.needs_shadow(config) return xs.dval + elseif EnzymeRules.needs_primal(config) + return xs.val else - return xs + return nothing end end @@ -593,12 +595,14 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, for i in 1:N xs.dval[i] .= xs.dval[i][inds] end - if RT <: Const - return xs.val - elseif RT <: BatchDuplicatedNoNeed + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return xs + elseif EnzymeRules.needs_shadow(config) return xs.dval + elseif EnzymeRules.needs_primal(config) + return xs.val else - return xs + return nothing end end @@ -652,16 +656,19 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, partialsortperm!(inds, xs.val, kv; kwargs...) xs.val .= xs.val[inds] xs.dval .= xs.dval[inds] - if RT <: Const - return kv isa Integer ? xs.val[kv] : view(xs.val, kv) - elseif RT <: DuplicatedNoNeed - return kv isa Integer ? xs.dval[kv] : view(xs.dval, kv) - else + + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) if kv isa Integer return Duplicated(xs.val[kv], xs.dval[kv]) else return Duplicated(view(xs.val, kv), view(xs.dval, kv)) end + elseif EnzymeRules.needs_shadow(config) + return kv isa Integer ? xs.dval[kv] : view(xs.dval, kv) + elseif EnzymeRules.needs_primal(config) + return kv isa Integer ? xs.val[kv] : view(xs.val, kv) + else + return nothing end end @@ -679,20 +686,23 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, for i in 1:N xs.dval[i] .= xs.dval[i][inds] end - if RT <: Const - return kv isa Integer ? xs.val[kv] : view(xs.val, kv) - elseif RT <: BatchDuplicatedNoNeed + + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) if kv isa Integer - return ntuple(i -> xs.dval[i][kv], N) + return BatchDuplicated(xs.val[kv], ntuple(i -> xs.dval[i][kv], N)) else - return ntuple(i -> view(xs.dval[i], kv), N) + return BatchDuplicated(view(xs.val, kv), ntuple(i -> view(xs.dval[i], kv), N)) end - else + elseif EnzymeRules.needs_shadow(config) if kv isa Integer - return BatchDuplicated(xs.val[kv], ntuple(i -> xs.dval[i][kv], N)) + return ntuple(i -> xs.dval[i][kv], N) else - return BatchDuplicated(view(xs.val, kv), ntuple(i -> view(xs.dval[i], kv), N)) + return ntuple(i -> view(xs.dval[i], kv), N) end + elseif EnzymeRules.needs_primal(config) + return kv isa Integer ? xs.val[kv] : view(xs.val, kv) + else + return nothing end end @@ -756,7 +766,12 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, func::Const{typeof(l B::Annotation{<:AbstractVecOrMat}; kwargs...) if B isa Const - return func.val(fact.val, B.val; kwargs...) + retval = func.val(fact.val, B.val; kwargs...) + if EnzymeRules.needs_primal(config) + retval + else + return nothing + end else N = EnzymeRules.width(config) retval = B.val @@ -787,16 +802,23 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, func::Const{typeof(l return dB end - if RT <: Const + + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + return Duplicated(retval, dretvals[1]) + else + return BatchDuplicated(retval, dretvals) + end + elseif EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + return dretvals[1] + else + return dretvals + end + elseif EnzymeRules.needs_primal(config) return retval - elseif RT <: DuplicatedNoNeed - return dretvals[1] - elseif RT <: Duplicated - return Duplicated(retval, dretvals[1]) - elseif RT <: BatchDuplicatedNoNeed - return dretvals else - return BatchDuplicated(retval, dretvals) + return nothing end end end @@ -830,23 +852,27 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, func::Const{Colon}, error("Annotation type $(typeof(start)) not supported for range step. Please open an issue") end - if RT <: Duplicated - Duplicated(ret, range(dstart; step=dstep, length=length(ret))) - elseif RT <: Const - ret - elseif RT <: DuplicatedNoNeed - range(dstart; step=dstep, length=length(ret)) - elseif RT <: BatchDuplicated - BatchDuplicated(ret, + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + return Duplicated(ret, range(dstart; step=dstep, length=length(ret))) + else + return BatchDuplicated(ret, ntuple(i -> range(dstart isa Number ? dstart : dstart[i]; step=dstep isa Number ? dstep : dstep[i], length=length(ret)), Val(EnzymeRules.width(config)))) - elseif RT <: BatchDuplicatedNoNeed - ntuple(i -> range(dstart isa Number ? dstart : dstart[i]; + end + elseif EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + return range(dstart; step=dstep, length=length(ret)) + else + return ntuple(i -> range(dstart isa Number ? dstart : dstart[i]; step=dstep isa Number ? dstep : dstep[i], length=length(ret)), Val(EnzymeRules.width(config))) + end + elseif EnzymeRules.needs_primal(config) + return ret else - error("This should not be possible. Please report.") + return nothing end end @@ -908,24 +934,30 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, RT::Type{<:Union{DuplicatedNoNeed, Duplicated, BatchDuplicated, BatchDuplicatedNoNeed}}; kwargs... ) - if RT <: Const - return Ty.val(; kwargs...) - elseif RT <: DuplicatedNoNeed - return Ty.val(; kwargs...) - elseif RT <: Duplicated - return RT(Ty.val(; kwargs...), Ty.val(; kwargs...)) - elseif RT <: BatchDuplicatedNoNeed - ntuple(Val(EnzymeRules.width(config))) do i - Base.@_inline_meta - Ty.val(; kwargs...) + + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + return RT(Ty.val(; kwargs...), Ty.val(; kwargs...)) + else + tup = ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + Ty.val(; kwargs...) + end + return RT(Ty.val(; kwargs...), tup) end - else - @assert RT <: BatchDuplicated - tup = ntuple(Val(EnzymeRules.width(config))) do i - Base.@_inline_meta - Ty.val(; kwargs...) + elseif EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + return Ty.val(; kwargs...) + else + return ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + Ty.val(; kwargs...) + end end - RT(Ty.val(; kwargs...), tup) + elseif EnzymeRules.needs_primal(config) + return Ty.val(; kwargs...) + else + return nothing end end @@ -973,26 +1005,28 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, smpl::Annotation{<:Random.SamplerTrivial{Random.CloseOpen01{FT}}}, ) where {rngty <: Union{TaskLocalRNG, Xoshiro}, FT <: Union{Float32, Float64}} Ty.val(rng.val, dst.val, smpl.val) - if RT <: Duplicated - fill!(dst.dval, 0) - Duplicated(dst.val, dst.dval) - elseif RT <: Const - dst.val - elseif RT <: DuplicatedNoNeed - fill!(dst.dval, 0) - dst.dval - else - ntuple(Val(EnzymeRules.width(config))) do i - Base.@_inline_meta - fill!(dst.dval[i], 0) - nothing - end - if RT <: BatchDuplicated - BatchDuplicated(dst.val, dst.dval) + + if !(dst isa Const) + if EnzymeRules.width(config) == 1 + fill!(dst.dval, 0) else - dst.dval + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + fill!(dst.dval[i], 0) + nothing + end end end + + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + dst + elseif EnzymeRules.needs_shadow(config) + dst.dval + elseif EnzymeRules.needs_primal(config) + dst.val + else + nothing + end end function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 75e36370d8c..e0eae36e4d8 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -352,7 +352,6 @@ end end width = get_width(gutils) - C = EnzymeRules.FwdConfig{Int(width), get_runtime_activity(gutils)} if shadowR != C_NULL unsafe_store!(shadowR,UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))).ref) @@ -374,6 +373,8 @@ end args, activity, overwritten, actives, kwtup, _ = enzyme_custom_setup_args(B, orig, gutils, mi, RealRt, #=reverse=#false, isKWCall) RT, needsPrimal, needsShadow, origNeedsPrimal = enzyme_custom_setup_ret(gutils, orig, mi, RealRt, B) + C = EnzymeRules.FwdConfig{Bool(needsPrimal), Bool(needsShadow), Int(width), get_runtime_activity(gutils)} + alloctx = LLVM.IRBuilder() position!(alloctx, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils))) mode = get_mode(gutils) @@ -494,8 +495,7 @@ end normalV = C_NULL if RT <: Const - # TODO introduce const-no-need - if needsPrimal || true + if needsPrimal if RealRt != fwd_RT emit_error(B, orig, "Enzyme: incorrect return type of const primal-only forward custom rule - "*(string(RT))*" "*string(activity)*" want just return type "*string(RealRt)*" found "*string(fwd_RT)) return false diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index d5818ecf2ff..01edec7118d 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -187,9 +187,9 @@ function body_runtime_generic_fwd(N, Width, wrapped, primtypes) end retres = if Width == 1 - :(return ReturnType((res[1], res[2]))) + :(return ReturnType((res[2], res[1]))) else - :(return ReturnType((res[1], res[2]...))) + :(return ReturnType((res[2], res[1]...))) end dup = if Width == 1 :(Duplicated(f, df)) @@ -764,9 +764,9 @@ function fwddiff_with_return(runtimeActivity::Val{RuntimeActivity}, ::Val{width} ReturnType(allFirst(Val(width+1), res)) else if width == 1 - ReturnType((res[1], res[2])) + ReturnType((res[2], res[1])) else - ReturnType((res[1], res[2]...)) + ReturnType((res[2], res[1]...)) end end end diff --git a/test/abi.jl b/test/abi.jl index e07b7403cee..63fe48dc61e 100644 --- a/test/abi.jl +++ b/test/abi.jl @@ -45,7 +45,7 @@ using Test @test_throws ErrorException autodiff(Reverse, f, Active, Active(1.5 + 0.7im)) cres, = autodiff(ReverseHolomorphic, f, Active, Active(1.5 + 0.7im))[1] @test cres ≈ 1.0 + 0.0im - cres, = autodiff(Forward, f, DuplicatedNoNeed, Duplicated(1.5 + 0.7im, 1.0 + 0im)) + cres, = autodiff(Forward, f, Duplicated, Duplicated(1.5 + 0.7im, 1.0 + 0im)) @test cres ≈ 1.0 + 0.0im @test_throws ErrorException autodiff(Reverse, f, Active(1.5 + 0.7im)) @@ -68,12 +68,12 @@ using Test _, res0 = autodiff(Enzyme.set_abi(Reverse, NonGenABI), unused, Active, Const(nothing), Active(2.0))[1] @test res0 ≈ 1.0 - res0, = autodiff(Forward, unused, DuplicatedNoNeed, Const(nothing), Duplicated(2.0, 1.0)) + res0, = autodiff(Forward, unused, Duplicated, Const(nothing), Duplicated(2.0, 1.0)) @test res0 ≈ 1.0 - res0, = autodiff(Forward, unused, DuplicatedNoNeed, Const(nothing), DuplicatedNoNeed(2.0, 1.0)) + res0, = autodiff(Forward, unused, Duplicated, Const(nothing), DuplicatedNoNeed(2.0, 1.0)) @test res0 ≈ 1.0 - res0, = autodiff(Enzyme.set_abi(Forward, NonGenABI), unused, DuplicatedNoNeed, Const(nothing), Duplicated(2.0, 1.0)) + res0, = autodiff(Enzyme.set_abi(Forward, NonGenABI), unused, Duplicated, Const(nothing), Duplicated(2.0, 1.0)) @test res0 ≈ 1.0 _, res0 = autodiff(Reverse, unused, Const(nothing), Active(2.0))[1] @@ -193,7 +193,7 @@ using Test res2, = autodiff(Reverse, g, Active, Active(Foo(3, 1.2)))[1] @test res2.qux ≈ 1.0 - @test 1.0≈ first(autodiff(Forward, g, DuplicatedNoNeed, Duplicated(Foo(3, 1.2), Foo(0, 1.0)))) + @test 1.0≈ first(autodiff(Forward, g, Duplicated, Duplicated(Foo(3, 1.2), Foo(0, 1.0)))) res2, = autodiff(Reverse, g, Active(Foo(3, 1.2)))[1] @test res2.qux ≈ 1.0 @@ -204,7 +204,7 @@ using Test _, resF = autodiff(Reverse, unused2, Active, Const(nothing), Active(Foo(3, 2.0)))[1] @test resF.qux ≈ 1.0 - @test 1.0≈ first(autodiff(Forward, unused2, DuplicatedNoNeed, Const(nothing), Duplicated(Foo(3, 1.2), Foo(0, 1.0)))) + @test 1.0≈ first(autodiff(Forward, unused2, Duplicated, Const(nothing), Duplicated(Foo(3, 1.2), Foo(0, 1.0)))) _, resF = autodiff(Reverse, unused2, Const(nothing), Active(Foo(3, 2.0)))[1] @test resF.qux ≈ 1.0 @@ -216,7 +216,7 @@ using Test @test res3[1].qux ≈ 3.4 @test res3[2].qux ≈ 1.2 - @test 7*3.4 + 9 * 1.2 ≈ first(autodiff(Forward, h, DuplicatedNoNeed, Duplicated(Foo(3, 1.2), Foo(0, 7.0)), Duplicated(Foo(5, 3.4), Foo(0, 9.0)))) + @test 7*3.4 + 9 * 1.2 ≈ first(autodiff(Forward, h, Duplicated, Duplicated(Foo(3, 1.2), Foo(0, 7.0)), Duplicated(Foo(5, 3.4), Foo(0, 9.0)))) res3 = autodiff(Reverse, h, Active(Foo(3, 1.2)), Active(Foo(5, 3.4)))[1] @test res3[1].qux ≈ 3.4 @@ -228,7 +228,7 @@ using Test _, res4 = autodiff(Reverse, caller, Active, Const((x)->x), Active(3.0))[1] @test res4 ≈ 1.0 - res4, = autodiff(Forward, caller, DuplicatedNoNeed, Const((x)->x), Duplicated(3.0, 1.0)) + res4, = autodiff(Forward, caller, Duplicated, Const((x)->x), Duplicated(3.0, 1.0)) @test res4 ≈ 1.0 _, res4 = autodiff(Reverse, caller, Const((x)->x), Active(3.0))[1] @@ -257,7 +257,7 @@ using Test @test ad === ((nothing,),) @test shadow.val ≈ 1.0 && shadow.next.val ≈ 1.0 - @test 2.0 ≈ first(autodiff(Forward, sumlist, DuplicatedNoNeed, Duplicated(regular, shadow))) + @test 2.0 ≈ first(autodiff(Forward, sumlist, Duplicated, Duplicated(regular, shadow))) mulr(x, y) = x[] * y[] x = Ref(2.0) @@ -273,7 +273,7 @@ using Test y = Ref(3.0) dx = Ref(5.0) dy = Ref(7.0) - @test 5.0*3.0 + 2.0*7.0≈ first(autodiff(Forward, mulr, DuplicatedNoNeed, Duplicated(x, dx), Duplicated(y, dy))) + @test 5.0*3.0 + 2.0*7.0≈ first(autodiff(Forward, mulr, Duplicated, Duplicated(x, dx), Duplicated(y, dy))) _, mid = Enzyme.autodiff(Reverse, (fs, x) -> fs[1](x), Active, Const((x->x*x,)), Active(2.0))[1] @test mid ≈ 4.0 @@ -281,10 +281,10 @@ using Test _, mid = Enzyme.autodiff(Reverse, (fs, x) -> fs[1](x), Active, Const([x->x*x]), Active(2.0))[1] @test mid ≈ 4.0 - mid, = Enzyme.autodiff(Forward, (fs, x) -> fs[1](x), DuplicatedNoNeed, Const((x->x*x,)), Duplicated(2.0, 1.0)) + mid, = Enzyme.autodiff(Forward, (fs, x) -> fs[1](x), Duplicated, Const((x->x*x,)), Duplicated(2.0, 1.0)) @test mid ≈ 4.0 - mid, = Enzyme.autodiff(Forward, (fs, x) -> fs[1](x), DuplicatedNoNeed, Const([x->x*x]), Duplicated(2.0, 1.0)) + mid, = Enzyme.autodiff(Forward, (fs, x) -> fs[1](x), Duplicated, Const([x->x*x]), Duplicated(2.0, 1.0)) @test mid ≈ 4.0 @@ -394,8 +394,8 @@ end @test Enzyme.autodiff(Reverse, method, Active, Const(AFoo(2.0)), Active(3.0))[1][2] ≈ 2.0 @test Enzyme.autodiff(Reverse, AFoo(2.0), Active, Active(3.0))[1][1] ≈ 2.0 - @test Enzyme.autodiff(Forward, method, DuplicatedNoNeed, Const(AFoo(2.0)), Duplicated(3.0, 1.0))[1] ≈ 2.0 - @test Enzyme.autodiff(Forward, AFoo(2.0), DuplicatedNoNeed, Duplicated(3.0, 1.0))[1] ≈ 2.0 + @test Enzyme.autodiff(Forward, method, Duplicated, Const(AFoo(2.0)), Duplicated(3.0, 1.0))[1] ≈ 2.0 + @test Enzyme.autodiff(Forward, AFoo(2.0), Duplicated, Duplicated(3.0, 1.0))[1] ≈ 2.0 struct ABar end @@ -407,8 +407,8 @@ end @test Enzyme.autodiff(Reverse, method, Active, Const(ABar()), Active(3.0))[1][2] ≈ 2.0 @test Enzyme.autodiff(Reverse, ABar(), Active, Active(3.0))[1][1] ≈ 2.0 - @test Enzyme.autodiff(Forward, method, DuplicatedNoNeed, Const(ABar()), Duplicated(3.0, 1.0))[1] ≈ 2.0 - @test Enzyme.autodiff(Forward, ABar(), DuplicatedNoNeed, Duplicated(3.0, 1.0))[1] ≈ 2.0 + @test Enzyme.autodiff(Forward, method, Duplicated, Const(ABar()), Duplicated(3.0, 1.0))[1] ≈ 2.0 + @test Enzyme.autodiff(Forward, ABar(), Duplicated, Duplicated(3.0, 1.0))[1] ≈ 2.0 struct RWClos x::Vector{Float64} @@ -446,14 +446,14 @@ end @testset "Promotion" begin x = [1.0, 2.0]; dx_1 = [1.0, 0.0]; dx_2 = [0.0, 1.0]; rosenbrock_inp(x) = (1.0 - x[1])^2 + 100.0 * (x[2] - x[1]^2)^2 - r = autodiff(Forward, rosenbrock_inp, Duplicated, BatchDuplicated(x, (dx_1, dx_2))) - @test r[1] ≈ 100.0 - @test r[2][1] ≈ -400.0 - @test r[2][2] ≈ 200.0 - r = autodiff_deferred(Forward, rosenbrock_inp, Duplicated, BatchDuplicated(x, (dx_1, dx_2))) - @test r[1] ≈ 100.0 - @test r[2][1] ≈ -400.0 - @test r[2][2] ≈ 200.0 + r = autodiff(ForwardWithPrimal, rosenbrock_inp, Duplicated, BatchDuplicated(x, (dx_1, dx_2))) + @test r[2] ≈ 100.0 + @test r[1][1] ≈ -400.0 + @test r[1][2] ≈ 200.0 + r = autodiff_deferred(ForwardWithPrimal, rosenbrock_inp, Duplicated, BatchDuplicated(x, (dx_1, dx_2))) + @test r[2] ≈ 100.0 + @test r[1][1] ≈ -400.0 + @test r[1][2] ≈ 200.0 end abssum(x) = sum(abs2, x); @@ -467,11 +467,14 @@ mulsin(x) = sin(x[1] * x[2]) @inferred autodiff(Enzyme.ReverseHolomorphic, abssum, Duplicated(x,x)) @inferred autodiff(Enzyme.ReverseHolomorphicWithPrimal, abssum, Duplicated(x,x)) @inferred autodiff(Enzyme.Forward, abssum, Duplicated(x,x)) + @inferred autodiff(Enzyme.ForwardWithPrimal, abssum, Duplicated, Duplicated(x,x)) @inferred autodiff(Enzyme.Forward, abssum, Duplicated, Duplicated(x,x)) - @inferred autodiff(Enzyme.Forward, abssum, DuplicatedNoNeed, Duplicated(x,x)) @inferred gradient(Reverse, abssum, x) @inferred gradient!(Reverse, x, abssum, x) + + @inferred gradient(ReverseWithPrimal, abssum, x) + @inferred gradient!(ReverseWithPrimal, x, abssum, x) cx = ones(10) @inferred autodiff(Enzyme.ReverseHolomorphic, sum, Duplicated(cx,cx)) @@ -489,6 +492,9 @@ mulsin(x) = sin(x[1] * x[2]) @inferred gradient(Reverse, abssum, tx) @inferred gradient(Forward, abssum, tx) + @inferred gradient(ReverseWithPrimal, abssum, tx) + @inferred gradient(ForwardWithPrimal, abssum, tx) + @inferred hvp(mulsin, [2.0, 3.0], [5.0, 2.7]) @inferred hvp!(zeros(2), mulsin, [2.0, 3.0], [5.0, 2.7]) diff --git a/test/applyiter.jl b/test/applyiter.jl index 5b55617e553..642ad62035c 100644 --- a/test/applyiter.jl +++ b/test/applyiter.jl @@ -267,7 +267,7 @@ end @test dres[3] ≈ 100.02 @test dres[4] ≈ 304.1 - res, dres = Enzyme.autodiff(Forward, metaconcat, Duplicated, Duplicated(x, dx)) + dres, res = Enzyme.autodiff(ForwardWithPrimal, metaconcat, Duplicated, Duplicated(x, dx)) @test length(res) == 4 @test res[1] ≈ 2.0 @test res[2] ≈ 3.0 @@ -290,7 +290,7 @@ end @test dres[3] == "c" @test dres[4] == "d" - res, dres = Enzyme.autodiff(Forward, metaconcat, Duplicated, Duplicated(a, da)) + dres, res = Enzyme.autodiff(ForwardWithPrimal, metaconcat, Duplicated, Duplicated(a, da)) @test length(res) == 4 @test res[1] == "a" @test res[2] == "b" @@ -313,7 +313,7 @@ end @test dres[4] == "c" @test dres[5] == "d" - res, dres = Enzyme.autodiff(Forward, midconcat, Duplicated, Duplicated(1.0, 7.0), Duplicated(a, da)) + dres, res = Enzyme.autodiff(ForwardWithPrimal, midconcat, Duplicated, Duplicated(1.0, 7.0), Duplicated(a, da)) @test length(res) == 5 @test res[1] ≈ 1.0 @test res[2] == "a" @@ -337,7 +337,7 @@ end @test dres[4] == "c" @test dres[5] == "d" - res, dres = Enzyme.autodiff(Forward, midconcat, Duplicated, Duplicated(1.0, 7.0), Const(a)) + dres, res = Enzyme.autodiff(ForwardWithPrimal, midconcat, Duplicated, Duplicated(1.0, 7.0), Const(a)) @test length(res) == 5 @test res[1] ≈ 1.0 @test res[2] == "a" @@ -365,7 +365,7 @@ end @test dres[7] ≈ -9100.02 @test dres[8] ≈ -9304.1 - res, dres = Enzyme.autodiff(Forward, metaconcat2, Duplicated, Duplicated(x, dx), Duplicated(y, dy)) + dres, res = Enzyme.autodiff(ForwardWithPrimal, metaconcat2, Duplicated, Duplicated(x, dx), Duplicated(y, dy)) @test length(res) == 8 @test res[1] ≈ 2.0 @test res[2] ≈ 3.0 @@ -403,7 +403,7 @@ end @test dres[11] ≈ -9100.02 @test dres[12] ≈ -9304.1 - res, dres = Enzyme.autodiff(Forward, metaconcat3, Duplicated, Duplicated(x, dx), Const(a), Duplicated(y, dy)) + dres, res = Enzyme.autodiff(ForwardWithPrimal, metaconcat3, Duplicated, Duplicated(x, dx), Const(a), Duplicated(y, dy)) @test length(res) == 12 @test res[1] ≈ 2.0 @test res[2] ≈ 3.0 @@ -449,7 +449,7 @@ end @test dres[2][3] ≈ -9100.02 @test dres[2][4] ≈ -9304.1 - res, dres = Enzyme.autodiff(Forward, metaconcat, Duplicated, BatchDuplicated(x, (dx, dy))) + dres, res = Enzyme.autodiff(ForwardWithPrimal, metaconcat, Duplicated, BatchDuplicated(x, (dx, dy))) @test length(res) == 4 @test res[1] ≈ 2.0 @test res[2] ≈ 3.0 diff --git a/test/ext/chainrulescore.jl b/test/ext/chainrulescore.jl index b73117faf25..65984ef26f8 100644 --- a/test/ext/chainrulescore.jl +++ b/test/ext/chainrulescore.jl @@ -24,8 +24,11 @@ function ChainRulesCore.rrule(::typeof(MockModule.mock_function), x) return y, ȳ -> 2 * ȳ end -fdiff(f, x::Number) = autodiff(Forward, f, Duplicated, Duplicated(x, one(x)))[2] -fdiff(f, x::MockModule.MockType) = autodiff(Forward, f, Duplicated, Duplicated(x, MockModule.MockType(one(x.x))))[2] +fdiff(f, x::Number) = autodiff(ForwardWithPrimal, f, Duplicated, Duplicated(x, one(x)))[1] +fdiff(f, x::MockModule.MockType) = autodiff(ForwardWithPrimal, f, Duplicated, Duplicated(x, MockModule.MockType(one(x.x))))[1] + +fdiff2(f, x::Number) = autodiff(Forward, f, Duplicated, Duplicated(x, one(x)))[1] +fdiff2(f, x::MockModule.MockType) = autodiff(Forward, f, Duplicated, Duplicated(x, MockModule.MockType(one(x.x))))[1] @testset "import_frule" begin f1(x) = 2*x @@ -33,6 +36,8 @@ fdiff(f, x::MockModule.MockType) = autodiff(Forward, f, Duplicated, Duplicated(x Enzyme.@import_frule typeof(f1) Any @test fdiff(f1, 1f0) === 5f0 @test fdiff(f1, 1.0) === 5.0 + @test fdiff2(f1, 1f0) === 5f0 + @test fdiff2(f1, 1.0) === 5.0 # specific signature f2(x) = 2*x @@ -40,6 +45,8 @@ fdiff(f, x::MockModule.MockType) = autodiff(Forward, f, Duplicated, Duplicated(x Enzyme.@import_frule typeof(f2) Float32 @test fdiff(f2, 1f0) === 5f0 @test fdiff(f2, 1.0) === 2.0 + @test fdiff2(f2, 1f0) === 5f0 + @test fdiff2(f2, 1.0) === 2.0 # two arguments f3(x, y) = 2*x + y @@ -47,6 +54,8 @@ fdiff(f, x::MockModule.MockType) = autodiff(Forward, f, Duplicated, Duplicated(x Enzyme.@import_frule typeof(f3) Any Any @test fdiff(x -> f3(x, 1.0), 2.) === 5.0 @test fdiff(y -> f3(1.0, y), 2.) === 2.0 + @test fdiff2(x -> f3(x, 1.0), 2.) === 5.0 + @test fdiff2(y -> f3(1.0, y), 2.) === 2.0 # external module (checks correct type escaping, PR #1446) Enzyme.@import_frule typeof(MockModule.mock_function) MockModule.MockType diff --git a/test/rules.jl b/test/rules.jl index 0ef2e0fe8e8..b306c353fbf 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -61,11 +61,11 @@ end @test autodiff(Forward, f, Duplicated(2.0, 1.0))[1] ≈ 14.0 @test autodiff(Forward, x->f(x)^2, Duplicated(2.0, 1.0))[1] ≈ 832.0 - res = autodiff(Forward, f, BatchDuplicatedNoNeed, BatchDuplicated(2.0, (1.0, 3.0)))[1] + res = autodiff(Forward, f, BatchDuplicated, BatchDuplicated(2.0, (1.0, 3.0)))[1] @test res[1] ≈ 1004.0 @test res[2] ≈ 1012.0 - res = Enzyme.autodiff(Forward, x->f(x)^2, BatchDuplicatedNoNeed, BatchDuplicated(2.0, (1.0, 3.0)))[1] + res = Enzyme.autodiff(Forward, x->f(x)^2, BatchDuplicated, BatchDuplicated(2.0, (1.0, 3.0)))[1] @test res[1] ≈ 80032.0 @test res[2] ≈ 80096.0 @@ -129,7 +129,7 @@ end @testset "Shadow" begin @test Enzyme.autodiff(Forward, h, Duplicated(3.0, 1.0)) == (6000.0,) - @test Enzyme.autodiff(Forward, h, Duplicated, Duplicated(3.0, 1.0)) == (9.0, 60.0) + @test Enzyme.autodiff(ForwardWithPrimal, h, Duplicated(3.0, 1.0)) == (60.0, 9.0) @test Enzyme.autodiff(Forward, h2, Duplicated(3.0, 1.0)) == (1080.0,) @test_throws Enzyme.Compiler.EnzymeRuntimeException Enzyme.autodiff(Forward, h3, Duplicated(3.0, 1.0)) end @@ -149,10 +149,10 @@ function EnzymeRules.forward(config, end @testset "Batch complex" begin - res = autodiff(Forward, foo, BatchDuplicated, BatchDuplicated(0.1 + 0im, (0.2 + 0im, 0.3 + 0im))) # errors, see below - @test res[1] ≈ 0.2 + 0.0im - @test res[2][1] ≈ 0.4 + 0.0im - @test res[2][2] ≈ 0.6 + 0.0im + res = autodiff(ForwardWithPrimal, foo, BatchDuplicated(0.1 + 0im, (0.2 + 0im, 0.3 + 0im))) + @test res[2] ≈ 0.2 + 0.0im + @test res[1][1] ≈ 0.4 + 0.0im + @test res[1][2] ≈ 0.6 + 0.0im end end # module ForwardRules diff --git a/test/runtests.jl b/test/runtests.jl index bdda7604bf2..18d765938d0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -342,8 +342,8 @@ make3() = (1.0, 2.0, 3.0) f1(x) = 1.0 + x f2(x) = x*x @test autodiff(Reverse, f1, Active, Active(1.0))[1][1] ≈ 1.0 - @test autodiff(Forward, f1, DuplicatedNoNeed, Duplicated(1.0, 1.0))[1] ≈ 1.0 - @test autodiff(Forward, f1, Duplicated, Duplicated(1.0, 1.0))[2] ≈ 1.0 + @test autodiff(Forward, f1, Duplicated, Duplicated(1.0, 1.0))[1] ≈ 1.0 + @test autodiff(ForwardWithPrimal, f1, Duplicated, Duplicated(1.0, 1.0))[1] ≈ 1.0 @test autodiff(Reverse, f2, Active, Active(1.0))[1][1] ≈ 2.0 @test autodiff(Forward, f2, Duplicated(1.0, 1.0))[1] ≈ 2.0 tup = autodiff(Forward, f2, BatchDuplicated(1.0, (1.0, 2.0, 3.0)))[1] @@ -1323,8 +1323,8 @@ end (sin(x)::Float64 + x)::Float64 end @test 0.5838531634528576 ≈ Enzyme.autodiff(Reverse, boxfloat, Active, Active(2.0))[1][1] - @test 0.5838531634528576 ≈ Enzyme.autodiff(Forward, boxfloat, DuplicatedNoNeed, Duplicated(2.0, 1.0))[1] - res = Enzyme.autodiff(Forward, boxfloat, BatchDuplicatedNoNeed, BatchDuplicated(2.0, (1.0, 2.0)))[1] + @test 0.5838531634528576 ≈ Enzyme.autodiff(Forward, boxfloat, Duplicated, Duplicated(2.0, 1.0))[1] + res = Enzyme.autodiff(Forward, boxfloat, BatchDuplicated, BatchDuplicated(2.0, (1.0, 2.0)))[1] @test 0.5838531634528576 ≈ res[1] @test 1.1677063269057153 ≈ res[2] end @@ -1420,9 +1420,9 @@ function rtg_f(V,@nospecialize(cv)) end @testset "RuntimeActivity generic call" begin - res = autodiff(set_runtime_activity(Forward), rtg_f, Duplicated, Duplicated([0.2], [1.0]), Const(RTGData(3.14))) - @test 3.14 ≈ res[1] - @test 0.0 ≈ res[2] + res = autodiff(set_runtime_activity(ForwardWithPrimal), rtg_f, Duplicated, Duplicated([0.2], [1.0]), Const(RTGData(3.14))) + @test 3.14 ≈ res[2] + @test 0.0 ≈ res[1] end @inline function myquantile(v::AbstractVector, p::Real; alpha) @@ -1452,9 +1452,9 @@ end @testset "Attributor issues" begin cor = fquantile(2.0) - res = autodiff(Forward, fquantile, Duplicated,Duplicated(2.0, 1.0)) - @test cor ≈ res[1] - @test 0.7 ≈ res[2] + res = autodiff(ForwardWithPrimal, fquantile, Duplicated,Duplicated(2.0, 1.0)) + @test cor ≈ res[2] + @test 0.7 ≈ res[1] end @@ -1739,13 +1739,13 @@ end dx = [1.0, 1.0, 1.0] dx2 = [10.0, 20.0, 30.0] - res = Enzyme.autodiff(Forward, fwdlatestfoo, BatchDuplicated, BatchDuplicated(x, (dx, dx2))) + res = Enzyme.autodiff(ForwardWithPrimal, fwdlatestfoo, BatchDuplicated, BatchDuplicated(x, (dx, dx2))) @test 2.0 ≈ res[1][1] + @test 20.0 ≈ res[1][2] @test 2.0 ≈ res[2][1] - @test 20.0 ≈ res[2][2] - res = Enzyme.autodiff(Forward, fwdlatestfoo, BatchDuplicatedNoNeed, BatchDuplicated(x, (dx, dx2))) + res = Enzyme.autodiff(Forward, fwdlatestfoo, BatchDuplicated, BatchDuplicated(x, (dx, dx2))) @test 2.0 ≈ res[1][1] @test 20.0 ≈ res[1][2] @@ -2712,14 +2712,14 @@ end @testset "Batch Forward" begin square(x)=x*x - bres = autodiff(Forward, square, BatchDuplicatedNoNeed, BatchDuplicated(3.0, (1.0, 2.0, 3.0))) + bres = autodiff(Forward, square, BatchDuplicated, BatchDuplicated(3.0, (1.0, 2.0, 3.0))) @test length(bres) == 1 @test length(bres[1]) == 3 @test bres[1][1] ≈ 6.0 @test bres[1][2] ≈ 12.0 @test bres[1][3] ≈ 18.0 - bres = autodiff(Forward, square, BatchDuplicatedNoNeed, BatchDuplicated(3.0 + 7.0im, (1.0+0im, 2.0+0im, 3.0+0im))) + bres = autodiff(Forward, square, BatchDuplicated, BatchDuplicated(3.0 + 7.0im, (1.0+0im, 2.0+0im, 3.0+0im))) @test bres[1][1] ≈ 6.0 + 14.0im @test bres[1][2] ≈ 12.0 + 28.0im @test bres[1][3] ≈ 18.0 + 42.0im @@ -2729,10 +2729,10 @@ end # Shadow offset is not the same as primal so following doesn't work # d_inp = Float32[1.0, 2.0, 3.0] - # autodiff(Forward, squareidx, BatchDuplicatedNoNeed, BatchDuplicated(view(inp, 1:1), (view(d_inp, 1:1), view(d_inp, 2:2), view(d_inp, 3:3)))) + # autodiff(Forward, squareidx, BatchDuplicated, BatchDuplicated(view(inp, 1:1), (view(d_inp, 1:1), view(d_inp, 2:2), view(d_inp, 3:3)))) d_inp = (Float32[1.0], Float32[2.0], Float32[3.0]) - bres = autodiff(Forward, squareidx, BatchDuplicatedNoNeed, BatchDuplicated(inp, d_inp)) + bres = autodiff(Forward, squareidx, BatchDuplicated, BatchDuplicated(inp, d_inp)) @test bres[1][1] ≈ 6.0 @test bres[1][2] ≈ 12.0 @test bres[1][3] ≈ 18.0 @@ -3559,11 +3559,11 @@ end fn(0.0) end - res = autodiff(set_runtime_activity(Forward), Const(f2), Duplicated, Duplicated(0.2, 1.0)) - @test res[1] ≈ 0.2 + res = autodiff(set_runtime_activity(ForwardWithPrimal), Const(f2), Duplicated, Duplicated(0.2, 1.0)) + @test res[2] ≈ 0.2 # broken as the return of an apply generic is {primal, primal} # but since the return is abstractfloat doing the - @test res[2] ≈ 1.0 + @test res[1] ≈ 1.0 end @inline function uns_mymean(f, A, ::Type{T}, c) where T From bd5dcd10c703c43d5fbabb1e851b849089244dfd Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 16 Sep 2024 14:37:41 -0500 Subject: [PATCH 33/87] Auto upgrade to autodiff_deferred in nested AD (#1839) * WIP * Upgrade non deferred to deferred * cleanup * Update Project.toml * cleanup nested AD example --------- Co-authored-by: Michel Schanen --- examples/autodiff.jl | 4 ++-- src/Enzyme.jl | 45 ++----------------------------------- src/compiler/interpreter.jl | 32 +++++++++++++++++++++++++- test/runtests.jl | 8 +++++++ 4 files changed, 43 insertions(+), 46 deletions(-) diff --git a/examples/autodiff.jl b/examples/autodiff.jl index 669f3b68093..6bd0b74fb57 100644 --- a/examples/autodiff.jl +++ b/examples/autodiff.jl @@ -98,7 +98,7 @@ dby = [0.0] Enzyme.autodiff( Forward, - (x,y) -> Enzyme.autodiff_deferred(Reverse, f, x, y), + (x,y) -> Enzyme.autodiff(Reverse, f, x, y), Duplicated(Duplicated(x, bx), Duplicated(dx, dbx)), Duplicated(Duplicated(y, by), Duplicated(dy, dby)), ) @@ -121,7 +121,7 @@ dbx[2] == 1.0 # \end{aligned} # ``` function grad(x, dx, y, dy) - Enzyme.autodiff_deferred(Reverse, f, Duplicated(x, dx), DuplicatedNoNeed(y, dy)) + Enzyme.autodiff(Reverse, f, Duplicated(x, dx), DuplicatedNoNeed(y, dy)) nothing end diff --git a/src/Enzyme.jl b/src/Enzyme.jl index bb86a33fc79..583035593d0 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1084,31 +1084,6 @@ grad = gradient(Reverse, only ∘ f, (a = 2.0, b = [3.0], c = "str")) end end -""" - gradient_deferred(::ReverseMode, f, x) - -Like [`gradient`](@ref), except it using deferred mode. -""" -@inline function gradient_deferred(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, f::F, x::X) where {F, X, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten} - if Compiler.active_reg_inner(X, #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState - dx = Ref(make_zero(x)) - autodiff_deferred(rm, f, Active, MixedDuplicated(x, dx)) - if ReturnPrimal - return (only(dx), res[2]) - else - return only(dx) - end - else - dx = make_zero(x) - autodiff_deferred(rm, f, Active, Duplicated(x, dx)) - if ReturnPrimal - (dx, res[2]) - else - dx - end - end -end - """ gradient!(::ReverseMode, dx, f, x) @@ -1149,22 +1124,6 @@ gradient!(ReverseWithPrimal, dx, f, [2.0, 3.0]) end end - -""" - gradient_deferred!(::ReverseMode, f, x) - -Like [`gradient!`](@ref), except it using deferred mode. -""" -@inline function gradient_deferred!(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, dx::X, f::F, x::X) where {X<:Array, F, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten} - make_zero!(dx) - autodiff_deferred(rm, f, Active, Duplicated(x, dx)) - return if ReturnPrimal - (dx, res[2]) - else - dx - end -end - """ gradient(::ForwardMode, f, x; shadow=onehot(x)) @@ -1605,7 +1564,7 @@ res """ @inline function hvp!(res::X, f::F, x::X, v::X) where {F, X} grad = make_zero(x) - Enzyme.autodiff(Forward, gradient_deferred!, Const(Reverse), DuplicatedNoNeed(grad, res), Const(f), Duplicated(x, v)) + Enzyme.autodiff(Forward, gradient!, Const(Reverse), DuplicatedNoNeed(grad, res), Const(f), Duplicated(x, v)) return nothing end @@ -1640,7 +1599,7 @@ grad ``` """ @inline function hvp_and_gradient!(res::X, grad::X, f::F, x::X, v::X) where {F, X} - Enzyme.autodiff(Forward, gradient_deferred!, Const(Reverse), Duplicated(grad, res), Const(f), Duplicated(x, v)) + Enzyme.autodiff(Forward, gradient!, Const(Reverse), Duplicated(grad, res), Const(f), Duplicated(x, v)) return nothing end diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 2ef66a15713..482690e20f6 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -212,4 +212,34 @@ let # overload `inlining_policy` end end -end # module Interpreter +import Core.Compiler: abstract_call, abstract_call_known, ArgInfo, StmtInfo, AbsIntState, get_max_methods, + CallMeta, Effects, NoCallInfo, widenconst, mapany + +struct AutodiffCallInfo <: CallInfo + # ... + info::CallInfo +end + +function abstract_call_known(interp::EnzymeInterpreter, @nospecialize(f), + arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, + max_methods::Int = get_max_methods(interp, f, sv)) + + (; fargs, argtypes) = arginfo + + if f === Enzyme.autodiff && length(argtypes) >= 4 + if widenconst(argtypes[2]) <: Enzyme.Mode && widenconst(argtypes[3]) <: Enzyme.Annotation && widenconst(argtypes[4]) <: Type{<:Enzyme.Annotation} + arginfo2 = ArgInfo( + fargs isa Nothing ? nothing : [:(Enzyme.autodiff_deferred), fargs[2:end]...], + [Core.Const(Enzyme.autodiff_deferred), argtypes[2:end]...] + ) + return abstract_call_known( + interp, Enzyme.autodiff_deferred, arginfo2, + si, sv, max_methods) + end + end + return Base.@invoke abstract_call_known( + interp::AbstractInterpreter, f, arginfo::ArgInfo, + si::StmtInfo, sv::AbsIntState, max_methods::Int) +end + +end diff --git a/test/runtests.jl b/test/runtests.jl index 18d765938d0..b079c0f5406 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -486,6 +486,14 @@ end end +@testset "Deferred upgrade" begin + function gradsin(x) + return gradient(Reverse, sin, x) + end + res = Enzyme.gradient(Reverse, gradsin, 3.1) + @test res ≈ -sin(3.1) +end + @testset "Simple Complex tests" begin mul2(z) = 2 * z square(z) = z * z From 786a998f0dc5343703c5420eae40cb790575e218 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 17 Sep 2024 15:30:36 -0500 Subject: [PATCH 34/87] Update sugar apis (#1844) * Update sugar apis * cleanup * cleanup * cleanup * fix * fix * fix stack * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * Update internal_rules.jl * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix --- ext/EnzymeStaticArraysExt.jl | 2 + src/Enzyme.jl | 734 ++++++++++++++++++++--------------- test/ext/logexpfunctions.jl | 4 +- test/internal_rules.jl | 38 +- test/runtests.jl | 326 ++++++++-------- 5 files changed, 609 insertions(+), 495 deletions(-) diff --git a/ext/EnzymeStaticArraysExt.jl b/ext/EnzymeStaticArraysExt.jl index 6dbd390cb76..bcaa3ec6cbb 100644 --- a/ext/EnzymeStaticArraysExt.jl +++ b/ext/EnzymeStaticArraysExt.jl @@ -3,6 +3,8 @@ module EnzymeStaticArraysExt using StaticArrays using Enzyme +@inline Enzyme.tupstack(rows::(NTuple{N, <:StaticArrays.SArray} where N), inshape, outshape) = reshape(cat(rows..., dims=length(inshape)), (inshape..., outshape...)) + @inline function Enzyme.onehot(x::StaticArrays.SArray{S, T, N, L}) where {S, T, N, L} ntuple(Val(L)) do i Base.@_inline_meta diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 583035593d0..66551f29587 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1024,12 +1024,16 @@ end end """ - gradient(::ReverseMode, f, x) + gradient(::ReverseMode, f, args...) Compute the gradient of a real-valued function `f` using reverse mode. -This will allocate and return new array `make_zero(x)` with the gradient result. +For each differentiable argument, this function will allocate and return new derivative object, returning +a tuple of derivatives for each argument. If an argument is not differentiable, the element of the returned +tuple with be nothing. -Besides arrays, for struct `x` it returns another instance of the same type, +In reverse mode (here), the derivatives will be the same type as the original argument. + +This is a structure gradient. For a struct `x` it returns another instance of the same type, whose fields contain the components of the gradient. In the result, `grad.a` contains `∂f/∂x.a` for any differential `x.a`, while `grad.c == x.c` for other types. @@ -1042,44 +1046,128 @@ f(x) = x[1]*x[2] grad = gradient(Reverse, f, [2.0, 3.0]) # output +([3.0, 2.0],) +``` -2-element Vector{Float64}: - 3.0 - 2.0 +```jldoctest gradient +grad = gradient(Reverse, only ∘ f, (a = 2.0, b = [3.0], c = "str")) + +# output + +((a = 3.0, b = [2.0], c = "str"),) ``` ```jldoctest gradient +mul(x, y) = x[1]*y[1] -grad = gradient(ReverseWithPrimal, f, [2.0, 3.0]) +grad = gradient(Reverse, mul, [2.0], [3.0]) # output -([3.0, 2.0], 6.0) +([3.0], [2.0]) ``` ```jldoctest gradient -grad = gradient(Reverse, only ∘ f, (a = 2.0, b = [3.0], c = "str")) + +grad = gradient(Reverse, mul, [2.0], Const([3.0])) + +# output +([3.0], nothing) +``` + +If passing a mode that returns the primal (e.g. ReverseWithPrimal), the return type will instead be +a tuple where the first element contains the derivatives, and the second element contains the result of the original computation. + +```jldoctest gradient + +grad = gradient(ReverseWithPrimal, f, [2.0, 3.0]) + +# output +(([3.0, 2.0],), 6.0) +``` +```jldoctest gradient + +grad = gradient(ReverseWithPrimal, mul, [2.0], [3.0]) # output +(([3.0], [2.0]), 6.0) +``` + +```jldoctest gradient +grad = gradient(ReverseWithPrimal, mul, [2.0], Const([3.0])) -(a = 3.0, b = [2.0], c = "str") +# output +(([3.0], nothing), 6.0) ``` + """ -@inline function gradient(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, f::F, x::X) where {F, X, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten} - if Compiler.active_reg_inner(X, #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState - dx = Ref(make_zero(x)) - res = autodiff(rm, f, Active, MixedDuplicated(x, dx)) - if ReturnPrimal - (only(dx), res[2]) - else - only(dx) +@generated function gradient(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, f::F, x::ty_0, args::Vararg{<:Any, N}) where {F, ty_0, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten, N} + toemit= Expr[quote + act_0 = !(x isa Enzyme.Const) && Compiler.active_reg_inner(Core.Typeof(x), #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState + end] + rargs = Union{Symbol,Expr}[:x] + acts = Symbol[Symbol("act_0")] + + for i in 1:N + argidx = quote args[$i] end + push!(rargs, argidx) + sym = Symbol("act_$i") + push!(acts, sym) + push!(toemit, quote + $sym = !($argidx isa Enzyme.Const) && Compiler.active_reg_inner(Core.Typeof($argidx), #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState + end) + end + + idx = 0 + shadows = Symbol[] + enz_args = Expr[] + resargs = Expr[] + for (arg, act) in zip(rargs, acts) + shad = Symbol("shad_$idx") + push!(shadows, shad) + push!(toemit, quote + $shad = if $arg isa Enzyme.Const + nothing + elseif $act + Ref(make_zero($arg)) + else + make_zero($arg) + end + end) + push!(enz_args, quote + if $arg isa Enzyme.Const + $arg + elseif $act + MixedDuplicated($arg, $shad) + else + Duplicated($arg, $shad) + end + end) + push!(resargs, quote + if $arg isa Enzyme.Const + nothing + elseif $act + $shad[] + else + $shad + end + end) + idx+=1 + end + push!(toemit, quote + res = autodiff(rm, f, Active, $(enz_args...)) + end) + + if ReturnPrimal + return quote + Base.@_inline_meta + $(toemit...) + (($(resargs...),), res[2]) end else - dx = make_zero(x) - res = autodiff(rm, f, Active, Duplicated(x, dx)) - if ReturnPrimal - (dx, res[2]) - else - dx + return quote + Base.@_inline_meta + $(toemit...) + ($(resargs...),) end end end @@ -1100,10 +1188,7 @@ dx = [0.0, 0.0] gradient!(Reverse, dx, f, [2.0, 3.0]) # output - -2-element Vector{Float64}: - 3.0 - 2.0 +([3.0, 2.0],) ``` ```jldoctest gradip @@ -1111,21 +1196,87 @@ dx = [0.0, 0.0] gradient!(ReverseWithPrimal, dx, f, [2.0, 3.0]) # output -([3.0, 2.0], 6.0) +(([3.0, 2.0],), 6.0) ``` """ @inline function gradient!(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, dx::X, f::F, x::X) where {X<:Array, F, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten} make_zero!(dx) res = autodiff(rm, f, Active, Duplicated(x, dx)) return if ReturnPrimal - (dx, res[2]) + ((dx,), res[2]) else - dx + (dx,) + end +end + +@inline function chunkedonehot(x, ::Val{chunk}) where chunk + sz = length(x) + num = ((sz + chunk - 1) ÷ chunk) + ntuple(Val(num)) do i + Base.@_inline_meta + onehot(x, (i-1)*chunk+1, i == num ? sz : (i*chunk) ) + end +end + +@inline function chunkedonehot(x::AbstractFloat, ::Val{chunk}) where chunk + return ((one(x),),) +end + +@inline tupleconcat(x) = x +@inline tupleconcat(x, y) = (x..., y...) +@inline tupleconcat(x, y, z...) = (x..., tupleconcat(y, z...)...) + +function create_shadows(::Nothing, x) + return (onehot(x),) +end + +function create_shadows(::Val{1}, x) + return (onehot(x),) +end + +function create_shadows(::Val{chunk}, x) where chunk + return (chunkedonehot(x, Val(chunk)),) +end + +struct TupleArray{T, Shape, Length, N} <: AbstractArray{T,N} + data::NTuple{Length, T} +end +TupleArray(data::NTuple{Length, T}, Shape) where {Length, T} = TupleArray{T, Shape, Length, length(Shape)}(data) + +@inline Base.eltype(::TupleArray{T}) where T = T +@inline Base.eltype(::Type{<:TupleArray{T}}) where T = T +@inline Base.size(::TupleArray{<:Any, Shape}) where Shape = Shape +@inline Base.ndims(::TupleArray{<:Any, <:Any, <:Any, N}) where N = N + +function Base.convert(::Type{Array{T, N}}, X::TupleArray{T, Shape, Length, N}) where {T, Shape, Length, N} + vals = Array{T, N}(undef, Shape...) + for i in 1:Length + @inbounds val[i] = X.data[i] + end + return vals +end + +function Base.getindex(a::TupleArray, args::Vararg{Int,N}) where {N} + start = 0 + for i in 1:N + start *= size(a, N - i + 1) + start += (args[N - i + 1] - 1) + end + start += 1 + return a.data[start] +end + +@inline function tupstack(x, inshape, outshape) + st = Base.stack(x) + if length(outshape) == 1 + st + else + reshape(st, (inshape..., outshape...)) end end """ - gradient(::ForwardMode, f, x; shadow=onehot(x)) + gradient(::ForwardMode, f, x; shadows=onehot(x), chunk=nothing) Compute the gradient of an array-input function `f` using forward mode. The optional keyword argument `shadow` is a vector of one-hot vectors of type `x` @@ -1138,372 +1289,331 @@ Example: ```jldoctest gradfwd f(x) = x[1]*x[2] -grad = gradient(Forward, f, [2.0, 3.0]) +gradient(Forward, f, [2.0, 3.0]) # output -(3.0, 2.0) +([3.0, 2.0],) ``` ```jldoctest gradfwd gradient(ForwardWithPrimal, f, [2.0, 3.0]) # output -((3.0, 2.0), 6.0) +(([3.0, 2.0],), 6.0) ``` -""" -@inline function gradient(fm::ForwardMode{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity}, f, x; shadow=onehot(x)) where {ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity} - if length(shadow) == 0 - if ReturnPrimal - ((), f(x.val)) - else - return () - end - end - resp = autodiff(fm, f, BatchDuplicated, BatchDuplicated(x, shadow)) - res = values(resp[1]) - dres = if x isa AbstractFloat - res[1] - else - res - end - if ReturnPrimal - (dres, resp[2]) - else - dres - end -end - -@inline function chunkedonehot(x, ::Val{chunk}) where chunk - sz = length(x) - num = ((sz + chunk - 1) ÷ chunk) - ntuple(Val(num)) do i - Base.@_inline_meta - onehot(x, (i-1)*chunk+1, i == num ? sz : (i*chunk) ) - end -end +```jldoctest gradfwd +gradient(Forward, f, [2.0, 3.0]; chunk=Val(1)) -@inline function chunkedonehot(x::AbstractFloat, ::Val{chunk}) where chunk - return ((one(x),),) -end +# output -@inline tupleconcat(x) = x -@inline tupleconcat(x, y) = (x..., y...) -@inline tupleconcat(x, y, z...) = (x..., tupleconcat(y, z...)...) +([3.0, 2.0],) +``` -""" - gradient(::ForwardMode, f, x::Union{Array,NTuple}, ::Val{chunk}; shadow=onehot(x)) +```jldoctest gradfwd +gradient(ForwardWithPrimal, f, [2.0, 3.0]; chunk=Val(1)) -Compute the gradient of an array-input function `f` using vector forward mode. -Like [`gradient`](@ref), except it uses a chunk size of `chunk` to compute -`chunk` derivatives in a single call. +# output +(([3.0, 2.0],), 6.0) +``` -Example: +For functions which return an AbstractArray or scalar, this function will return an AbstracttArray +whose shape is `(size(output)..., size(input)...)`. No guarantees are presently made +about the type of the AbstractArray returned by this function (which may or may not be the same +as the input AbstractArray if provided). +For functions who return other types, this function will retun an AbstractArray +of shape `size(input)` of values of the output type. ```jldoctest -f(x) = x[1]*x[2] +f(x) = [ x[1] * x[2], x[2] + x[3] ] -grad = gradient(Forward, f, [2.0, 3.0], Val(2)) +grad = gradient(Forward, f, [2.0, 3.0, 4.0]) # output - -(3.0, 2.0) +([3.0 2.0 0.0; 0.0 1.0 1.0],) ``` """ -@inline function gradient(fm::ForwardMode{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity}, f::F, x::X, ::Val{chunk}; shadow=chunkedonehot(x, Val(chunk))) where {F, X, chunk, ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity} - if chunk == 0 - throw(ErrorException("Cannot differentiate with a batch size of 0")) - end - if ReturnPrimal - rp = autodiff(fm, f, BatchDuplicated, BatchDuplicated(x, shadow[1]))[1] - dres1 = if chunk == 1 - (rp[1],) - else - values(rp[1]) - end - gres = if x isa AbstractFloat - dres1 +@inline function gradient(fm::ForwardMode{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity}, f, x; chunk::CS=nothing, shadows=create_shadows(chunk, x)) where {ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity, CS} + if length(shadows[1]) == 0 + if ReturnPrimal + ((x,), f(x.val)) else - fm2 = ForwardMode{#=ReturnPrimal=#false, ABI, ErrIfFuncWritten,RuntimeActivity}() - tmp = ntuple(length(shadow)-1) do i - values(autodiff(fm2, f, BatchDuplicated, BatchDuplicated(x, shadow[i+1]))[1]) - end - tupleconcat(dres1, tmp...) - end - (gres, rp[2]) - else - tmp = ntuple(length(shadow)) do i - values(autodiff(fm, f, BatchDuplicated, BatchDuplicated(x, shadow[i]))[1]) + return (x,) end - res = tupleconcat(tmp...) - if x isa AbstractFloat + end + if chunk == Val(0) + throw(ErrorException("Cannot differentiate with a batch size of 0")) + end + + gradtup = if chunk == nothing + resp = autodiff(fm, f, BatchDuplicated, BatchDuplicated(x, shadows[1])) + + res = values(resp[1]) + dres = if x isa AbstractFloat res[1] else res end - end -end - -@inline function gradient(fm::ForwardMode{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity}, f::F, x::X, ::Val{1}; shadow=onehot(x)) where {F, X, ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity} - if ReturnPrimal - rp = autodiff(fm, f, Duplicated, Duplicated(x, shadow[1])) - dres1 = rp[1] - fm2 = ForwardMode{#=ReturnPrimal=#false, ABI, ErrIfFuncWritten,RuntimeActivity}() - - res = ntuple(length(shadow)-1) do i - autodiff(fm2, f, Duplicated, Duplicated(x, shadow[i+1]))[1] + if ReturnPrimal + ((dres,), resp[2]) + else + (dres,) end - gres = if x isa AbstractFloat - dres1 + elseif chunk == Val(1) + if ReturnPrimal + rp = autodiff(fm, f, Duplicated, Duplicated(x, shadows[1][1])) + dres1 = rp[1] + fm2 = ForwardMode{#=ReturnPrimal=#false, ABI, ErrIfFuncWritten,RuntimeActivity}() + + res = ntuple(length(shadows[1])-1) do i + autodiff(fm2, f, Duplicated, Duplicated(x, shadows[1][i+1]))[1] + end + gres = if x isa AbstractFloat + dres1[1] + else + (dres1, res...) + end + ((gres,), rp[2]) else - (dres1, res...) + res = ntuple(length(shadows[1])) do i + autodiff(fm, f, Duplicated, Duplicated(x, shadows[1][i]))[1] + end + (if x isa AbstractFloat + res[1] + else + res + end,) end - (gres, rp[2]) else - res = ntuple(length(shadow)) do i - autodiff(fm, f, Duplicated, Duplicated(x, shadow[i]))[1] - end - if x isa AbstractFloat - res[1] + if ReturnPrimal + rp = autodiff(fm, f, BatchDuplicated, BatchDuplicated(x, shadows[1][1])) + dres1 = values(rp[1]) + gres = if x isa AbstractFloat + dres1[1] + else + fm2 = ForwardMode{#=ReturnPrimal=#false, ABI, ErrIfFuncWritten,RuntimeActivity}() + tmp = ntuple(length(shadows[1])-1) do i + values(autodiff(fm2, f, BatchDuplicated, BatchDuplicated(x, shadows[1][i+1]))[1]) + end + tupleconcat(dres1, tmp...) + end + ((gres,), rp[2]) else - res + tmp = ntuple(length(shadows[1])) do i + values(autodiff(fm, f, BatchDuplicated, BatchDuplicated(x, shadows[1][i]))[1]) + end + res = tupleconcat(tmp...) + (if x isa AbstractFloat + res[1] + else + res + end,) end end -end - -""" - jacobian(::ForwardMode, f, x; shadow=onehot(x)) - jacobian(::ForwardMode, f, x, ::Val{chunk}; shadow=onehot(x)) - -Compute the jacobian of an array or scalar-input function `f` using (potentially vector) -forward mode. All relevant arguments of the forward-mode [`gradient`](@ref) function -apply here. - -Example: - -```jldoctest -f(x) = [ x[1] * x[2], x[2] + x[3] ] - -grad = jacobian(Forward, f, [2.0, 3.0, 4.0]) - -# output -2×3 Matrix{Float64}: - 3.0 2.0 0.0 - 0.0 1.0 1.0 -``` - -For functions which return an AbstractArray, this function will return an array -whose shape is `(size(output)..., size(input)...)` - -For functions who return other types, this function will retun an array or tuple -of shape `size(input)` of values of the output type. -""" -@inline function jacobian(fm::ForwardMode{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity}, args...; kwargs...) where {ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity} - gradtup = gradient(fm, args...; kwargs...) cols = if ReturnPrimal - gradtup[1] + gradtup[1][1] else - gradtup + gradtup[1] end - x = args[2] res = if x isa AbstractFloat cols - elseif length(cols) > 0 && cols[1] isa AbstractArray + elseif length(cols) > 0 && cols[1] isa AbstractArray && x isa AbstractArray inshape = size(x) outshape = size(cols[1]) # st : outshape x total inputs - st = Base.stack(cols) - - st3 = if length(inshape) <= 1 - st - else - reshape(st, (outshape..., inshape...)) - end - - st3 + tupstack(cols, outshape, inshape) elseif x isa AbstractArray - inshape = size(x) - reshape(collect(cols), inshape) + TupleArray(cols, size(x)) else cols end if ReturnPrimal - (res, gradtup[2]) + ((res,), gradtup[2]) else - res + (res,) end end """ - jacobian(::ReverseMode, f, x, ::Val{num_outs}, ::Val{chunk}=Val(1)) + jacobian(::ForwardMode, args...; kwargs...) + +Equivalent to gradient(::ForwardMode, args...; kwargs...) +""" +@inline function jacobian(fm::ForwardMode, args...; kwargs...) + gradient(fm, args...; kwargs...) +end + +""" + jacobian(::ReverseMode, f, x; n_outs=nothing, chunk=nothing) jacobian(::ReverseMode, f, x) -Compute the jacobian of an array-output function `f` using (potentially vector) -reverse mode. The `chunk` argument denotes the chunk size to use and `num_outs` -denotes the number of outputs `f` will return in an array. +Compute the jacobian of a array-output function `f` using (potentially vector) +reverse mode. The `chunk` argument denotes the chunk size to use and `n_outs` +denotes the shape of the array returned by `f`. Example: ```jldoctest f(x) = [ x[1] * x[2], x[2] + x[3] ] -grad = jacobian(Reverse, f, [2.0, 3.0, 4.0], Val(2)) +jacobian(Reverse, f, [2.0, 3.0, 4.0]) # output +([3.0 2.0 0.0; 0.0 1.0 1.0],) +``` -2×3 transpose(::Matrix{Float64}) with eltype Float64: - 3.0 2.0 0.0 - 0.0 1.0 1.0 +```jldoctest +f(x) = [ x[1] * x[2], x[2] + x[3] ] + +grad = jacobian(Reverse, f, [2.0, 3.0, 4.0], n_outs=Val((2,))) + +# output +([3.0 2.0 0.0; 0.0 1.0 1.0],) ``` -For functions which return an AbstractArray, this function will return an array -whose shape is `(size(output)..., size(input)...)` +This function will return an AbstractArray whose shape is `(size(output)..., size(input)...)`. +No guarantees are presently made about the type of the AbstractArray returned by this function +(which may or may not be the same as the input AbstractArray if provided). -For functions who return other types, this function will retun an array or tuple -of shape `size(output)` of values of the input type. +In the future, when this function is extended to handle non-array return types, +this function will retun an AbstractArray of shape `size(output)` of values of the input type. ``` """ -@inline function jacobian(::ReverseMode{#=ReturnPrimal=#false,RuntimeActivity, RABI, #=Holomorphic=#false, ErrIfFuncWritten}, f::F, x::X, n_outs::Val{n_out_val}, ::Val{chunk}) where {F, X, chunk, n_out_val, RABI<:ABI, ErrIfFuncWritten, RuntimeActivity} - num = ((n_out_val + chunk - 1) ÷ chunk) - - if chunk == 0 - throw(ErrorException("Cannot differentiate with a batch size of 0")) - end - - XT = Core.Typeof(x) - MD = Compiler.active_reg_inner(XT, #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState - tt′ = MD ? Tuple{BatchMixedDuplicated{XT, chunk}} : Tuple{BatchDuplicated{XT, chunk}} - tt = Tuple{XT} - rt = Core.Compiler.return_type(f, tt) - ModifiedBetween = Val((false, false)) - FA = Const{Core.Typeof(f)} - opt_mi = if RABI <: NonGenABI - Compiler.fspec(eltype(FA), tt′) - else - Val(codegen_world_age(Core.Typeof(f), tt)) - end - primal, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(chunk), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) - - if num * chunk == n_out_val - last_size = chunk - primal2, adjoint2 = primal, adjoint - else - last_size = n_out_val - (num-1)*chunk - tt′ = Tuple{BatchDuplicated{Core.Typeof(x), last_size}} - primal2, adjoint2 = Enzyme.Compiler.thunk(opt_mi, FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(last_size), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) - end +@inline function jacobian(::ReverseMode{ReturnPrimal,RuntimeActivity, RABI, Holomorphic, ErrIfFuncWritten}, f::F, x::X; n_outs::OutType=nothing, chunk::CT=nothing) where {ReturnPrimal, F, X, RABI<:ABI, ErrIfFuncWritten, RuntimeActivity, OutType, CT, Holomorphic} - tmp = ntuple(num) do i - Base.@_inline_meta - dx = ntuple(Val(i == num ? last_size : chunk)) do idx - Base.@_inline_meta - z = make_zero(x) - MD ? Ref(z) : z - end - res = (i == num ? primal2 : primal)(Const(f), MD ? BatchMixedDuplicated(x, dx) : BatchDuplicated(x, dx)) - tape = res[1] - j = 0 - for shadow in res[3] - j += 1 - @inbounds shadow[(i-1)*chunk+j] += Compiler.default_adjoint(eltype(typeof(shadow))) + if n_outs == nothing + res = if f isa Const + f.val(x) + else + f(x) end - (i == num ? adjoint2 : adjoint)(Const(f), MD ? BatchMixedDuplicated(x, dx) : BatchDuplicated(x, dx), tape) - return MD ? (ntuple(Val(i == num ? last_size : chunk)) do idx - Base.@_inline_meta - dx[idx][] - end) : dx, (i == 1 ? size(res[3][1]) : nothing) - end - rows = tupleconcat(map(first, tmp)...) - outshape = tmp[1][2] - if x isa AbstractArray - inshape = size(x) - - st = Base.stack(rows) - - st2 = if length(outshape) == 1 - st + jac = if res isa AbstractArray + jacobian(ReverseMode{false,RuntimeActivity,RABI, Holomorphic, ErrIfFuncWritten}(), f, x; n_outs=Val(size(res)), chunk) + elseif res isa AbstractFloat + gradient(ReverseMode{false,RuntimeActivity,RABI, Holomorphic, ErrIfFuncWritten}(), f, x) else - reshape(st, (inshape..., outshape...)) + throw(AssertionError("Unsupported return type of function for reverse-mode jacobian, $(Core.Typeof(res))")) end - st3 = if length(outshape) == 1 && length(inshape) == 1 - transpose(st2) + return if ReturnPrimal + (jac, res) else - transp = ( ((length(inshape)+1):(length(inshape)+length(outshape)))... , (1:length(inshape))... ) - PermutedDimsArray(st2, transp) + jac end - - st3 else - reshape(collect(rows), outshape) - end -end - -@inline function jacobian(::ReverseMode{#=ReturnPrimal=#false,RuntimeActivity,RABI, #=Holomorphic=#false, ErrIfFuncWritten}, f::F, x::X, n_outs::Val{n_out_val}, ::Val{1} = Val(1)) where {F, X, n_out_val,RuntimeActivity,RABI<:ABI, ErrIfFuncWritten} - XT = Core.Typeof(x) - MD = Compiler.active_reg_inner(XT, #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState - tt′ = MD ? Tuple{MixedDuplicated{XT}} : Tuple{Duplicated{XT}} - tt = Tuple{XT} - rt = Core.Compiler.return_type(f, tt) - ModifiedBetween = Val((false, false)) - FA = Const{Core.Typeof(f)} - opt_mi = if RABI <: NonGenABI - Compiler.fspec(eltype(FA), tt′) - else - Val(codegen_world_age(Core.Typeof(f), tt)) - end - primal, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, DuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(1), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) - tmp = ntuple(n_outs) do i - Base.@_inline_meta - z = make_zero(x) - dx = MD ? Ref(z) : z - res = primal(Const(f), MD ? MixedDuplicated(x, dx) : Duplicated(x, dx)) - tape = res[1] - @inbounds res[3][i] += Compiler.default_adjoint(eltype(typeof(res[3]))) - adjoint(Const(f), MD ? MixedDuplicated(x, dx) : Duplicated(x, dx), tape) - return MD ? dx[] : dx, (i == 1 ? size(res[3]) : nothing) - end - rows = map(first, tmp) - outshape = tmp[1][2] - if x isa AbstractArray - inshape = size(x) - st = Base.stack(rows) - - st2 = if length(outshape) == 1 - st + @assert !Holomorphic + n_out_val = if length(Compiler.element(n_outs)) == 0 + 0 else - reshape(st, (inshape..., outshape...)) + prod(Compiler.element(n_outs)) end - - st3 = if length(outshape) == 1 && length(inshape) == 1 - transpose(st2) + + if chunk == Val(0) + throw(ErrorException("Cannot differentiate with a batch size of 0")) + end + + XT = Core.Typeof(x) + MD = Compiler.active_reg_inner(XT, #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState + tt = Tuple{XT} + rt = if f isa Const + Core.Compiler.return_type(f.val, tt) + else + Core.Compiler.return_type(f, tt) + end + + ModifiedBetween = Val((false, false)) + FRT = Core.Typeof(f) + FA = Const{FRT} + + opt_mi = if RABI <: NonGenABI + Compiler.fspec(FRT, tt′) else - transp = ( ((length(inshape)+1):(length(inshape)+length(outshape)))... , (1:length(inshape))... ) - PermutedDimsArray(st2, transp) + Val(codegen_world_age(FRT, tt)) end - st3 - else - reshape(collect(rows), outshape) - end -end + if chunk == Val(1) || chunk == nothing + tt′ = MD ? Tuple{MixedDuplicated{XT}} : Tuple{Duplicated{XT}} + primal, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, DuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(1), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) + tmp = ntuple(Val(n_out_val)) do i + Base.@_inline_meta + z = make_zero(x) + dx = MD ? Ref(z) : z + res = primal(Const(f), MD ? MixedDuplicated(x, dx) : Duplicated(x, dx)) + tape = res[1] + @inbounds res[3][i] += Compiler.default_adjoint(eltype(typeof(res[3]))) + adjoint(Const(f), MD ? MixedDuplicated(x, dx) : Duplicated(x, dx), tape) + return MD ? dx[] : dx, (i == 1 ? size(res[3]) : nothing) + end + rows = map(first, tmp) + outshape = tmp[1][2] + rows, outshape + else + chunksize = Compiler.element(chunk) + tt′ = MD ? Tuple{BatchMixedDuplicated{XT, chunksize}} : Tuple{BatchDuplicated{XT, chunksize}} + primal, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#chunk, ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) + + num = ((n_out_val + chunksize - 1) ÷ chunksize) + + if num * chunksize == n_out_val + last_size = chunksize + primal2, adjoint2 = primal, adjoint + else + last_size = n_out_val - (num-1)*chunksize + tt′ = Tuple{BatchDuplicated{Core.Typeof(x), last_size}} + primal2, adjoint2 = Enzyme.Compiler.thunk(opt_mi, FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(last_size), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) + end -@inline function jacobian(::ReverseMode{ReturnPrimal,RuntimeActivity, RABI, Holomorphic, ErrIfFuncWritten}, f::F, x::X) where {ReturnPrimal, F, X, RABI<:ABI, ErrIfFuncWritten, RuntimeActivity, Holomorphic} - res = f(x) - jac = if res isa AbstractArray - jacobian(ReverseMode{false,RuntimeActivity,RABI, Holomorphic, ErrIfFuncWritten}(), f, x, Val(length(res))) - elseif res isa AbstractFloat - gradient(ReverseMode{false,RuntimeActivity,RABI, Holomorphic, ErrIfFuncWritten}(), f, x) - else - throw(AssertionError("Unsupported return type of function for reverse-mode jacobian, $(Core.Typeof(res))")) - end + tmp = ntuple(num) do i + Base.@_inline_meta + dx = ntuple(Val(i == num ? last_size : chunksize)) do idx + Base.@_inline_meta + z = make_zero(x) + MD ? Ref(z) : z + end + res = (i == num ? primal2 : primal)(Const(f), MD ? BatchMixedDuplicated(x, dx) : BatchDuplicated(x, dx)) + tape = res[1] + j = 0 + for shadow in res[3] + j += 1 + @inbounds shadow[(i-1)*chunksize+j] += Compiler.default_adjoint(eltype(typeof(shadow))) + end + (i == num ? adjoint2 : adjoint)(Const(f), MD ? BatchMixedDuplicated(x, dx) : BatchDuplicated(x, dx), tape) + return MD ? (ntuple(Val(i == num ? last_size : chunksize)) do idx + Base.@_inline_meta + dx[idx][] + end) : dx, (i == 1 ? size(res[3][1]) : nothing) + end + rows = tupleconcat(map(first, tmp)...) + outshape = tmp[1][2] + rows, outshape + end + res = if x isa AbstractArray + inshape = size(x) + st2 = tupstack(rows, inshape, outshape) - if ReturnPrimal - (res, jac) - else - jac + st3 = if length(outshape) == 1 && length(inshape) == 1 + transpose(st2) + else + transp = ( ((length(inshape)+1):(length(inshape)+length(outshape)))... , (1:length(inshape))... ) + PermutedDimsArray(st2, transp) + end + + st3 + else + reshape(collect(rows), outshape) + end + if ReturnPrimal + # TODO optimize away redundant fwd pass + (res, if f isa Enzyme.Const + f.val(x) + else + f(x) + end) + else + (res,) + end end end diff --git a/test/ext/logexpfunctions.jl b/test/ext/logexpfunctions.jl index 69ee7f2e733..51dbe2ec76f 100644 --- a/test/ext/logexpfunctions.jl +++ b/test/ext/logexpfunctions.jl @@ -9,6 +9,6 @@ xlogydiff(x) = xlogy(x[1], 23.0) grad_forward = Enzyme.gradient(Enzyme.Forward, xlogydiff, x) grad_reverse = Enzyme.gradient(Enzyme.Reverse, xlogydiff, x) - @test grad_forward[1] ≈ log(23.0) - @test grad_reverse[1] ≈ log(23.0) + @test grad_forward[1] ≈ [log(23.0)] + @test grad_reverse[1] ≈ [log(23.0)] end diff --git a/test/internal_rules.jl b/test/internal_rules.jl index b9a705941c9..32a206c62e6 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -197,14 +197,14 @@ end dL = zero(x) dL[2, 1] = 1.0 - @test Enzyme.gradient(Reverse, chol_lower0, x) ≈ [0.05270807565639164 0.0 0.0 0.0; 1.0000000000000024 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] + @test Enzyme.gradient(Reverse, chol_lower0, x)[1] ≈ [0.05270807565639164 0.0 0.0 0.0; 1.0000000000000024 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] - @test reshape(collect(Enzyme.gradient(Forward, chol_lower0, x)), 4, 4) ≈ [0.05270807565639164 0.0 0.0 0.0; 1.0000000000000024 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] + @test Enzyme.gradient(Forward, chol_lower0, x)[1] ≈ [0.05270807565639164 0.0 0.0 0.0; 1.0000000000000024 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] @test FiniteDifferences.grad(central_fdm(5, 1), chol_lower0, x)[1] ≈ [0.05270807565639164 0.0 0.0 0.0; 1.0000000000000024 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] - @test reshape(collect(Enzyme.gradient(Forward, chol_upper0, x)), 4, 4) ≈ [0.05270807565639728 0.9999999999999999 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] - @test Enzyme.gradient(Reverse, chol_upper0, x) ≈ [0.05270807565639728 0.9999999999999999 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] + @test Enzyme.gradient(Forward, chol_upper0, x)[1] ≈ [0.05270807565639728 0.9999999999999999 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] + @test Enzyme.gradient(Reverse, chol_upper0, x)[1] ≈ [0.05270807565639728 0.9999999999999999 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] @test FiniteDifferences.grad(central_fdm(5, 1), chol_upper0, x)[1] ≈ [0.05270807565639728 0.9999999999999999 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0] end @@ -225,14 +225,14 @@ end x = [1.0 0.13147601759884564 0.5282944836504488; 0.13147601759884564 1.0 0.18506733179093515; 0.5282944836504488 0.18506733179093515 1.0] for i in 1:size(x, 1) for j in 1:size(x, 2) - reverse_grad = Enzyme.gradient(Reverse, x -> tchol_lower(x, i, j), x) - forward_grad = reshape(collect(Enzyme.gradient(Forward, x -> tchol_lower(x, i, j), x)), size(x)) + reverse_grad = Enzyme.gradient(Reverse, x -> tchol_lower(x, i, j), x)[1] + forward_grad = Enzyme.gradient(Forward, x -> tchol_lower(x, i, j), x)[1] finite_diff = FiniteDifferences.grad(central_fdm(5, 1), x -> tchol_lower(x, i, j), x)[1] @test reverse_grad ≈ finite_diff @test forward_grad ≈ finite_diff - reverse_grad = Enzyme.gradient(Reverse, x -> tchol_upper(x, i, j), x) - forward_grad = reshape(collect(Enzyme.gradient(Forward, x -> tchol_upper(x, i, j), x)), size(x)) + reverse_grad = Enzyme.gradient(Reverse, x -> tchol_upper(x, i, j), x)[1] + forward_grad = Enzyme.gradient(Forward, x -> tchol_upper(x, i, j), x)[1] finite_diff = FiniteDifferences.grad(central_fdm(5, 1), x -> tchol_upper(x, i, j), x)[1] @test reverse_grad ≈ finite_diff @test forward_grad ≈ finite_diff @@ -257,26 +257,26 @@ end x = [1.0 0.13147601759884564 0.5282944836504488; 0.13147601759884564 1.0 0.18506733179093515; 0.5282944836504488 0.18506733179093515 1.0] for i in 1:15 B = [3.1 2.7 5.9 2.4 1.6; 7.9 8.2 1.3 9.4 5.5; 4.7 2.9 9.8 7.1 4.3] - reverse_grad = Enzyme.gradient(Reverse, Const(B -> tcholsolv_lower(x, B, i)), B) - # forward_grad = reshape(collect(Enzyme.gradient(Forward, B -> tcholsolv_lower(x, B, i), B)), size(B)) + reverse_grad = Enzyme.gradient(Reverse, Const(B -> tcholsolv_lower(x, B, i)), B)[1] + # forward_grad = Enzyme.gradient(Forward, B -> tcholsolv_lower(x, B, i), B)[1] finite_diff = FiniteDifferences.grad(central_fdm(5, 1), B -> tcholsolv_lower(x, B, i), B)[1] @test reverse_grad ≈ finite_diff # @test forward_grad ≈ finite_diff - reverse_grad = Enzyme.gradient(Reverse, Const(B -> tcholsolv_upper(x, B, i)), B) - # forward_grad = reshape(collect(Enzyme.gradient(Forward, B -> tcholsolv_upper(x, B, i), B)), size(B)) + reverse_grad = Enzyme.gradient(Reverse, Const(B -> tcholsolv_upper(x, B, i)), B)[1] + # forward_grad = Enzyme.gradient(Forward, B -> tcholsolv_upper(x, B, i), B))[1] finite_diff = FiniteDifferences.grad(central_fdm(5, 1), B -> tcholsolv_upper(x, B, i), B)[1] @test reverse_grad ≈ finite_diff # @test forward_grad ≈ finite_diff - reverse_grad = Enzyme.gradient(Reverse, Const(x -> tcholsolv_lower(x, B, i)), x) - #forward_grad = reshape(collect(Enzyme.gradient(Forward, x -> tcholsolv_lower(x, B, i), x)), size(x)) + reverse_grad = Enzyme.gradient(Reverse, Const(x -> tcholsolv_lower(x, B, i)), x)[1] + #forward_grad = Enzyme.gradient(Forward, x -> tcholsolv_lower(x, B, i), x)[1] finite_diff = FiniteDifferences.grad(central_fdm(5, 1), x -> tcholsolv_lower(x, B, i), x)[1] @test reverse_grad ≈ finite_diff #@test forward_grad ≈ finite_diff # - reverse_grad = Enzyme.gradient(Reverse, Const(x -> tcholsolv_upper(x, B, i)), x) - #forward_grad = reshape(collect(Enzyme.gradient(Forward, x -> tcholsolv_upper(x, B, i), x)), size(x)) + reverse_grad = Enzyme.gradient(Reverse, Const(x -> tcholsolv_upper(x, B, i)), x)[1] + #forward_grad = Enzyme.gradient(Forward, x -> tcholsolv_upper(x, B, i), x)[1] finite_diff = FiniteDifferences.grad(central_fdm(5, 1), x -> tcholsolv_upper(x, B, i), x)[1] @test reverse_grad ≈ finite_diff #@test forward_grad ≈ finite_diff @@ -554,7 +554,7 @@ end b = [1., 2.] dA = zero(A) Enzyme.autodiff(Reverse, h, Active, Duplicated(A, dA), Const(b)) - # dA_fwd = Enzyme.gradient(Forward, A->h(A, b), A) + # dA_fwd = Enzyme.gradient(Forward, A->h(A, b), A)[1] dA_fd = FiniteDifferences.grad(central_fdm(5, 1), A->h(A, b), A)[1] @test isapprox(dA, dA_fd) @@ -571,9 +571,9 @@ end @testset "Cholesky upper triangular v1" begin x = [1.0, -0.10541615131279458, 0.6219810761363638, 0.293343219811946, -0.10541615131279458, 1.0, -0.05258941747718969, 0.34629296878264443, 0.6219810761363638, -0.05258941747718969, 1.0, 0.4692436399208845, 0.293343219811946, 0.34629296878264443, 0.4692436399208845, 1.0] - @test collect(Enzyme.gradient(Forward, chol_upper, x)) ≈ [0.05270807565639728, 0.0, 0.0, 0.0, 0.9999999999999999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + @test Enzyme.gradient(Forward, chol_upper, x)[1] ≈ [0.05270807565639728, 0.0, 0.0, 0.0, 0.9999999999999999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] - @test Enzyme.gradient(Reverse, chol_upper, x) ≈ [0.05270807565639728, 0.0, 0.0, 0.0, 0.9999999999999999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + @test Enzyme.gradient(Reverse, chol_upper, x)[1] ≈ [0.05270807565639728, 0.0, 0.0, 0.0, 0.9999999999999999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] end @testset "Linear solve for triangular matrices" begin diff --git a/test/runtests.jl b/test/runtests.jl index b079c0f5406..65ad4e3fd4e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -488,9 +488,9 @@ end @testset "Deferred upgrade" begin function gradsin(x) - return gradient(Reverse, sin, x) + return gradient(Reverse, sin, x)[1] end - res = Enzyme.gradient(Reverse, gradsin, 3.1) + res = Enzyme.gradient(Reverse, gradsin, 3.1)[1] @test res ≈ -sin(3.1) end @@ -2794,43 +2794,43 @@ end @testset "Gradient & NamedTuples" begin xy = (x = [1.0, 2.0], y = [3.0, 4.0]) - grad = Enzyme.gradient(Reverse, z -> sum(z.x .* z.y), xy) + grad = Enzyme.gradient(Reverse, z -> sum(z.x .* z.y), xy)[1] @test grad == (x = [3.0, 4.0], y = [1.0, 2.0]) xp = (x = [1.0, 2.0], p = 3) # 3::Int is non-diff - grad = Enzyme.gradient(Reverse, z -> sum(z.x .^ z.p), xp) + grad = Enzyme.gradient(Reverse, z -> sum(z.x .^ z.p), xp)[1] @test grad.x == [3.0, 12.0] xp2 = (x = [1.0, 2.0], p = 3.0) # mixed activity - grad = Enzyme.gradient(Reverse, z -> sum(z.x .^ z.p), xp2) + grad = Enzyme.gradient(Reverse, z -> sum(z.x .^ z.p), xp2)[1] @test grad.x == [3.0, 12.0] @test grad.p ≈ 5.545177444479562 xy = (x = [1.0, 2.0], y = [3, 4]) # y is non-diff - grad = Enzyme.gradient(Reverse, z -> sum(z.x .* z.y), xy) + grad = Enzyme.gradient(Reverse, z -> sum(z.x .* z.y), xy)[1] @test grad.x == [3.0, 4.0] @test grad.y === xy.y # make_zero did not copy this - grad = Enzyme.gradient(Reverse, z -> (z.x * z.y), (x=5.0, y=6.0)) + grad = Enzyme.gradient(Reverse, z -> (z.x * z.y), (x=5.0, y=6.0))[1] @test grad == (x = 6.0, y = 5.0) - grad = Enzyme.gradient(Reverse, abs2, 7.0) + grad = Enzyme.gradient(Reverse, abs2, 7.0)[1] @test grad == 14.0 end @testset "Gradient & SparseArrays / StaticArrays" begin x = sparse([5.0, 0.0, 6.0]) - dx = Enzyme.gradient(Reverse, sum, x) + dx = Enzyme.gradient(Reverse, sum, x)[1] @test dx isa SparseVector @test dx ≈ [1, 0, 1] x = sparse([5.0 0.0 6.0]) - dx = Enzyme.gradient(Reverse, sum, x) + dx = Enzyme.gradient(Reverse, sum, x)[1] @test dx isa SparseMatrixCSC @test dx ≈ [1 0 1] x = @SArray [5.0 0.0 6.0] - dx = Enzyme.gradient(Reverse, prod, x) + dx = Enzyme.gradient(Reverse, prod, x)[1] @test dx isa SArray @test dx ≈ [0 30 0] @@ -2851,7 +2851,7 @@ end @test y[2] == [0.0, 0.0, 1.0] x = @SArray [5.0 0.0 6.0] - dx = Enzyme.gradient(Forward, prod, x) + dx = Enzyme.gradient(Forward, prod, x)[1] @test dx[1] ≈ 0 @test dx[2] ≈ 30 @test dx[3] ≈ 0 @@ -2906,264 +2906,266 @@ mkarray(sz, args...) = reshape(vcat(args...), sz) scalar = 3.0 # ∂ scalar / ∂ scalar - @test Enzyme.gradient(Enzyme.Forward, x -> x^2, scalar) ≈ 6.0 - @test Enzyme.gradient(Enzyme.Reverse, x -> x^2, scalar) ≈ 6.0 - @test Enzyme.jacobian(Enzyme.Forward, x -> x^2, scalar) ≈ 6.0 - @test Enzyme.jacobian(Enzyme.Reverse, x -> x^2, scalar) ≈ 6.0 - @test Enzyme.gradient(Enzyme.Forward, x -> 2*x, scalar) ≈ 2.0 - @test Enzyme.gradient(Enzyme.Reverse, x -> 2*x, scalar) ≈ 2.0 - @test Enzyme.jacobian(Enzyme.Forward, x -> 2*x, scalar) ≈ 2.0 - @test Enzyme.jacobian(Enzyme.Reverse, x -> 2*x, scalar) ≈ 2.0 + @test Enzyme.gradient(Enzyme.Forward, x -> x^2, scalar)[1] ≈ 6.0 + @test Enzyme.gradient(Enzyme.Reverse, x -> x^2, scalar)[1] ≈ 6.0 + @test Enzyme.jacobian(Enzyme.Forward, x -> x^2, scalar)[1] ≈ 6.0 + @test Enzyme.jacobian(Enzyme.Reverse, x -> x^2, scalar)[1] ≈ 6.0 + @test Enzyme.gradient(Enzyme.Forward, x -> 2*x, scalar)[1] ≈ 2.0 + @test Enzyme.gradient(Enzyme.Reverse, x -> 2*x, scalar)[1] ≈ 2.0 + @test Enzyme.jacobian(Enzyme.Forward, x -> 2*x, scalar)[1] ≈ 2.0 + @test Enzyme.jacobian(Enzyme.Reverse, x -> 2*x, scalar)[1] ≈ 2.0 # ∂ vector / ∂ scalar - @test Enzyme.gradient(Enzyme.Forward, x -> [2*x, x^2], scalar) ≈ [2.0, 6.0] - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [2*x, x^2], scalar) ≈ [2.0, 6.0] + @test Enzyme.gradient(Enzyme.Forward, x -> [2*x, x^2], scalar)[1] ≈ [2.0, 6.0] + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [2*x, x^2], scalar)[1] ≈ [2.0, 6.0] - @test Enzyme.jacobian(Enzyme.Forward, x -> [2*x, x^2], scalar) ≈ [2.0, 6.0] - @test Enzyme.jacobian(Enzyme.Reverse, x -> [2*x, x^2], scalar) ≈ [2.0, 6.0] + @test Enzyme.jacobian(Enzyme.Forward, x -> [2*x, x^2], scalar)[1] ≈ [2.0, 6.0] + @test Enzyme.jacobian(Enzyme.Reverse, x -> [2*x, x^2], scalar)[1] ≈ [2.0, 6.0] # ∂ tuple / ∂ scalar - @test Enzyme.gradient(Enzyme.Forward, x -> (2*x, x^2), scalar) ≃ (2.0, 6.0) - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (2*x, x^2), scalar) ≈ [2.0, 6.0] + @test Enzyme.gradient(Enzyme.Forward, x -> (2*x, x^2), scalar)[1] ≃ (2.0, 6.0) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (2*x, x^2), scalar)[1] ≈ [2.0, 6.0] - @test Enzyme.jacobian(Enzyme.Forward, x -> (2*x, x^2), scalar) ≃ (2.0, 6.0) - @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (2*x, x^2), scalar) ≃ (2.0, 6.0) + @test Enzyme.jacobian(Enzyme.Forward, x -> (2*x, x^2), scalar)[1] ≃ (2.0, 6.0) + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (2*x, x^2), scalar)[1] ≃ (2.0, 6.0) mkarray1 = x -> mkarray((2,2),2*x,sin(x),x^2,exp(x)) # ∂ matrix / ∂ scalar - @test Enzyme.gradient(Enzyme.Forward, mkarray1, scalar) ≈ [2.0 6.0; cos(scalar) exp(scalar)] - @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray1, scalar) ≈ [2.0 6.0; cos(scalar) exp(scalar)] + @test Enzyme.gradient(Enzyme.Forward, mkarray1, scalar)[1] ≈ [2.0 6.0; cos(scalar) exp(scalar)] + @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray1, scalar)[1] ≈ [2.0 6.0; cos(scalar) exp(scalar)] - @test Enzyme.jacobian(Enzyme.Forward, mkarray1, scalar) ≈ [2.0 6.0; cos(scalar) exp(scalar)] - @test Enzyme.jacobian(Enzyme.Reverse, mkarray1, scalar) ≈ [2.0 6.0; cos(scalar) exp(scalar)] + @test Enzyme.jacobian(Enzyme.Forward, mkarray1, scalar)[1] ≈ [2.0 6.0; cos(scalar) exp(scalar)] + @test Enzyme.jacobian(Enzyme.Reverse, mkarray1, scalar)[1] ≈ [2.0 6.0; cos(scalar) exp(scalar)] # ∂ struct / ∂ scalar - @test Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x, x^2, x^3), scalar) == OutStruct(1.0,2*scalar,3*scalar^2) - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> InpStruct(x, x^2, x^3), scalar) == (OutStruct(1.0,2.0,3.0),) - @test Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x, x^2, x^3), scalar) == OutStruct(1.0,2*scalar,3*scalar^2) - @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> InpStruct(x, x^2, x^3), scalar) == (OutStruct(1.0,2.0,3.0),) + @test Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x, x^2, x^3), scalar)[1] == OutStruct(1.0,2*scalar,3*scalar^2) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> InpStruct(x, x^2, x^3), scalar)[1] == (OutStruct(1.0,2.0,3.0),) + @test Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x, x^2, x^3), scalar)[1] == OutStruct(1.0,2*scalar,3*scalar^2) + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> InpStruct(x, x^2, x^3), scalar)[1] == (OutStruct(1.0,2.0,3.0),) vector = [2.7, 3.1] # ∂ scalar / ∂ vector - @test Enzyme.gradient(Enzyme.Forward, x -> x[1] * x[2], vector) ≃ (vector[2],vector[1]) - @test Enzyme.gradient(Enzyme.Reverse, x -> x[1] * x[2], vector) ≈ [vector[2], vector[1]] - @test Enzyme.jacobian(Enzyme.Forward, x -> x[1] * x[2], vector) ≈ [vector[2], vector[1]] - @test Enzyme.jacobian(Enzyme.Reverse, x -> x[1] * x[2], vector) ≈ [vector[2], vector[1]] + @test Enzyme.gradient(Enzyme.Forward, x -> x[1] * x[2], vector)[1] ≈ [vector[2],vector[1]] + @test Enzyme.gradient(Enzyme.Reverse, x -> x[1] * x[2], vector)[1] ≈ [vector[2], vector[1]] + @test Enzyme.jacobian(Enzyme.Forward, x -> x[1] * x[2], vector)[1] ≈ [vector[2], vector[1]] + @test Enzyme.jacobian(Enzyme.Reverse, x -> x[1] * x[2], vector)[1] ≈ [vector[2], vector[1]] # ∂ vector / ∂ vector - @test Enzyme.gradient(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector) ≃ - ([vector[2], -sin(vector[1])], [vector[1], 1.0]) - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector) ≈ ([vector[2], -sin(vector[1])], [vector[1], 1.0]) - @test Enzyme.jacobian(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector) ≈ + @test Enzyme.gradient(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector)[1] ≈ + [vector[2] vector[1]; -sin(vector[1]) 1.0] + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector)[1] ≈ + [vector[2] vector[1]; -sin(vector[1]) 1.0] + @test Enzyme.jacobian(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector)[1] ≈ [vector[2] vector[1]; -sin(vector[1]) 1.0] - @test Enzyme.jacobian(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector) ≈ + @test Enzyme.jacobian(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector)[1] ≈ [vector[2] vector[1]; -sin(vector[1]) 1.0] # ∂ tuple / ∂ vector - @test Enzyme.gradient(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector) ≃ - ((vector[2], -sin(vector[1])), (vector[1], 1.0)) - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector) ≈ + @test Enzyme.gradient(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector)[1] ≃ + [(vector[2], -sin(vector[1])), (vector[1], 1.0)] + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector)[1] ≈ ([vector[2], -sin(vector[1])], [vector[1], 1.0]) - @test Enzyme.jacobian(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector) ≃ + @test Enzyme.jacobian(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector)[1] ≃ [(vector[2], -sin(vector[1])), (vector[1], 1.0)] - @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector) + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector)[1] mkarray2 = x -> mkarray((2,2), x[1]*x[2], exp(x[2]), cos(x[1])+x[2], x[1]) # ∂ matrix / ∂ vector - @test Enzyme.gradient(Enzyme.Forward, mkarray2, vector) ≃ - ([vector[2] -sin(vector[1]); 0.0 1.0], [vector[1] 1.0; exp(vector[2]) 0.0]) - @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray2, vector) - @test Enzyme.jacobian(Enzyme.Forward, mkarray2, vector) ≈ + @test Enzyme.gradient(Enzyme.Forward, mkarray2, vector)[1] ≈ + mkarray((2,2,2), vector[2], 0.0, -sin(vector[1]), 1.0, vector[1], exp(vector[2]), 1.0, 0.0) + @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray2, vector)[1] + @test Enzyme.jacobian(Enzyme.Forward, mkarray2, vector)[1] ≈ mkarray((2,2,2), vector[2], 0.0, -sin(vector[1]), 1.0, vector[1], exp(vector[2]), 1.0, 0.0) - @test Enzyme.jacobian(Enzyme.Reverse, mkarray2, vector) ≈ + @test Enzyme.jacobian(Enzyme.Reverse, mkarray2, vector)[1] ≈ mkarray((2,2,2), vector[2], 0.0, -sin(vector[1]), 1.0, vector[1], exp(vector[2]), 1.0, 0.0) # ∂ struct / ∂ vector - @test Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), vector) ≃ - (OutStruct(vector[2], -sin(vector[1]), 0.0), OutStruct(vector[1], 1.0, exp(vector[2]))) - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector) ≈ ([vector[2], -sin(vector[1])], [vector[1], 1.0]) + @test Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), vector)[1] ≃ + [OutStruct(vector[2], -sin(vector[1]), 0.0), OutStruct(vector[1], 1.0, exp(vector[2]))] + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector)[1] ≈ ([vector[2], -sin(vector[1])], [vector[1], 1.0]) - @test Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), vector) ≃ + @test Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), vector)[1] ≃ [OutStruct(vector[2], -sin(vector[1]), 0.0), OutStruct(vector[1], 1.0, exp(vector[2]))] - @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector) ≈ ([vector[2], -sin(vector[1])], [vector[1], 1.0]) + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector)[1] ≈ ([vector[2], -sin(vector[1])], [vector[1], 1.0]) tuplev = (2.7, 3.1) # ∂ scalar / ∂ tuple - @test Enzyme.gradient(Enzyme.Forward, x -> x[1] * x[2], tuplev) ≃ (tuplev[2],tuplev[1]) - @test Enzyme.gradient(Enzyme.Reverse, x -> x[1] * x[2], tuplev) ≃ (tuplev[2],tuplev[1]) - @test Enzyme.jacobian(Enzyme.Forward, x -> x[1] * x[2], tuplev) ≃ (tuplev[2],tuplev[1]) - @test Enzyme.jacobian(Enzyme.Reverse, x -> x[1] * x[2], tuplev) ≃ (tuplev[2],tuplev[1]) + @test Enzyme.gradient(Enzyme.Forward, x -> x[1] * x[2], tuplev)[1] ≃ (tuplev[2],tuplev[1]) + @test Enzyme.gradient(Enzyme.Reverse, x -> x[1] * x[2], tuplev)[1] ≃ (tuplev[2],tuplev[1]) + @test Enzyme.jacobian(Enzyme.Forward, x -> x[1] * x[2], tuplev)[1] ≃ (tuplev[2],tuplev[1]) + @test Enzyme.jacobian(Enzyme.Reverse, x -> x[1] * x[2], tuplev)[1] ≃ (tuplev[2],tuplev[1]) # ∂ vector / ∂ tuple - @test Enzyme.gradient(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev) ≃ + @test Enzyme.gradient(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev)[1] ≃ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev) ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) - @test_broken Enzyme.jacobian(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev) ≈ + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev)[1] ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev)[1] ≈ [tuplev[2] tuplev[1]; -sin(tuplev[1]) 1.0] - @test Enzyme.jacobian(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev) ≃ + @test Enzyme.jacobian(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev)[1] ≃ [(tuplev[2], tuplev[1]), (-sin(tuplev[1]), 1.0)] # ∂ tuple / ∂ tuple - @test Enzyme.gradient(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev) ≃ + @test Enzyme.gradient(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev)[1] ≃ ((vector[2], -sin(vector[1])), (vector[1], 1.0)) - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev) ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) - @test Enzyme.jacobian(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev) ≃ + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev)[1] ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) + @test Enzyme.jacobian(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev)[1] ≃ ((tuplev[2], -sin(tuplev[1])), (tuplev[1], 1.0)) - @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev) ≈ + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev)[1] ≈ [tuplev[2] tuplev[1]; -sin(tuplev[1]) 1.0] # ∂ matrix / ∂ tuple - @test Enzyme.gradient(Enzyme.Forward, mkarray2, tuplev) ≃ + @test Enzyme.gradient(Enzyme.Forward, mkarray2, tuplev)[1] ≃ ([tuplev[2] -sin(tuplev[1]); 0.0 1.0], [tuplev[1] 1.0; exp(tuplev[2]) 0.0]) - @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray2, tuplev) - @test_broken Enzyme.jacobian(Enzyme.Forward, mkarray2, tuplev) ≈ + @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray2, tuplev)[1] + @test_broken Enzyme.jacobian(Enzyme.Forward, mkarray2, tuplev)[1] ≈ [tuplev[2] -sin(tuplev[1]); 0.0 1.0;;; tuplev[1] 1.0; exp(tuplev[2]) 0.0] - @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> mkarray2, tuplev) ≈ + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> mkarray2, tuplev)[1] ≈ [tuplev[2] -sin(tuplev[1]); 0.0 1.0;;; tuplev[1] 1.0; exp(tuplev[2]) 0.0] # ∂ struct / ∂ tuple - @test Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), tuplev) ≃ + @test Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), tuplev)[1] ≃ (OutStruct(tuplev[2], -sin(tuplev[1]), 0.0), OutStruct(tuplev[1], 1.0, exp(tuplev[2]))) - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev) ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev)[1] ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) - @test_broken Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), tuplev) ≃ + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), tuplev)[1] ≃ [OutStruct(tuplev[2], -sin(tuplev[1]), 0.0), OutStruct(tuplev[1], 1.0, exp(tuplev[2]))] - @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev) ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev)[1] ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) matrix = [2.7 3.1; 4.7 5.6] # ∂ scalar / ∂ matrix - @test Enzyme.gradient(Enzyme.Forward, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix) ≃ - (matrix[1,2], matrix[2,2], matrix[1,1], matrix[2,1]) - @test Enzyme.gradient(Enzyme.Reverse, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix) ≈ [matrix[1,2] matrix[1,1]; matrix[2,2] matrix[2,1]] - @test Enzyme.jacobian(Enzyme.Forward, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix) ≈ [matrix[1,2] matrix[1,1]; matrix[2,2] matrix[2,1]] - @test Enzyme.jacobian(Enzyme.Reverse, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix) ≈ [matrix[1,2] matrix[1,1]; matrix[2,2] matrix[2,1]] + @test Enzyme.gradient(Enzyme.Forward, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix)[1] ≈ [matrix[1,2] matrix[1,1]; matrix[2,2] matrix[2,1]] + @test Enzyme.gradient(Enzyme.Reverse, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix)[1] ≈ [matrix[1,2] matrix[1,1]; matrix[2,2] matrix[2,1]] + @test Enzyme.jacobian(Enzyme.Forward, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix)[1] ≈ [matrix[1,2] matrix[1,1]; matrix[2,2] matrix[2,1]] + @test Enzyme.jacobian(Enzyme.Reverse, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix)[1] ≈ [matrix[1,2] matrix[1,1]; matrix[2,2] matrix[2,1]] # ∂ vector / ∂ matrix - @test Enzyme.gradient(Enzyme.Forward, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix) ≃ - ([matrix[1,2], 0.0], [0.0, matrix[2,2]], [matrix[1,1], 0.0], [0.0, matrix[2,1]]) - @test_broken Enzyme.gradient(Enzyme.Reverse, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix) + @test Enzyme.gradient(Enzyme.Forward, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix)[1] ≈ + mkarray((2,2,2), matrix[1,2], 0.0, 0.0, matrix[2,2], matrix[1,1], 0.0, 0.0, matrix[2,1]) + @test_broken Enzyme.gradient(Enzyme.Reverse, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix)[1] # again we can't use array construction syntax because of 1.6 - @test Enzyme.jacobian(Enzyme.Forward, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix) ≈ + @test Enzyme.jacobian(Enzyme.Forward, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix)[1] ≈ mkarray((2,2,2), matrix[1,2], 0.0, 0.0, matrix[2,2], matrix[1,1], 0.0, 0.0, matrix[2,1]) - @test Enzyme.jacobian(Enzyme.Reverse, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix) ≈ + @test Enzyme.jacobian(Enzyme.Reverse, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix)[1] ≈ mkarray((2,2,2), matrix[1,2], 0.0, 0.0, matrix[2,2], matrix[1,1], 0.0, 0.0, matrix[2,1]) # ∂ tuple / ∂ matrix - @test Enzyme.gradient(Enzyme.Forward, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix) ≃ ((matrix[1,2], 0.0), (0.0, matrix[2,2]), (matrix[1,1], 0.0), (0.0, matrix[2,1])) + @test Enzyme.gradient(Enzyme.Forward, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix)[1] ≃ + [(matrix[1,2],0.0) (matrix[1,1],0.0); (0.0,matrix[2,2]) (0.0,matrix[2,1])] @test_broken Enzyme.gradient(Enzyme.Reverse, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix) - @test Enzyme.jacobian(Enzyme.Forward, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix) ≃ + @test Enzyme.jacobian(Enzyme.Forward, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix)[1] ≃ [(matrix[1,2],0.0) (matrix[1,1],0.0); (0.0,matrix[2,2]) (0.0,matrix[2,1])] - @test_broken Enzyme.jacobian(Enzyme.Reverse, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix) + @test_broken Enzyme.jacobian(Enzyme.Reverse, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix)[1] mkarray3 = x -> mkarray((2,2), x[1,1]*x[1,2], exp(x[1,1])+x[2,2], x[2,1]*x[2,2], sin(x[1,2])+x[2,1]) # ∂ matrix / ∂ matrix - @test Enzyme.gradient(Enzyme.Forward, mkarray3, matrix) ≃ - ([matrix[1,2] 0.0; exp(matrix[1,1]) 0.0], [0.0 matrix[2,2]; 0.0 1.0], [matrix[1,1] 0.0; 0.0 cos(matrix[1,2])], [0.0 matrix[2,1]; 1.0 0.0]) - @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray3, matrix) + @test Enzyme.gradient(Enzyme.Forward, mkarray3, matrix)[1] ≈ + mkarray((2,2,2,2), matrix[1,2],exp(matrix[1,1]),0.0,0.0,0.0,0.0,matrix[2,2],1.0, + matrix[1,1],0.0,0.0,cos(matrix[1,2]),0.0,1.0,matrix[2,1],0.0) + @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray3, matrix)[1] # array construction syntax broken on 1.6 - @test Enzyme.jacobian(Enzyme.Forward, mkarray3, matrix) ≈ + @test Enzyme.jacobian(Enzyme.Forward, mkarray3, matrix)[1] ≈ mkarray((2,2,2,2), matrix[1,2],exp(matrix[1,1]),0.0,0.0,0.0,0.0,matrix[2,2],1.0, matrix[1,1],0.0,0.0,cos(matrix[1,2]),0.0,1.0,matrix[2,1],0.0) - @test Enzyme.jacobian(Enzyme.Reverse, mkarray3, matrix) ≈ + @test Enzyme.jacobian(Enzyme.Reverse, mkarray3, matrix)[1] ≈ mkarray((2,2,2,2), matrix[1,2],exp(matrix[1,1]),0.0,0.0,0.0,0.0,matrix[2,2],1.0, matrix[1,1],0.0,0.0,cos(matrix[1,2]),0.0,1.0,matrix[2,1],0.0) # ∂ tuple / ∂ matrix - @test Enzyme.gradient(Enzyme.Forward, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix) ≃ - (OutStruct(matrix[1,2], 0.0, exp(matrix[1,1])), OutStruct(0.0, matrix[2,2], 0.0), OutStruct(matrix[1,1], 0.0, 0.0), OutStruct(0.0, matrix[2,1], 1.0)) - @test_broken Enzyme.gradient(Enzyme.Reverse, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix) - @test Enzyme.jacobian(Enzyme.Forward, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix) ≃ + @test Enzyme.gradient(Enzyme.Forward, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix)[1] ≃ + [OutStruct(matrix[1,2],0.0, exp(matrix[1,1])) OutStruct(matrix[1,1],0.0,0.0); OutStruct(0.0,matrix[2,2],0.0) OutStruct(0.0,matrix[2,1], 1.0)] + @test_broken Enzyme.gradient(Enzyme.Reverse, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix)[1] + @test Enzyme.jacobian(Enzyme.Forward, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix)[1] ≃ [OutStruct(matrix[1,2],0.0, exp(matrix[1,1])) OutStruct(matrix[1,1],0.0,0.0); OutStruct(0.0,matrix[2,2],0.0) OutStruct(0.0,matrix[2,1], 1.0)] - @test_broken Enzyme.jacobian(Enzyme.Reverse, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix) + @test_broken Enzyme.jacobian(Enzyme.Reverse, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix)[1] istruct = InpStruct(2.7, 3.1, 4.7) # ∂ scalar / ∂ struct - @test_broken Enzyme.gradient(Enzyme.Forward, x -> x.i1 * x.i2 + x.i3, istruct) - @test Enzyme.gradient(Enzyme.Reverse, x -> x.i1 * x.i2 + x.i3, istruct) ≃ InpStruct(istruct.i2, istruct.i1, 1.0) - @test_broken Enzyme.jacobian(Enzyme.Forward, x -> x.i1 * x.i2 + x.i3, istruct) - @test Enzyme.jacobian(Enzyme.Reverse, x -> x.i1 * x.i2 + x.i3, istruct) ≃ InpStruct(istruct.i2, istruct.i1, 1.0) + @test_broken Enzyme.gradient(Enzyme.Forward, x -> x.i1 * x.i2 + x.i3, istruct)[1] + @test Enzyme.gradient(Enzyme.Reverse, x -> x.i1 * x.i2 + x.i3, istruct)[1] ≃ InpStruct(istruct.i2, istruct.i1, 1.0) + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> x.i1 * x.i2 + x.i3, istruct)[1] + @test Enzyme.jacobian(Enzyme.Reverse, x -> x.i1 * x.i2 + x.i3, istruct)[1] ≃ InpStruct(istruct.i2, istruct.i1, 1.0) # ∂ vector / ∂ struct - @test_broken Enzyme.gradient(Enzyme.Forward, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct) - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct) - @test_broken Enzyme.jacobian(Enzyme.Forward, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct) - @test Enzyme.jacobian(Enzyme.Reverse, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct) ≃ [InpStruct(istruct.i2, istruct.i1, 0.0), InpStruct(1.0, 0.0, -sin(istruct.i3))] + @test_broken Enzyme.gradient(Enzyme.Forward, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct)[1] + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct)[1] + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct)[1] + @test Enzyme.jacobian(Enzyme.Reverse, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct)[1] ≃ [InpStruct(istruct.i2, istruct.i1, 0.0), InpStruct(1.0, 0.0, -sin(istruct.i3))] # ∂ tuple / ∂ struct - @test_broken Enzyme.gradient(Enzyme.Forward, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct) - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct) - @test_broken Enzyme.jacobian(Enzyme.Forward, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct) - @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct) + @test_broken Enzyme.gradient(Enzyme.Forward, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct)[1] + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct)[1] + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct)[1] + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct)[1] mkarray4 = x -> mkarray((2,2), x.i1*x.i2, exp(x.i2), cos(x.i3)+x.i1, x.i1) # ∂ matrix / ∂ struct - @test_broken Enzyme.gradient(Enzyme.Forward, x -> [x.i1 * x.i2 cos(x.i3) + x.i1; exp(x.i2) x.i1], istruct) - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x.i1 * x.i2 cos(x.i3) + x.i1; exp(x.i2) x.i1], istruct) - @test_broken Enzyme.jacobian(Enzyme.Forward, x -> [x.i1 * x.i2 cos(x.i3) + x.i1; exp(x.i2) x.i1], istruct) - @test Enzyme.jacobian(Enzyme.Reverse, mkarray4, istruct) ≃ + @test_broken Enzyme.gradient(Enzyme.Forward, x -> [x.i1 * x.i2 cos(x.i3) + x.i1; exp(x.i2) x.i1], istruct)[1] + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x.i1 * x.i2 cos(x.i3) + x.i1; exp(x.i2) x.i1], istruct)[1] + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> [x.i1 * x.i2 cos(x.i3) + x.i1; exp(x.i2) x.i1], istruct)[1] + @test Enzyme.jacobian(Enzyme.Reverse, mkarray4, istruct)[1] ≃ [InpStruct(istruct.i2, istruct.i1, 0.0) InpStruct(1.0, 0.0, -sin(istruct.i3)); InpStruct(0.0, exp(istruct.i2), 0.0) InpStruct(1.0, 0.0, 0.0)] # ∂ struct / ∂ struct - @test_broken Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct) - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct) - @test_broken Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct) - @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct) + @test_broken Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct)[1] + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct)[1] + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct)[1] + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct)[1] end @testset "Simple Jacobian" begin - @test Enzyme.jacobian(Enzyme.Forward, x->2*x, 3.0) ≈ 2.0 - @test Enzyme.jacobian(Enzyme.Forward, x->[x, 2*x], 3.0) ≈ [1.0, 2.0] - @test Enzyme.jacobian(Enzyme.Forward, x->sum(abs2, x), [2.0, 3.0]) ≈ [4.0, 6.0] + @test Enzyme.jacobian(Enzyme.Forward, x->2*x, 3.0)[1] ≈ 2.0 + @test Enzyme.jacobian(Enzyme.Forward, x->[x, 2*x], 3.0)[1] ≈ [1.0, 2.0] + @test Enzyme.jacobian(Enzyme.Forward, x->sum(abs2, x), [2.0, 3.0])[1] ≈ [4.0, 6.0] - @test Enzyme.jacobian(Enzyme.Forward, x->2*x, 3.0, Val(1)) ≈ 2.0 - @test Enzyme.jacobian(Enzyme.Forward, x->[x, 2*x], 3.0, Val(1)) ≈ [1.0, 2.0] - @test Enzyme.jacobian(Enzyme.Forward, x->sum(abs2, x), [2.0, 3.0], Val(1)) ≈ [4.0, 6.0] + @test Enzyme.jacobian(Enzyme.Forward, x->2*x, 3.0, chunk=Val(1))[1] ≈ 2.0 + @test Enzyme.jacobian(Enzyme.Forward, x->[x, 2*x], 3.0, chunk=Val(1))[1] ≈ [1.0, 2.0] + @test Enzyme.jacobian(Enzyme.Forward, x->sum(abs2, x), [2.0, 3.0], chunk=Val(1))[1] ≈ [4.0, 6.0] - @test Enzyme.jacobian(Enzyme.Forward, x->2*x, 3.0, Val(2)) ≈ 2.0 - @test Enzyme.jacobian(Enzyme.Forward, x->[x, 2*x], 3.0, Val(2)) ≈ [1.0, 2.0] - @test Enzyme.jacobian(Enzyme.Forward, x->sum(abs2, x), [2.0, 3.0], Val(2)) ≈ [4.0, 6.0] + @test Enzyme.jacobian(Enzyme.Forward, x->2*x, 3.0, chunk=Val(2))[1] ≈ 2.0 + @test Enzyme.jacobian(Enzyme.Forward, x->[x, 2*x], 3.0, chunk=Val(2))[1] ≈ [1.0, 2.0] + @test Enzyme.jacobian(Enzyme.Forward, x->sum(abs2, x), [2.0, 3.0], chunk=Val(2))[1] ≈ [4.0, 6.0] - @test Enzyme.jacobian(Enzyme.Reverse, x->[x, 2*x], 3.0, Val(2)) ≈ [1.0, 2.0] - @test Enzyme.jacobian(Enzyme.Reverse, x->[x, 2*x], 3.0, Val(2), Val(1)) ≈ [1.0, 2.0] - @test Enzyme.jacobian(Enzyme.Reverse, x->[x, 2*x], 3.0, Val(2), Val(2)) ≈ [1.0, 2.0] + @test Enzyme.jacobian(Enzyme.Reverse, x->[x, 2*x], 3.0, n_outs=Val((2,)))[1] ≈ [1.0, 2.0] + @test Enzyme.jacobian(Enzyme.Reverse, x->[x, 2*x], 3.0, n_outs=Val((2,)), chunk=Val(1))[1] ≈ [1.0, 2.0] + @test Enzyme.jacobian(Enzyme.Reverse, x->[x, 2*x], 3.0, n_outs=Val((2,)), chunk=Val(2))[1] ≈ [1.0, 2.0] x = float.(reshape(1:6, 2, 3)) fillabs2(x) = [sum(abs2, x), 10*sum(abs2, x), 100*sum(abs2, x), 1000*sum(abs2, x)] - jac = Enzyme.jacobian(Enzyme.Forward, fillabs2, x) + jac = Enzyme.jacobian(Enzyme.Forward, fillabs2, x)[1] @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] - jac = Enzyme.jacobian(Enzyme.Forward, fillabs2, x, Val(1)) + jac = Enzyme.jacobian(Enzyme.Forward, fillabs2, x, chunk=Val(1))[1] @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] - jac = Enzyme.jacobian(Enzyme.Forward, fillabs2, x, Val(2)) + jac = Enzyme.jacobian(Enzyme.Forward, fillabs2, x, chunk=Val(2))[1] @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] @@ -3171,14 +3173,14 @@ end @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] - jac = Enzyme.jacobian(Enzyme.Reverse, fillabs2, x, Val(4), Val(1)) + jac = Enzyme.jacobian(Enzyme.Reverse, fillabs2, x, n_outs=Val((4,)), chunk=Val(1))[1] @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] - jac = Enzyme.jacobian(Enzyme.Reverse, fillabs2, x, Val(4), Val(2)) + jac = Enzyme.jacobian(Enzyme.Reverse, fillabs2, x, n_outs=Val((4,)), chunk=Val(2))[1] @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] @@ -3189,14 +3191,14 @@ end x2 = InpStruct(1.0, 2.0, 3.0) - jac = Enzyme.jacobian(Enzyme.Reverse, fillinpabs2, x2, Val(4), Val(1)) + jac = Enzyme.jacobian(Enzyme.Reverse, fillinpabs2, x2, n_outs=Val((4,)), chunk=Val(1))[1] @test jac[1] == InpStruct(2.0, 4.0, 6.0) @test jac[2] == InpStruct(20.0, 40.0, 60.0) @test jac[3] == InpStruct(200.0, 400.0, 600.0) @test jac[4] == InpStruct(2000.0, 4000.0, 6000.0) - jac = Enzyme.jacobian(Enzyme.Reverse, fillinpabs2, x2, Val(4), Val(2)) + jac = Enzyme.jacobian(Enzyme.Reverse, fillinpabs2, x2, n_outs=Val((4,)), chunk=Val(2))[1] @test jac[1] == InpStruct(2.0, 4.0, 6.0) @test jac[2] == InpStruct(20.0, 40.0, 60.0) @@ -3205,7 +3207,7 @@ end filloutabs2(x) = OutStruct(sum(abs2, x), 10*sum(abs2, x), 100*sum(abs2, x)) - jac = Enzyme.jacobian(Enzyme.Forward, filloutabs2, x) + jac = Enzyme.jacobian(Enzyme.Forward, filloutabs2, x)[1] @test jac[1, 1] == OutStruct(2.0, 20.0, 200.0) @test jac[2, 1] == OutStruct(4.0, 40.0, 400.0) @@ -3216,7 +3218,7 @@ end @test jac[1, 3] == OutStruct(10.0, 100.0, 1000.0) @test jac[2, 3] == OutStruct(12.0, 120.0, 1200.0) - jac = Enzyme.jacobian(Enzyme.Forward, filloutabs2, x, Val(1)) + jac = Enzyme.jacobian(Enzyme.Forward, filloutabs2, x, chunk=Val(1))[1] @test jac[1, 1] == OutStruct(2.0, 20.0, 200.0) @test jac[2, 1] == OutStruct(4.0, 40.0, 400.0) @@ -3227,7 +3229,7 @@ end @test jac[1, 3] == OutStruct(10.0, 100.0, 1000.0) @test jac[2, 3] == OutStruct(12.0, 120.0, 1200.0) - jac = Enzyme.jacobian(Enzyme.Forward, filloutabs2, x, Val(2)) + jac = Enzyme.jacobian(Enzyme.Forward, filloutabs2, x, chunk=Val(2))[1] @test jac[1, 1] == OutStruct(2.0, 20.0, 200.0) @test jac[2, 1] == OutStruct(4.0, 40.0, 400.0) @@ -3245,27 +3247,27 @@ end [v[2], v[1]*v[1], v[1]*v[1]*v[1]] end - jac = Enzyme.jacobian(Reverse, inout, [2.0, 3.0], #=n_outs=# Val(3), Val(1)) + jac = Enzyme.jacobian(Reverse, inout, [2.0, 3.0], n_outs=Val((3,)), chunk=Val(1))[1] @test size(jac) == (3, 2) @test jac ≈ [ 0.0 1.0; 4.0 0.0; 12.0 0.0] - jac = Enzyme.jacobian(Forward, inout, [2.0, 3.0], Val(1)) + jac = Enzyme.jacobian(Forward, inout, [2.0, 3.0], chunk=Val(1))[1] @test size(jac) == (3, 2) @test jac ≈ [ 0.0 1.0; 4.0 0.0; 12.0 0.0] - @test jac == Enzyme.jacobian(Forward, inout, [2.0, 3.0]) + @test jac == Enzyme.jacobian(Forward, inout, [2.0, 3.0])[1] - jac = Enzyme.jacobian(Reverse, inout, [2.0, 3.0], #=n_outs=# Val(3), Val(2)) + jac = Enzyme.jacobian(Reverse, inout, [2.0, 3.0], n_outs=Val((3,)), chunk=Val(2))[1] @test size(jac) == (3, 2) @test jac ≈ [ 0.0 1.0; 4.0 0.0; 12.0 0.0] - jac = Enzyme.jacobian(Forward, inout, [2.0, 3.0], Val(2)) + jac = Enzyme.jacobian(Forward, inout, [2.0, 3.0], chunk=Val(2))[1] @test size(jac) == (3, 2) @test jac ≈ [ 0.0 1.0; 4.0 0.0; @@ -3286,13 +3288,13 @@ end utmp .= A*x[2:end] .+ x[1] end - J_r_1(A, x) = Enzyme.jacobian(Reverse, θ -> f_test_1(A, θ), x, Val(5)) - J_r_2(A, x) = Enzyme.jacobian(Reverse, θ -> f_test_2(A, θ), x, Val(5)) - J_r_3(u, A, x) = Enzyme.jacobian(Reverse, θ -> f_test_3!(u, A, θ), x, Val(5)) + J_r_1(A, x) = Enzyme.jacobian(Reverse, θ -> f_test_1(A, θ), x, n_outs=Val((5,)))[1] + J_r_2(A, x) = Enzyme.jacobian(Reverse, θ -> f_test_2(A, θ), x, n_outs=Val((5,)))[1] + J_r_3(u, A, x) = Enzyme.jacobian(Reverse, θ -> f_test_3!(u, A, θ), x, n_outs=Val((5,)))[1] - J_f_1(A, x) = Enzyme.jacobian(Forward, Const(θ -> f_test_1(A, θ)), x) - J_f_2(A, x) = Enzyme.jacobian(Forward, Const(θ -> f_test_2(A, θ)), x) - J_f_3(u, A, x) = Enzyme.jacobian(Forward, Const(θ -> f_test_3!(u, A, θ)), x) + J_f_1(A, x) = Enzyme.jacobian(Forward, Const(θ -> f_test_1(A, θ)), x)[1] + J_f_2(A, x) = Enzyme.jacobian(Forward, Const(θ -> f_test_2(A, θ)), x)[1] + J_f_3(u, A, x) = Enzyme.jacobian(Forward, Const(θ -> f_test_3!(u, A, θ)), x)[1] x = ones(6) A = Matrix{Float64}(LinearAlgebra.I, 5, 5) @@ -3351,7 +3353,7 @@ end dry = zeros(2) function foo(y, dy, x, dx) - autodiff_deferred(Reverse, speelpenning, Const, Duplicated(y, dy), Duplicated(x, dx)) + autodiff(Reverse, speelpenning, Const, Duplicated(y, dy), Duplicated(x, dx)) return nothing end @@ -3776,8 +3778,8 @@ end @testset "Constant Complex return" begin vec = [0.5] - @test Enzyme.gradient(Enzyme.Reverse, fexpandempty, vec)[1] ≈ 1.0 - @test Enzyme.gradient(Enzyme.Forward, fexpandempty, vec)[1] ≈ 1.0 + @test Enzyme.gradient(Enzyme.Reverse, fexpandempty, vec)[1] ≈ [1.0] + @test Enzyme.gradient(Enzyme.Forward, fexpandempty, vec)[1] ≈ [1.0] end const CUmemoryPool2 = Ptr{Float64} @@ -3924,10 +3926,10 @@ const objective3 = params -> mixture_loglikelihood3(params, data) -13.935687326484112, -38.00044665702692, 12.87712891527131] - @test expected ≈ Enzyme.gradient(Reverse, objective1, params0) + @test expected ≈ Enzyme.gradient(Reverse, objective1, params0)[1] # objective2 fails from runtime activity requirements - # @test expected ≈ Enzyme.gradient(Reverse, objective2, params0) - @test expected ≈ Enzyme.gradient(Reverse, objective3, params0) + # @test expected ≈ Enzyme.gradient(Reverse, objective2, params0)[1] + @test expected ≈ Enzyme.gradient(Reverse, objective3, params0)[1] end struct HarmonicAngle From bbaa1f8d8c83daf4b28018d6387148be9121bdb1 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 17 Sep 2024 21:06:17 -0500 Subject: [PATCH 35/87] Use namedtuple for grad/jacobian (#1850) * Use namedtuple for grad/jacobian * Update index.md * Update Enzyme.jl * Update Enzyme.jl * Update Enzyme.jl * Update Enzyme.jl * Update index.md --- docs/src/faq.md | 2 +- docs/src/index.md | 88 ++++++++++++++++++++++++------------------- src/Enzyme.jl | 28 +++++++------- test/ext/bfloat16s.jl | 4 +- 4 files changed, 67 insertions(+), 55 deletions(-) diff --git a/docs/src/faq.md b/docs/src/faq.md index 88c0cce3b98..6b3bbce6b40 100644 --- a/docs/src/faq.md +++ b/docs/src/faq.md @@ -193,7 +193,7 @@ That is why Enzyme provides a helper function `Enzyme.make_zero` that does this ```jldoctest sparse Enzyme.make_zero(a) -Enzyme.gradient(Reverse, sum, a) # This calls make_zero(a) +Enzyme.gradient(Reverse, sum, a)[1] # This calls make_zero(a) # output diff --git a/docs/src/index.md b/docs/src/index.md index 1f7f092a994..2643b87a1e1 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -76,24 +76,32 @@ Both the inplace and "normal" variant return the gradient. The difference is tha ## Forward mode -The return value of forward mode with a `Duplicated` return is a tuple containing as the first value -the primal return value and as the second value the derivative. +The return value when using `ForwardWithPrimal` is a tuple containing as the first value +the derivative return value and as the second value the original value. + +The return value when using `Forward` is a single-element tuple containing the derivative. In forward mode `Duplicated(x, 0.0)` is equivalent to `Const(x)`, except that we can perform more optimizations for `Const`. ```jldoctest rosenbrock -julia> autodiff(Forward, rosenbrock, Duplicated, Const(1.0), Duplicated(3.0, 1.0)) +julia> autodiff(ForwardWithPrimal, rosenbrock, Const(1.0), Duplicated(3.0, 1.0)) (400.0, 400.0) -julia> autodiff(Forward, rosenbrock, Duplicated, Duplicated(1.0, 1.0), Const(3.0)) -(400.0, -800.0) +julia> autodiff(Forward, rosenbrock, Const(1.0), Duplicated(3.0, 1.0)) +(400.0,) + +julia> autodiff(ForwardWithPrimal, rosenbrock, Duplicated(1.0, 1.0), Const(3.0)) +(-800.0, 400.0) + +julia> autodiff(Forward, rosenbrock, Duplicated(1.0, 1.0), Const(3.0)) +(-800.0,) ``` Of note, when we seed both arguments at once the tangent return is the sum of both. ```jldoctest rosenbrock -julia> autodiff(Forward, rosenbrock, Duplicated, Duplicated(1.0, 1.0), Duplicated(3.0, 1.0)) +julia> autodiff(ForwardWithPrimal, rosenbrock, Duplicated(1.0, 1.0), Duplicated(3.0, 1.0)) (400.0, -400.0) ``` @@ -121,7 +129,7 @@ Note the seeding through `dx`. We can also use vector mode to calculate both derivatives at once. ```jldoctest rosenbrock -julia> autodiff(Forward, rosenbrock, BatchDuplicated, BatchDuplicated(1.0, (1.0, 0.0)), BatchDuplicated(3.0, (0.0, 1.0))) +julia> autodiff(ForwardWithPrimal, rosenbrock, BatchDuplicated(1.0, (1.0, 0.0)), BatchDuplicated(3.0, (0.0, 1.0))) (400.0, (var"1" = -800.0, var"2" = 400.0)) julia> x = [1.0, 3.0] @@ -131,7 +139,7 @@ julia> x = [1.0, 3.0] julia> dx_1 = [1.0, 0.0]; dx_2 = [0.0, 1.0]; -julia> autodiff(Forward, rosenbrock_inp, BatchDuplicated, BatchDuplicated(x, (dx_1, dx_2))) +julia> autodiff(ForwardWithPrimal, rosenbrock_inp, BatchDuplicated(x, (dx_1, dx_2))) (400.0, (var"1" = -800.0, var"2" = 400.0)) ``` @@ -145,18 +153,20 @@ Like [`autodiff`](@ref), the mode (forward or reverse) is determined by the firs The functions [`gradient`](@ref) and [`gradient!`](@ref) compute the gradient of function with vector input and scalar return. +Gradient functions take a mode as the first argument. If the mode is `Reverse` or `Forward`, the return type is a tuple of gradients of each argument. +If the mode is `ReverseWithPrimal` or `ForwardWithPrimal`, the return type is a named tuple containing both the derivatives and the original return result. + ```jldoctest rosenbrock julia> gradient(Reverse, rosenbrock_inp, [1.0, 2.0]) -2-element Vector{Float64}: - -400.0 - 200.0 +([-400.0, 200.0],) + +julia> gradient(ReverseWithPrimal, rosenbrock_inp, [1.0, 2.0]) +(derivs=[-400.0, 200.0], val=100.0) julia> # inplace variant dx = [0.0, 0.0]; gradient!(Reverse, dx, rosenbrock_inp, [1.0, 2.0]) -2-element Vector{Float64}: - -400.0 - 200.0 +([-400.0, 200.0],) julia> dx 2-element Vector{Float64}: @@ -164,14 +174,16 @@ julia> dx 200.0 julia> gradient(Forward, rosenbrock_inp, [1.0, 2.0]) -(-400.0, 200.0) +([-400.0, 200.0],) + +julia> gradient(ForwardWithPrimal, rosenbrock_inp, [1.0, 2.0]) +(derivs = [-400.0, 200.0], val = 100.0) julia> # in forward mode, we can also optionally pass a chunk size # to specify the number of derivatives computed simulateneously # using vector forward mode - chunk_size = Val(2) - gradient(Forward, rosenbrock_inp, [1.0, 2.0], chunk_size) -(-400.0, 200.0) + gradient(Forward, rosenbrock_inp, [1.0, 2.0]; chunk=Val(1)) +([-400.0, 200.0],) ``` ## Jacobian Convenience functions @@ -179,31 +191,31 @@ julia> # in forward mode, we can also optionally pass a chunk size The function [`jacobian`](@ref) computes the Jacobian of a function vector input and vector return. Like [`autodiff`](@ref) and [`gradient`](@ref), the mode (forward or reverse) is determined by the first argument. +Again like [`gradient`](@ref), if the mode is `Reverse` or `Forward`, the return type is a tuple of jacobians of each argument. +If the mode is `ReverseWithPrimal` or `ForwardWithPrimal`, the return type is a named tuple containing both the derivatives and the original return result. + +Both forward and reverse modes take an optional chunk size to compute several derivatives simultaneously using vector mode, and reverse mode optionally takes `n_outs` which describes the shape of the output value. + ```jldoctest rosenbrock julia> foo(x) = [rosenbrock_inp(x), prod(x)]; -julia> output_size = Val(2) # here we have to provide the output size of `foo` since it cannot be statically inferred - jacobian(Reverse, foo, [1.0, 2.0], output_size) -2×2 transpose(::Matrix{Float64}) with eltype Float64: - -400.0 200.0 - 2.0 1.0 +julia> jacobian(Reverse, foo, [1.0, 2.0]) +([-400.0 200.0; 2.0 1.0],) -julia> chunk_size = Val(2) # By specifying the optional chunk size argument, we can use vector inverse mode to propogate derivatives of multiple outputs at once. - jacobian(Reverse, foo, [1.0, 2.0], output_size, chunk_size) -2×2 transpose(::Matrix{Float64}) with eltype Float64: - -400.0 200.0 - 2.0 1.0 +julia> jacobian(ReverseWithPrimal, foo, [1.0, 2.0]) +(derivs = ([-400.0 200.0; 2.0 1.0],), val = [100.0, 2.0]) + +julia> jacobian(Reverse, foo, [1.0, 2.0]; chunk=Val(2)) +([-400.0 200.0; 2.0 1.0],) + +julia> jacobian(Reverse, foo, [1.0, 2.0]; chunk=Val(2), n_outs=Val((2,))) +([-400.0 200.0; 2.0 1.0],) julia> jacobian(Forward, foo, [1.0, 2.0]) -2×2 Matrix{Float64}: - -400.0 200.0 - 2.0 1.0 - -julia> # Again, the optinal chunk size argument allows us to use vector forward mode - jacobian(Forward, foo, [1.0, 2.0], chunk_size) -2×2 Matrix{Float64}: - -400.0 200.0 - 2.0 1.0 +([-400.0 200.0; 2.0 1.0],) + +julia> jacobian(Forward, foo, [1.0, 2.0], chunk=Val(2)) +([-400.0 200.0; 2.0 1.0],) ``` ## Hessian Vector Product Convenience functions @@ -257,4 +269,4 @@ julia> grad 2-element Vector{Float64}: 2.880510859951098 1.920340573300732 -``` \ No newline at end of file +``` diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 66551f29587..c4994fb363b 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1082,21 +1082,21 @@ a tuple where the first element contains the derivatives, and the second element grad = gradient(ReverseWithPrimal, f, [2.0, 3.0]) # output -(([3.0, 2.0],), 6.0) +(derivs = ([3.0, 2.0],), val = 6.0) ``` ```jldoctest gradient grad = gradient(ReverseWithPrimal, mul, [2.0], [3.0]) # output -(([3.0], [2.0]), 6.0) +(derivs = ([3.0], [2.0]), val = 6.0) ``` ```jldoctest gradient grad = gradient(ReverseWithPrimal, mul, [2.0], Const([3.0])) # output -(([3.0], nothing), 6.0) +(derivs = ([3.0], nothing), val = 6.0) ``` """ @@ -1161,7 +1161,7 @@ grad = gradient(ReverseWithPrimal, mul, [2.0], Const([3.0])) return quote Base.@_inline_meta $(toemit...) - (($(resargs...),), res[2]) + (; derivs=($(resargs...),), val=res[2]) end else return quote @@ -1196,14 +1196,14 @@ dx = [0.0, 0.0] gradient!(ReverseWithPrimal, dx, f, [2.0, 3.0]) # output -(([3.0, 2.0],), 6.0) +(derivs = ([3.0, 2.0],), val = 6.0) ``` """ @inline function gradient!(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, dx::X, f::F, x::X) where {X<:Array, F, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten} make_zero!(dx) res = autodiff(rm, f, Active, Duplicated(x, dx)) return if ReturnPrimal - ((dx,), res[2]) + (; derivs=(dx,), val=res[2]) else (dx,) end @@ -1300,7 +1300,7 @@ gradient(Forward, f, [2.0, 3.0]) gradient(ForwardWithPrimal, f, [2.0, 3.0]) # output -(([3.0, 2.0],), 6.0) +(derivs = ([3.0, 2.0],), val = 6.0) ``` ```jldoctest gradfwd @@ -1315,7 +1315,7 @@ gradient(Forward, f, [2.0, 3.0]; chunk=Val(1)) gradient(ForwardWithPrimal, f, [2.0, 3.0]; chunk=Val(1)) # output -(([3.0, 2.0],), 6.0) +(derivs = ([3.0, 2.0],), val = 6.0) ``` For functions which return an AbstractArray or scalar, this function will return an AbstracttArray @@ -1336,10 +1336,10 @@ grad = gradient(Forward, f, [2.0, 3.0, 4.0]) """ @inline function gradient(fm::ForwardMode{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity}, f, x; chunk::CS=nothing, shadows=create_shadows(chunk, x)) where {ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity, CS} if length(shadows[1]) == 0 - if ReturnPrimal - ((x,), f(x.val)) + return if ReturnPrimal + (; derivs=(x,), val=f(x.val)) else - return (x,) + (x,) end end if chunk == Val(0) @@ -1430,7 +1430,7 @@ grad = gradient(Forward, f, [2.0, 3.0, 4.0]) cols end if ReturnPrimal - ((res,), gradtup[2]) + (; derivs=(res,), val=gradtup[2]) else (res,) end @@ -1498,7 +1498,7 @@ this function will retun an AbstractArray of shape `size(output)` of values of t end return if ReturnPrimal - (jac, res) + (; derivs=jac, val=res) else jac end @@ -1606,7 +1606,7 @@ this function will retun an AbstractArray of shape `size(output)` of values of t end if ReturnPrimal # TODO optimize away redundant fwd pass - (res, if f isa Enzyme.Const + (; derivs=res, val=if f isa Enzyme.Const f.val(x) else f(x) diff --git a/test/ext/bfloat16s.jl b/test/ext/bfloat16s.jl index 0a47f48f031..daaf6ef74cd 100644 --- a/test/ext/bfloat16s.jl +++ b/test/ext/bfloat16s.jl @@ -2,6 +2,6 @@ using Enzyme using Test using BFloat16s -@test_broken Enzyme.gradient(Reverse, sum, ones(BFloat16, 10)) ≈ ones(BFloat16, 10) +@test_broken Enzyme.gradient(Reverse, sum, ones(BFloat16, 10))[1] ≈ ones(BFloat16, 10) -@test_broken Enzyme.gradient(Forward, sum, ones(BFloat16, 10)) ≈ ones(BFloat16, 10) +@test_broken Enzyme.gradient(Forward, sum, ones(BFloat16, 10))[1] ≈ ones(BFloat16, 10) From 6a19be2cfb982b1d12cacc7c5aa182ed7321801f Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 18 Sep 2024 00:26:29 -0500 Subject: [PATCH 36/87] Simplify deferred functions (#1849) * Simplify deferred functions * fix * Update runtests.jl * fix * fix * fix * fix * fix --- docs/Project.toml | 3 ++- docs/src/faq.md | 2 +- docs/src/index.md | 29 ++++++++++++++++------------- examples/custom_rule.jl | 19 +++++++++++-------- lib/EnzymeCore/src/rules.jl | 12 ++++++++++++ src/Enzyme.jl | 33 --------------------------------- test/abi.jl | 30 +++++++++++++++--------------- test/amdgpu.jl | 6 +++--- test/cuda.jl | 10 +++++----- test/metal.jl | 4 ++-- test/runtests.jl | 12 ++++++------ 11 files changed, 73 insertions(+), 87 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index 56dd852972d..14301cf64d7 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,8 +1,9 @@ [deps] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] -Literate = "2" Documenter = "1" +Literate = "2" diff --git a/docs/src/faq.md b/docs/src/faq.md index 6b3bbce6b40..c5a80a976d2 100644 --- a/docs/src/faq.md +++ b/docs/src/faq.md @@ -627,7 +627,7 @@ Presently Enzyme only considers floats as base types. As a result, Enzyme does n ```jldoctest types f_int(x) = x * x -Enzyme.autodiff(Forward, f_int, DuplicatedNoNeed, Duplicated(3, 1)) +Enzyme.autodiff(Forward, f_int, Duplicated, Duplicated(3, 1)) # output diff --git a/docs/src/index.md b/docs/src/index.md index 2643b87a1e1..3c1c31b4af2 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -102,7 +102,10 @@ Of note, when we seed both arguments at once the tangent return is the sum of bo ```jldoctest rosenbrock julia> autodiff(ForwardWithPrimal, rosenbrock, Duplicated(1.0, 1.0), Duplicated(3.0, 1.0)) -(400.0, -400.0) +(-400.0, 400.0) + +julia> autodiff(Forward, rosenbrock, Duplicated(1.0, 1.0), Duplicated(3.0, 1.0)) +(-400.0,) ``` We can also use forward mode with our inplace method. @@ -118,8 +121,8 @@ julia> dx = [1.0, 1.0] 1.0 1.0 -julia> autodiff(Forward, rosenbrock_inp, Duplicated, Duplicated(x, dx)) -(400.0, -400.0) +julia> autodiff(ForwardWithPrimal, rosenbrock_inp, Duplicated, Duplicated(x, dx)) +(-400.0, 400.0) ``` Note the seeding through `dx`. @@ -130,7 +133,7 @@ We can also use vector mode to calculate both derivatives at once. ```jldoctest rosenbrock julia> autodiff(ForwardWithPrimal, rosenbrock, BatchDuplicated(1.0, (1.0, 0.0)), BatchDuplicated(3.0, (0.0, 1.0))) -(400.0, (var"1" = -800.0, var"2" = 400.0)) +((var"1" = -800.0, var"2" = 400.0), 400.0) julia> x = [1.0, 3.0] 2-element Vector{Float64}: @@ -140,7 +143,7 @@ julia> x = [1.0, 3.0] julia> dx_1 = [1.0, 0.0]; dx_2 = [0.0, 1.0]; julia> autodiff(ForwardWithPrimal, rosenbrock_inp, BatchDuplicated(x, (dx_1, dx_2))) -(400.0, (var"1" = -800.0, var"2" = 400.0)) +((var"1" = -800.0, var"2" = 400.0), 400.0) ``` ## Gradient Convenience functions @@ -161,7 +164,7 @@ julia> gradient(Reverse, rosenbrock_inp, [1.0, 2.0]) ([-400.0, 200.0],) julia> gradient(ReverseWithPrimal, rosenbrock_inp, [1.0, 2.0]) -(derivs=[-400.0, 200.0], val=100.0) +(derivs = ([-400.0, 200.0],), val = 100.0) julia> # inplace variant dx = [0.0, 0.0]; @@ -177,7 +180,7 @@ julia> gradient(Forward, rosenbrock_inp, [1.0, 2.0]) ([-400.0, 200.0],) julia> gradient(ForwardWithPrimal, rosenbrock_inp, [1.0, 2.0]) -(derivs = [-400.0, 200.0], val = 100.0) +(derivs = ([-400.0, 200.0],), val = 100.0) julia> # in forward mode, we can also optionally pass a chunk size # to specify the number of derivatives computed simulateneously @@ -200,22 +203,22 @@ Both forward and reverse modes take an optional chunk size to compute several de julia> foo(x) = [rosenbrock_inp(x), prod(x)]; julia> jacobian(Reverse, foo, [1.0, 2.0]) -([-400.0 200.0; 2.0 1.0],) +([-400.0 200.0; 2.0 1.0],) julia> jacobian(ReverseWithPrimal, foo, [1.0, 2.0]) -(derivs = ([-400.0 200.0; 2.0 1.0],), val = [100.0, 2.0]) +(derivs = ([-400.0 200.0; 2.0 1.0],), val = [100.0, 2.0]) julia> jacobian(Reverse, foo, [1.0, 2.0]; chunk=Val(2)) -([-400.0 200.0; 2.0 1.0],) +([-400.0 200.0; 2.0 1.0],) julia> jacobian(Reverse, foo, [1.0, 2.0]; chunk=Val(2), n_outs=Val((2,))) -([-400.0 200.0; 2.0 1.0],) +([-400.0 200.0; 2.0 1.0],) julia> jacobian(Forward, foo, [1.0, 2.0]) -([-400.0 200.0; 2.0 1.0],) +([-400.0 200.0; 2.0 1.0],) julia> jacobian(Forward, foo, [1.0, 2.0], chunk=Val(2)) -([-400.0 200.0; 2.0 1.0],) +([-400.0 200.0; 2.0 1.0],) ``` ## Hessian Vector Product Convenience functions diff --git a/examples/custom_rule.jl b/examples/custom_rule.jl index 2b3f226fb01..f7789390324 100644 --- a/examples/custom_rule.jl +++ b/examples/custom_rule.jl @@ -65,7 +65,7 @@ function forward(config::FwdConfig, func::Const{typeof(f)}, ::Type{<:Duplicated} end # In the signature of our rule, we have made use of `Enzyme`'s activity annotations. Let's break down each one: -# - the [`FwdConfig`](@ref) configuration passes certain compile-time information about differentiation procedure (the width, and if we're using runtime activity), +# - the [`EnzymeRules.FwdConfig`](@ref) configuration passes certain compile-time information about differentiation procedure (the width, and if we're using runtime activity), # - the [`Const`](@ref) annotation on `f` indicates that we accept a function `f` that does not have a derivative component, # which makes sense since `f` is not a closure with data that could be differentiated. # - the [`Duplicated`](@ref) annotation given in the second argument annotates the return value of `f`. This means that @@ -123,8 +123,9 @@ dy = [0.0, 0.0] # If a custom rule is specified for the correct function/argument types, but not the correct activity annotation, # a runtime error will be thrown alerting the user to the missing activity rule rather than silently ignoring the rule." -# Finally, it may be that either `x`, `y`, or the return value are marked as [`Const`](@ref). We can in fact handle this case, -# along with the previous two cases, all together in a single rule: +# Finally, it may be that either `x`, `y`, or the return value are marked as [`Const`](@ref), in which case we can simply return the original result. However, Enzyme also may determine the return is not differentiable and also not needed for other computations, in which case we should simply return nothing. +# +# We can in fact handle this case, along with the previous two cases, all together in a single rule by leveraging utility functions [`EnzymeRules.needs_primal`](@ref) and [`EnzymeRules.needs_shadow`](@ref), which return true if the original return or the derivative is needed to be returned, respectively: Base.delete_method.(methods(forward, (Const{typeof(f)}, Vararg{Any}))) # delete our old rules @@ -138,12 +139,14 @@ function forward(config, func::Const{typeof(f)}, RT::Type{<:Union{Const, Duplica make_zero!(y.dval) end dret = !(y isa Const) ? sum(y.dval) : zero(eltype(y.val)) - if RT <: Const + if needs_primal(config) && needs_shadow(config) + return Duplicated(sum(y.val), dret) + elseif needs_primal(config) return sum(y.val) - elseif RT <: DuplicatedNoNeed + elseif needs_shadow(config) return dret else - return Duplicated(sum(y.val), dret) + return nothing end end @@ -189,7 +192,7 @@ function augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(f)}, ::T end # Let's unpack our signature for `augmented_primal` : -# * We accepted a [`EnzymeRules.Config`](@ref) object with a specified width of 1, which means that our rule does not support batched reverse mode. +# * We accepted a [`EnzymeRules.RevConfig`](@ref) object with a specified width of 1, which means that our rule does not support batched reverse mode. # * We annotated `f` with [`Const`](@ref) as usual. # * We dispatched on an [`Active`](@ref) annotation for the return value. This is a special annotation for scalar values, such as our return value, # that indicates that that we care about the value's derivative but we need not explicitly allocate a mutable shadow since it is a scalar value. @@ -197,7 +200,7 @@ end # Now, let's unpack the body of our `augmented_primal` rule: # * We checked if the `config` requires the primal. If not, we need not compute the return value, but we make sure to mutate `y` in all cases. -# * We checked if `x` could possibly be overwritten using the `Overwritten` attribute of [`EnzymeRules.Config`](@ref). +# * We checked if `x` could possibly be overwritten using the `Overwritten` attribute of [`EnzymeRules.RevConfig`](@ref). # If so, we save the elements of `x` on the `tape` of the returned [`EnzymeRules.AugmentedReturn`](@ref) object. # * We return a shadow of `nothing` since the return value is [`Active`](@ref) and hence does not need a shadow. diff --git a/lib/EnzymeCore/src/rules.jl b/lib/EnzymeCore/src/rules.jl index 8d01d321da8..d4469e97931 100644 --- a/lib/EnzymeCore/src/rules.jl +++ b/lib/EnzymeCore/src/rules.jl @@ -35,7 +35,19 @@ Getters for the type parameters are provided by `needs_primal`, `needs_shadow`, struct FwdConfig{NeedsPrimal, NeedsShadow, Width, RuntimeActivity} end const FwdConfigWidth{Width} = FwdConfig{<:Any,<:Any,Width} +""" + needs_primal(::FwdConfig) + needs_primal(::RevConfig) + +Whether a custom rule should return the original result of the function. +""" @inline needs_primal(::FwdConfig{NeedsPrimal}) where NeedsPrimal = NeedsPrimal +""" + needs_shadow(::FwdConfig) + needs_shadow(::RevConfig) + +Whether a custom rule should return the shadow (derivative) of the function result. +""" @inline needs_shadow(::FwdConfig{<:Any, NeedsShadow}) where NeedsShadow = NeedsShadow @inline width(::FwdConfig{<:Any, <:Any, Width}) where Width = Width diff --git a/src/Enzyme.jl b/src/Enzyme.jl index c4994fb363b..2b6f1f3627d 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -531,39 +531,6 @@ code, as well as high-order differentiation. thunk(f, args...) end -""" - autodiff_deferred(mode::Mode, f, ::Type{A}, args) - -Like [`autodiff_deferred`](@ref) but will try to extend f to an annotation, if needed. -""" -@inline function autodiff_deferred(mode::CMode, f::F, args::Vararg{Annotation, Nargs}) where {F, CMode<:Mode, Nargs} - autodiff_deferred(EnzymeCore.set_err_if_func_written(mode), Const(f), args...) -end -@inline function autodiff_deferred(mode::CMode, f::F, ::Type{RT}, args::Vararg{Annotation, Nargs}) where {F, RT<:Annotation, CMode<:Mode, Nargs} - autodiff_deferred(EnzymeCore.set_err_if_func_written(mode), Const(f), RT, args...) -end - -""" - autodiff_deferred(mode, f, args...) - -Like [`autodiff_deferred`](@ref) but will try to guess the activity of the return value. -""" - -@inline function autodiff_deferred(mode::M, f::FA, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, M<:Mode, Nargs} - tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} - rt = if mode isa ReverseMode - Compiler.primal_return_type(mode, Val(codegen_world_age(eltype(FA), tt)), eltype(FA), tt) - else - Core.Compiler.return_type(f.val, tt) - end - - if rt === Union{} - error("return type is Union{}, giving up.") - end - rt = guess_activity(rt, mode) - autodiff_deferred(mode, f, rt, args...) -end - """ autodiff_thunk(::ReverseModeSplit, ftype, Activity, argtypes::Vararg{Type{<:Annotation, Nargs}) diff --git a/test/abi.jl b/test/abi.jl index 63fe48dc61e..342722c44dc 100644 --- a/test/abi.jl +++ b/test/abi.jl @@ -20,13 +20,13 @@ using Test @test () === autodiff(Forward, f, Const(nothing)) - res = autodiff_deferred(Reverse, f, Const(nothing)) + res = autodiff_deferred(Reverse, Const(f), Const, Const(nothing)) @test res === ((nothing,),) - res = autodiff_deferred(Enzyme.set_abi(Reverse, NonGenABI), f, Const, Const(nothing)) + res = autodiff_deferred(Enzyme.set_abi(Reverse, NonGenABI), Const(f), Const, Const(nothing)) @test res === ((nothing,),) - @test () === autodiff_deferred(Forward, f, Const(nothing)) - @test () === autodiff_deferred(Enzyme.set_abi(Forward, NonGenABI), f, Const, Const(nothing)) + @test () === autodiff_deferred(Forward, Const(f), Const, Const(nothing)) + @test () === autodiff_deferred(Enzyme.set_abi(Forward, NonGenABI), Const(f), Const, Const(nothing)) # ConstType -> Type{Int} res = autodiff(Reverse, f, Const, Const(Int)) @@ -37,9 +37,9 @@ using Test @test res === ((nothing,),) @test () === autodiff(Forward, f, Const(Int)) - res = autodiff_deferred(Reverse, f, Const(Int)) + res = autodiff_deferred(Reverse, Const(f), Const, Const(Int)) @test res === ((nothing,),) - @test () === autodiff_deferred(Forward, f, Const(Int)) + @test () === autodiff_deferred(Forward, Const(f), Const, Const(Int)) # Complex numbers @test_throws ErrorException autodiff(Reverse, f, Active, Active(1.5 + 0.7im)) @@ -54,10 +54,10 @@ using Test cres, = autodiff(Forward, f, Duplicated(1.5 + 0.7im, 1.0+0im)) @test cres ≈ 1.0 + 0.0im - @test_throws ErrorException autodiff_deferred(Reverse, f, Active(1.5 + 0.7im)) - @test_throws ErrorException autodiff_deferred(ReverseHolomorphic, f, Active(1.5 + 0.7im)) + @test_throws ErrorException autodiff_deferred(Reverse, Const(f), Active, Active(1.5 + 0.7im)) + @test_throws ErrorException autodiff_deferred(ReverseHolomorphic, Const(f), Active, Active(1.5 + 0.7im)) - cres, = autodiff_deferred(Forward, f, Duplicated(1.5 + 0.7im, 1.0+0im)) + cres, = autodiff_deferred(Forward, Const(f), Duplicated, Duplicated(1.5 + 0.7im, 1.0+0im)) @test cres ≈ 1.0 + 0.0im # Unused singleton argument @@ -97,7 +97,7 @@ using Test x = [0.0] dx = [1.2] - autodiff_deferred(Reverse, squareRetArray, Const, Duplicated(x, dx)) + autodiff_deferred(Reverse, Const(squareRetArray), Const, Duplicated(x, dx)) dx = [1.2] @test () === autodiff(Forward, squareRetArray, Const, Duplicated(x, dx)) @@ -113,7 +113,7 @@ using Test @test pair[1] ≈ 3.0 @test pair[2] ≈ 2.0 - pair = autodiff_deferred(Reverse, mul, Active(2.0), Active(3.0))[1] + pair = autodiff_deferred(Reverse, Const(mul), Active, Active(2.0), Active(3.0))[1] @test pair[1] ≈ 3.0 @test pair[2] ≈ 2.0 @@ -122,7 +122,7 @@ using Test @test pair[2] ≈ 2.0 @test orig ≈ 6.0 - pair, orig = autodiff_deferred(ReverseWithPrimal, mul, Active(2.0), Active(3.0)) + pair, orig = autodiff_deferred(ReverseWithPrimal, Const(mul), Active, Active(2.0), Active(3.0)) @test pair[1] ≈ 3.0 @test pair[2] ≈ 2.0 @test orig ≈ 6.0 @@ -142,7 +142,7 @@ using Test res = Ref(3.0) dres = Ref(1.0) - pair, orig = autodiff_deferred(ReverseWithPrimal, inplace, Const, Duplicated(res, dres)) + pair, orig = autodiff_deferred(ReverseWithPrimal, Const(inplace), Const, Duplicated(res, dres)) @test pair == (nothing,) @test res[] ≈ 6.0 @test dres[] ≈ 2.0 @@ -163,7 +163,7 @@ using Test res = Ref(3.0) dres = Ref(1.0) - pair, orig = autodiff_deferred(ReverseWithPrimal, inplace2, Const, Duplicated(res, dres)) + pair, orig = autodiff_deferred(ReverseWithPrimal, Const(inplace2), Const, Duplicated(res, dres)) @test pair == (nothing,) @test res[] ≈ 6.0 @test dres[] ≈ 2.0 @@ -450,7 +450,7 @@ end @test r[2] ≈ 100.0 @test r[1][1] ≈ -400.0 @test r[1][2] ≈ 200.0 - r = autodiff_deferred(ForwardWithPrimal, rosenbrock_inp, Duplicated, BatchDuplicated(x, (dx_1, dx_2))) + r = autodiff_deferred(ForwardWithPrimal, Const(rosenbrock_inp), Duplicated, BatchDuplicated(x, (dx_1, dx_2))) @test r[2] ≈ 100.0 @test r[1][1] ≈ -400.0 @test r[1][2] ≈ 200.0 diff --git a/test/amdgpu.jl b/test/amdgpu.jl index 9c9b0974229..75318ac97d1 100644 --- a/test/amdgpu.jl +++ b/test/amdgpu.jl @@ -11,7 +11,7 @@ function mul_kernel(A) end function grad_mul_kernel(A, dA) - autodiff_deferred(Reverse, mul_kernel, Const, Duplicated(A, dA)) + autodiff_deferred(Reverse, Const(mul_kernel), Const, Duplicated(A, dA)) return nothing end @@ -34,7 +34,7 @@ function exp_kernel(A) end function grad_exp_kernel(A, dA) - autodiff_deferred(Reverse, exp_kernel, Const, Duplicated(A, dA)) + autodiff_deferred(Reverse, Const(exp_kernel), Const, Duplicated(A, dA)) return nothing end @@ -57,7 +57,7 @@ function cos_kernel(A) end function grad_cos_kernel(A, dA) - autodiff_deferred(Reverse, cos_kernel, Const, Duplicated(A, dA)) + autodiff_deferred(Reverse, Const(cos_kernel), Const, Duplicated(A, dA)) return nothing end diff --git a/test/cuda.jl b/test/cuda.jl index 29a55dcfc83..736f667a879 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -11,7 +11,7 @@ function mul_kernel(A) end function grad_mul_kernel(A, dA) - autodiff_deferred(Reverse, mul_kernel, Const, Duplicated(A, dA)) + autodiff_deferred(Reverse, Const(mul_kernel), Const, Duplicated(A, dA)) return nothing end @@ -34,7 +34,7 @@ function exp_kernel(A) end function grad_exp_kernel(A, dA) - autodiff_deferred(Reverse, exp_kernel, Const, Duplicated(A, dA)) + autodiff_deferred(Reverse, Const(exp_kernel), Const, Duplicated(A, dA)) return nothing end @@ -57,7 +57,7 @@ function cos_kernel(A) end function grad_cos_kernel(A, dA) - autodiff_deferred(Reverse, cos_kernel, Const, Duplicated(A, dA)) + autodiff_deferred(Reverse, Const(cos_kernel), Const, Duplicated(A, dA)) return nothing end @@ -76,7 +76,7 @@ function val_kernel!(_, ::Val{N}) where N end function dval_kernel!(du, ::Val{N}) where {N} - autodiff_deferred(Reverse, val_kernel!, Const, du, Const(Val(N))) + autodiff_deferred(Reverse, Const(val_kernel!), Const, du, Const(Val(N))) return nothing end @@ -123,7 +123,7 @@ function ddense!( autodiff_deferred( Reverse, - dense!, + Const(dense!), Const, dfeats_out, dfeats_in, dW, db, Const(Val(nfeat_out)), Const(Val(nfeat_in)), Const(Val(ndof)) diff --git a/test/metal.jl b/test/metal.jl index 661bcfbedc8..588357c92e4 100644 --- a/test/metal.jl +++ b/test/metal.jl @@ -16,12 +16,12 @@ function fun_gpu!(A, B, a) end function ∇_fun_cpu!(A, Ā, B, B̄, a) - Enzyme.autodiff_deferred(Reverse, fun_cpu!, Const, DuplicatedNoNeed(A, Ā), DuplicatedNoNeed(B, B̄), Const(a)) + Enzyme.autodiff_deferred(Reverse, Const(fun_cpu!), Const, DuplicatedNoNeed(A, Ā), DuplicatedNoNeed(B, B̄), Const(a)) nothing end function ∇_fun_gpu!(A_d, Ā_d, B_d, B̄_d, a) - Enzyme.autodiff_deferred(Reverse, fun_gpu!, Const, Duplicated(A_d, Ā_d), Duplicated(B_d, B̄_d), Const(a)) + Enzyme.autodiff_deferred(Reverse, Const(fun_gpu!), Const, Duplicated(A_d, Ā_d), Duplicated(B_d, B̄_d), Const(a)) nothing end diff --git a/test/runtests.jl b/test/runtests.jl index 65ad4e3fd4e..d99a28832bd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -436,7 +436,7 @@ end def_A, thunk_A = copy(A), copy(A) primal = Enzyme.autodiff(ReverseWithPrimal, dot, Active, Duplicated(A, dA))[2] @test primal == 34.0 - primal = Enzyme.autodiff_deferred(ReverseWithPrimal, dot, Active, Duplicated(def_A, def_dA))[2] + primal = Enzyme.autodiff_deferred(ReverseWithPrimal, Const(dot), Active, Duplicated(def_A, def_dA))[2] @test primal == 34.0 dup = Duplicated(thunk_A, thunk_dA) @@ -752,7 +752,7 @@ end @testset "Nested AD" begin tonest(x,y) = (x + y)^2 - @test autodiff(Forward, (x,y) -> autodiff_deferred(Forward, tonest, Duplicated(x, 1.0), Const(y))[1], Const(1.0), Duplicated(2.0, 1.0))[1] ≈ 2.0 + @test autodiff(Forward, (x,y) -> autodiff(Forward, Const(tonest), Duplicated(x, 1.0), Const(y))[1], Const(1.0), Duplicated(2.0, 1.0))[1] ≈ 2.0 end @testset "Hessian" begin @@ -762,7 +762,7 @@ end end function grad(x, dx, y, dy) - Enzyme.autodiff_deferred(Reverse, origf, Duplicated(x, dx), DuplicatedNoNeed(y, dy)) + Enzyme.autodiff(Reverse, Const(origf), Duplicated(x, dx), DuplicatedNoNeed(y, dy)) nothing end @@ -797,7 +797,7 @@ end function f_gradient_deferred!(dx, x, tmp) dtmp = make_zero(tmp) - autodiff_deferred(Reverse, f_ip, Active, Duplicated(x, dx), Duplicated(tmp, dtmp)) + autodiff_deferred(Reverse, Const(f_ip), Active, Duplicated(x, dx), Duplicated(tmp, dtmp)) return nothing end @@ -828,7 +828,7 @@ end function nested_df!(dx, x) make_zero!(dx) - autodiff_deferred(Reverse, nested_f, Active, Duplicated(x, dx)) + autodiff_deferred(Reverse, Const(nested_f), Active, Duplicated(x, dx)) return nothing end @@ -1869,7 +1869,7 @@ end @testset "Mismatched return" begin @test_throws ErrorException autodiff(Reverse, _->missing, Active, Active(2.1)) - @test_throws ErrorException autodiff_deferred(Reverse, _->missing, Active, Active(2.1)) + @test_throws ErrorException autodiff_deferred(Reverse, Const(_->missing), Active, Active(2.1)) end @testset "GCPreserve" begin From efbe9a110ebc89c7b4534c03c93bba46eaa5a634 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 18 Sep 2024 13:32:33 -0500 Subject: [PATCH 37/87] Add within autodiff cmd (#1851) * Add within autodiff cmd * fix * fixup * fix * fix * fix --- lib/EnzymeCore/src/EnzymeCore.jl | 8 ++++++++ src/Enzyme.jl | 11 +++++++++-- src/compiler/interpreter.jl | 17 ++++++++++++++++- test/abi.jl | 6 ++++++ 4 files changed, 39 insertions(+), 3 deletions(-) diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index cc71f0f9c6b..f51c742f5d4 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -6,6 +6,7 @@ export Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplic export MixedDuplicated, BatchMixedDuplicated export DefaultABI, FFIABI, InlineABI, NonGenABI export BatchDuplicatedFunc +export within_autodiff function batch_size end @@ -338,4 +339,11 @@ if !isdefined(Base, :get_extension) include("../ext/AdaptExt.jl") end +""" + within_autodiff() + +Returns true if within autodiff, otherwise false. +""" +function within_autodiff end + end # module EnzymeCore diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 2b6f1f3627d..985e8deeeab 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -5,8 +5,8 @@ import EnzymeCore import EnzymeCore: Forward, ForwardWithPrimal, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode, ReverseHolomorphic, ReverseHolomorphicWithPrimal export Forward, ForwardWithPrimal, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode, ReverseHolomorphic, ReverseHolomorphicWithPrimal -import EnzymeCore: Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, ABI, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written, set_abi, set_runtime_activity, clear_runtime_activity -export Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written, set_abi, set_runtime_activity, clear_runtime_activity +import EnzymeCore: Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, ABI, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written, set_abi, set_runtime_activity, clear_runtime_activity, within_autodiff +export Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written, set_abi, set_runtime_activity, clear_runtime_activity, within_autodiff import EnzymeCore: BatchDuplicatedFunc export BatchDuplicatedFunc @@ -1744,4 +1744,11 @@ macro import_rrule(args...) return _import_rrule(args...) end +""" + within_autodiff() + +Returns true if within autodiff, otherwise false. +""" +@inline EnzymeCore.within_autodiff() = false + end # module diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 482690e20f6..c167581c3a9 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -213,7 +213,7 @@ let # overload `inlining_policy` end import Core.Compiler: abstract_call, abstract_call_known, ArgInfo, StmtInfo, AbsIntState, get_max_methods, - CallMeta, Effects, NoCallInfo, widenconst, mapany + CallMeta, Effects, NoCallInfo, widenconst, mapany, MethodResultPure struct AutodiffCallInfo <: CallInfo # ... @@ -225,6 +225,21 @@ function abstract_call_known(interp::EnzymeInterpreter, @nospecialize(f), max_methods::Int = get_max_methods(interp, f, sv)) (; fargs, argtypes) = arginfo + + if f === Enzyme.within_autodiff + if length(argtypes) != 1 + @static if VERSION < v"1.11.0-" + return CallMeta(Union{}, Effects(), NoCallInfo()) + else + return CallMeta(Union{}, Union{}, Effects(), NoCallInfo()) + end + end + @static if VERSION < v"1.11.0-" + return CallMeta(Core.Const(true), Core.Compiler.EFFECTS_TOTAL, MethodResultPure()) + else + return CallMeta(Core.Const(true), Union{}, Core.Compiler.EFFECTS_TOTAL, MethodResultPure()) + end + end if f === Enzyme.autodiff && length(argtypes) >= 4 if widenconst(argtypes[2]) <: Enzyme.Mode && widenconst(argtypes[3]) <: Enzyme.Annotation && widenconst(argtypes[4]) <: Type{<:Enzyme.Annotation} diff --git a/test/abi.jl b/test/abi.jl index 342722c44dc..cbd467c1555 100644 --- a/test/abi.jl +++ b/test/abi.jl @@ -460,6 +460,12 @@ abssum(x) = sum(abs2, x); mulsin(x) = sin(x[1] * x[2]) +@testset "within_autodiff" begin + @test !Enzyme.within_autodiff() + @test_broken Enzyme.autodiff(ForwardWithPrimal, Enzyme.within_autodiff)[1] + @test Enzyme.autodiff(ForwardWithPrimal, () -> Enzyme.within_autodiff())[1] +end + @testset "Type inference" begin x = ones(10) @inferred autodiff(Enzyme.Reverse, abssum, Duplicated(x,x)) From f49d1fc68c3a66aeed2883fb308bec92ebd30103 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 18 Sep 2024 16:42:17 -0500 Subject: [PATCH 38/87] Fix enzymetestutils tests (#1858) --- lib/EnzymeTestUtils/test/test_forward.jl | 1 + lib/EnzymeTestUtils/test/test_reverse.jl | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/EnzymeTestUtils/test/test_forward.jl b/lib/EnzymeTestUtils/test/test_forward.jl index 57385a1dd98..24de5b2f44f 100644 --- a/lib/EnzymeTestUtils/test/test_forward.jl +++ b/lib/EnzymeTestUtils/test/test_forward.jl @@ -20,6 +20,7 @@ function f_kwargs_fwd!(x; kwargs...) end function EnzymeRules.forward( + config, func::Const{typeof(f_kwargs_fwd)}, RT::Type{ <:Union{Const,Duplicated,DuplicatedNoNeed,BatchDuplicated,BatchDuplicatedNoNeed} diff --git a/lib/EnzymeTestUtils/test/test_reverse.jl b/lib/EnzymeTestUtils/test/test_reverse.jl index b394fa171d0..901c259af88 100644 --- a/lib/EnzymeTestUtils/test/test_reverse.jl +++ b/lib/EnzymeTestUtils/test/test_reverse.jl @@ -17,7 +17,7 @@ function f_kwargs_rev!(x; kwargs...) end function EnzymeRules.augmented_primal( - config::EnzymeRules.ConfigWidth{1}, + config::EnzymeRules.RevConfigWidth{1}, func::Const{typeof(f_kwargs_rev)}, RT::Type{<:Union{Const,Duplicated,DuplicatedNoNeed}}, x::Union{Const,Duplicated}; @@ -39,7 +39,7 @@ function EnzymeRules.augmented_primal( end function EnzymeRules.reverse( - config::EnzymeRules.ConfigWidth{1}, + config::EnzymeRules.RevConfigWidth{1}, func::Const{typeof(f_kwargs_rev)}, dret::Type{<:Union{Const,Duplicated,DuplicatedNoNeed}}, tape, From 6e867ba81bab2abafaed85f56a0f6e7cc38b01a2 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 18 Sep 2024 16:46:45 -0500 Subject: [PATCH 39/87] Try llvm.jl 9.1 (#1857) * Try llvm.jl 9.1 * fixups * more fix * bump version * fix * fix * fix --- Project.toml | 4 ++-- src/compiler.jl | 46 +++++++++++++++++++++++++------------- src/compiler/optimize.jl | 20 ++++++++--------- src/compiler/orcv2.jl | 2 +- src/compiler/utils.jl | 10 ++++++++- src/compiler/validation.jl | 4 ++-- 6 files changed, 55 insertions(+), 31 deletions(-) diff --git a/Project.toml b/Project.toml index 19315d01dca..9cf8028a76e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.0" +version = "0.13.1" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -37,7 +37,7 @@ ChainRulesCore = "1" EnzymeCore = "0.8" Enzyme_jll = "0.0.150" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27" -LLVM = "6.1, 7, 8, =9.0" +LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" ObjectFile = "0.4" Preferences = "1.4" diff --git a/src/compiler.jl b/src/compiler.jl index 1d21fb99a1d..be4679d263f 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -4045,6 +4045,22 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr return adjointf, augmented_primalf, TapeType end +function get_subprogram(f::LLVM.Function) + @static if isdefined(LLVM, :subprogram) + LLVM.subprogram(f) + else + LLVM.get_subprogram(f) + end +end + +function set_subprogram!(f::LLVM.Function, sp) + @static if isdefined(LLVM, :subprogram) + LLVM.subprogram!(f, sp) + else + LLVM.set_subprogram!(f, sp) + end +end + function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, Mode::API.CDerivativeMode, augmented, width, returnPrimal, shadow_init, world, interp) is_adjoint = Mode == API.DEM_ReverseModeGradient || Mode == API.DEM_ReverseModeCombined is_split = Mode == API.DEM_ReverseModeGradient || Mode == API.DEM_ReverseModePrimal @@ -4422,8 +4438,8 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, push!(args, psret) end res = LLVM.call!(builder, LLVM.function_type(llvmf), llvmf, args) - if LLVM.get_subprogram(llvmf) !== nothing - metadata(res)[LLVM.MD_dbg] = DILocation( 0, 0, LLVM.get_subprogram(llvm_f) ) + if get_subprogram(llvmf) !== nothing + metadata(res)[LLVM.MD_dbg] = DILocation( 0, 0, get_subprogram(llvm_f) ) end if psret !== nothing res = load!(builder, convert(LLVMType, Func_RT), psret) @@ -4449,8 +4465,8 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, end val = call!(builder, LLVM.function_type(enzymefn), enzymefn, realparms) - if LLVM.get_subprogram(llvm_f) !== nothing - metadata(val)[LLVM.MD_dbg] = DILocation( 0, 0, LLVM.get_subprogram(llvm_f) ) + if get_subprogram(llvm_f) !== nothing + metadata(val)[LLVM.MD_dbg] = DILocation( 0, 0, get_subprogram(llvm_f) ) end @inline function fixup_abi(index, value) @@ -4514,8 +4530,8 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, push!(function_attributes(cf), EnumAttribute("alwaysinline", 0)) for shadowv in shadows c = call!(builder, LLVM.function_type(cf), cf, [shadowv]) - if LLVM.get_subprogram(llvm_f) !== nothing - metadata(c)[LLVM.MD_dbg] = DILocation( 0, 0, LLVM.get_subprogram(llvm_f) ) + if get_subprogram(llvm_f) !== nothing + metadata(c)[LLVM.MD_dbg] = DILocation( 0, 0, get_subprogram(llvm_f) ) end end end @@ -5027,9 +5043,9 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function wrapper_ft = LLVM.FunctionType(RT, wrapper_types) wrapper_f = LLVM.Function(mod, LLVM.name(entry_f), wrapper_ft) callconv!(wrapper_f, callconv(entry_f)) - sfn = LLVM.get_subprogram(entry_f) + sfn = get_subprogram(entry_f) if sfn !== nothing - LLVM.set_subprogram!(wrapper_f, sfn) + set_subprogram!(wrapper_f, sfn) end hasReturnsTwice = any(map(k->kind(k)==kind(EnumAttribute("returns_twice")), collect(function_attributes(entry_f)))) @@ -5107,8 +5123,8 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function entry = BasicBlock(wrapper_f, "entry") position!(builder, entry) - if LLVM.get_subprogram(entry_f) !== nothing - debuglocation!(builder, DILocation(0, 0, LLVM.get_subprogram(entry_f))) + if get_subprogram(entry_f) !== nothing + debuglocation!(builder, DILocation(0, 0, get_subprogram(entry_f))) end wrapper_args = Vector{LLVM.Value}() @@ -5178,8 +5194,8 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function end res = call!(builder, LLVM.function_type(entry_f), entry_f, wrapper_args) - if LLVM.get_subprogram(entry_f) !== nothing - metadata(res)[LLVM.MD_dbg] = DILocation( 0, 0, LLVM.get_subprogram(entry_f) ) + if get_subprogram(entry_f) !== nothing + metadata(res)[LLVM.MD_dbg] = DILocation( 0, 0, get_subprogram(entry_f) ) end callconv!(res, LLVM.callconv(entry_f)) @@ -5411,10 +5427,10 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function LLVM.run!(pm, mod) end if haskey(globals(mod), "llvm.used") - unsafe_delete!(mod, globals(mod)["llvm.used"]) + eraseInst(mod, globals(mod)["llvm.used"]) for u in user.(collect(uses(entry_f))) if isa(u, LLVM.GlobalVariable) && endswith(LLVM.name(u), "_slot") && startswith(LLVM.name(u), "julia") - unsafe_delete!(mod, u) + eraseInst(mod, u) end end end @@ -6469,7 +6485,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; st = LLVM.user(u) LLVM.API.LLVMInstructionEraseFromParent(st) end - LLVM.unsafe_delete!(mod, f) + eraseInst(mod, f) end linkage!(adjointf, LLVM.API.LLVMExternalLinkage) diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index 8c6385edb8c..2e3e8194c97 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -533,7 +533,7 @@ function memcpy_alloca_to_loadstore(mod::LLVM.Module) end end for inst in todel - unsafe_delete!(LLVM.parent(inst), inst) + eraseInst(LLVM.parent(inst), inst) end end end @@ -1145,7 +1145,7 @@ function prop_global!(g) end end replace_uses!(var, res) - unsafe_delete!(LLVM.parent(var), var) + eraseInst(LLVM.parent(var), var) continue end if isa(var, LLVM.AddrSpaceCastInst) @@ -1441,7 +1441,7 @@ function propagate_returned!(mod::LLVM.Module) end if !illegalUse for c in reverse(torem) - unsafe_delete!(LLVM.parent(c), c) + eraseInst(LLVM.parent(c), c) end B = IRBuilder() position!(B, first(instructions(first(blocks(fn))))) @@ -1617,7 +1617,7 @@ function propagate_returned!(mod::LLVM.Module) end API.EnzymeSetCalledFunction(un, nfn, toremove) end - unsafe_delete!(mod, fn) + eraseInst(mod, fn) changed = true catch break @@ -2030,26 +2030,26 @@ function removeDeadArgs!(mod::LLVM.Module, tm) for u in LLVM.uses(rfunc) u = LLVM.user(u) - unsafe_delete!(LLVM.parent(u), u) + eraseInst(LLVM.parent(u), u) end - unsafe_delete!(mod, rfunc) + eraseInst(mod, rfunc) for u in LLVM.uses(sfunc) u = LLVM.user(u) - unsafe_delete!(LLVM.parent(u), u) + eraseInst(LLVM.parent(u), u) end - unsafe_delete!(mod, sfunc) + eraseInst(mod, sfunc) for fn in functions(mod) for b in blocks(fn) inst = first(LLVM.instructions(b)) if isa(inst, LLVM.CallInst) fn = LLVM.called_operand(inst) if fn == func - unsafe_delete!(b, inst) + eraseInst(b, inst) end end end end - unsafe_delete!(mod, func) + eraseInst(mod, func) end function optimize!(mod::LLVM.Module, tm) diff --git a/src/compiler/orcv2.jl b/src/compiler/orcv2.jl index 40d13eea805..78ff089e7d3 100644 --- a/src/compiler/orcv2.jl +++ b/src/compiler/orcv2.jl @@ -224,7 +224,7 @@ function get_trampoline(job) # but it would be nicer if _thunk just codegen'd the half # we need. other_func = functions(mod)[other_name] - LLVM.unsafe_delete!(mod, other_func) + Compiler.eraseInst(mod, other_func) end tsm = move_to_threadsafe(mod) diff --git a/src/compiler/utils.jl b/src/compiler/utils.jl index 6615b6bd405..cde5d2cade4 100644 --- a/src/compiler/utils.jl +++ b/src/compiler/utils.jl @@ -313,6 +313,14 @@ function reinsert_gcmarker!(func, PB=nothing) end end +function eraseInst(bb, inst) + @static if isdefined(LLVM, Symbol("erase!")) + LLVM.erase!(inst) + else + unsafe_delete!(bb, inst) + end +end + function unique_gcmarker!(func) entry_bb = first(blocks(func)) pgcstack_func = declare_pgcstack!(LLVM.parent(func)) @@ -327,7 +335,7 @@ function unique_gcmarker!(func) for i in 2:length(found) LLVM.replace_uses!(found[i], found[1]) ops = LLVM.collect(operands(found[i])) - Base.unsafe_delete!(entry_bb, found[i]) + eraseInst(entry_bb, found[i]) end end return nothing diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 51aeacf675a..3df37be1175 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -112,7 +112,7 @@ function restore_lookups(mod::LLVM.Module) if haskey(functions(mod), k) f = functions(mod)[k] replace_uses!(f, LLVM.Value(LLVM.API.LLVMConstIntToPtr(ConstantInt(T_size_t, convert(UInt, v)), value_type(f)))) - unsafe_delete!(mod, f) + eraseInst(mod, f) end end end @@ -272,7 +272,7 @@ function check_ir!(job, errors, mod::LLVM.Module) mfn = LLVM.API.LLVMAddFunction(mod, "malloc", LLVM.FunctionType(ptr8, parameters(prev_ft))) replace_uses!(f, LLVM.Value(LLVM.API.LLVMConstPointerCast(mfn, value_type(f)))) - unsafe_delete!(mod, f) + eraseInst(mod, f) end rewrite_ccalls!(mod) for f in collect(functions(mod)) From 5a5beea54eabc9117a371397847dfaa62cd6c161 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 19 Sep 2024 14:47:17 -0500 Subject: [PATCH 40/87] runtime activity lookup on val not type (#1862) * runtime activity lookup on val not type * fix --------- Co-authored-by: William Moses --- lib/EnzymeCore/Project.toml | 2 +- lib/EnzymeCore/src/rules.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/EnzymeCore/Project.toml b/lib/EnzymeCore/Project.toml index 2d39f92f452..0b688f27a87 100644 --- a/lib/EnzymeCore/Project.toml +++ b/lib/EnzymeCore/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeCore" uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" authors = ["William Moses ", "Valentin Churavy "] -version = "0.8.0" +version = "0.8.1" [compat] Adapt = "3, 4" diff --git a/lib/EnzymeCore/src/rules.jl b/lib/EnzymeCore/src/rules.jl index d4469e97931..3da3e318a7b 100644 --- a/lib/EnzymeCore/src/rules.jl +++ b/lib/EnzymeCore/src/rules.jl @@ -293,6 +293,6 @@ Mark a particular type `Ty` as always being inactive. """ inactive_type(::Type) = false -@inline EnzymeCore.set_runtime_activity(::M, ::Config) where {M<:Mode, Config <: Union{FwdConfig, RevConfig}} = EnzymeCore.set_runtime_activity(M, runtime_activity(Config)) +@inline EnzymeCore.set_runtime_activity(mode::M, config::Config) where {M<:Mode, Config <: Union{FwdConfig, RevConfig}} = EnzymeCore.set_runtime_activity(mode, runtime_activity(config)) end # EnzymeRules From f14bd4a6bb3de73c205ca181bf8f69d71365822c Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 19 Sep 2024 20:49:22 -0500 Subject: [PATCH 41/87] Fix jac nout (#1864) --- Project.toml | 2 +- src/Enzyme.jl | 24 +++++++++++++++++++++--- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 9cf8028a76e..3c93057f906 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.1" +version = "0.13.2" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 985e8deeeab..a439cf430cf 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1417,8 +1417,8 @@ end jacobian(::ReverseMode, f, x) Compute the jacobian of a array-output function `f` using (potentially vector) -reverse mode. The `chunk` argument denotes the chunk size to use and `n_outs` -denotes the shape of the array returned by `f`. +reverse mode. The `chunk` argument optionally denotes the chunk size to use and +`n_outs` optionally denotes the shape of the array returned by `f` (e.g `size(f(x))`). Example: @@ -1434,12 +1434,30 @@ jacobian(Reverse, f, [2.0, 3.0, 4.0]) ```jldoctest f(x) = [ x[1] * x[2], x[2] + x[3] ] +grad = jacobian(ReverseWithPrimal, f, [2.0, 3.0, 4.0]) + +# output +(derivs = ([3.0 2.0 0.0; 0.0 1.0 1.0],), val = [6.0, 7.0]) +``` + +```jldoctest +f(x) = [ x[1] * x[2], x[2] + x[3] ] + grad = jacobian(Reverse, f, [2.0, 3.0, 4.0], n_outs=Val((2,))) # output ([3.0 2.0 0.0; 0.0 1.0 1.0],) ``` +```jldoctest +f(x) = [ x[1] * x[2], x[2] + x[3] ] + +grad = jacobian(ReverseWithPrimal, f, [2.0, 3.0, 4.0], n_outs=Val((2,))) + +# output +(derivs = ([3.0 2.0 0.0; 0.0 1.0 1.0],), val = [6.0, 7.0]) +``` + This function will return an AbstractArray whose shape is `(size(output)..., size(input)...)`. No guarantees are presently made about the type of the AbstractArray returned by this function (which may or may not be the same as the input AbstractArray if provided). @@ -1573,7 +1591,7 @@ this function will retun an AbstractArray of shape `size(output)` of values of t end if ReturnPrimal # TODO optimize away redundant fwd pass - (; derivs=res, val=if f isa Enzyme.Const + (; derivs=(res,), val=if f isa Enzyme.Const f.val(x) else f(x) From 00037e7ff8fb32f36691bbdba5ce8dc251fe2dec Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 21 Sep 2024 02:16:23 -0500 Subject: [PATCH 42/87] Cleanup (#1872) * Return type config * cleanup * Update runtests.jl --------- Co-authored-by: William Moses --- lib/EnzymeCore/Project.toml | 2 +- lib/EnzymeCore/src/rules.jl | 15 ++---- lib/EnzymeTestUtils/src/generate_tangent.jl | 10 +--- lib/EnzymeTestUtils/src/test_reverse.jl | 7 +-- lib/EnzymeTestUtils/test/test_forward.jl | 53 +++++++-------------- src/rules/customrules.jl | 8 ++-- test/runtests.jl | 15 +++--- test/threads.jl | 5 -- 8 files changed, 35 insertions(+), 80 deletions(-) diff --git a/lib/EnzymeCore/Project.toml b/lib/EnzymeCore/Project.toml index 0b688f27a87..37ddaf64576 100644 --- a/lib/EnzymeCore/Project.toml +++ b/lib/EnzymeCore/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeCore" uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" authors = ["William Moses ", "Valentin Churavy "] -version = "0.8.1" +version = "0.8.2" [compat] Adapt = "3, 4" diff --git a/lib/EnzymeCore/src/rules.jl b/lib/EnzymeCore/src/rules.jl index 3da3e318a7b..a7563a2ef73 100644 --- a/lib/EnzymeCore/src/rules.jl +++ b/lib/EnzymeCore/src/rules.jl @@ -77,17 +77,21 @@ const RevConfigWidth{Width} = RevConfig{<:Any,<:Any, Width} @inline runtime_activity(::RevConfig{<:Any, <:Any, <:Any, <:Any, RuntimeActivity}) where RuntimeActivity = RuntimeActivity """ + primal_type(::FwdConfig, ::Type{<:Annotation{RT}}) primal_type(::RevConfig, ::Type{<:Annotation{RT}}) Compute the exepcted primal return type given a reverse mode config and return activity """ +@inline primal_type(config::FwdConfig, ::Type{<:Annotation{RT}}) where RT = needs_primal(config) ? RT : Nothing @inline primal_type(config::RevConfig, ::Type{<:Annotation{RT}}) where RT = needs_primal(config) ? RT : Nothing """ + shadow_type(::FwdConfig, ::Type{<:Annotation{RT}}) shadow_type(::RevConfig, ::Type{<:Annotation{RT}}) Compute the exepcted shadow return type given a reverse mode config and return activity """ +@inline shadow_type(config::FwdConfig, ::Type{<:Annotation{RT}}) where RT = needs_shadow(config) ? (width(config) == 1 ? RT : NTuple{width(config), RT}) : Nothing @inline shadow_type(config::RevConfig, ::Type{<:Annotation{RT}}) where RT = needs_shadow(config) ? (width(config) == 1 ? RT : NTuple{width(config), RT}) : Nothing """ @@ -191,9 +195,6 @@ function isapplicable(@nospecialize(f), @nospecialize(TT); caller::Union{Nothing,Core.MethodInstance}=nothing) tt = Base.to_tuple_type(TT) sig = Base.signature_type(f, tt) - @static if VERSION < v"1.7.0" - return !isempty(Base._methods_by_ftype(sig, -1, world)) - end mt = ccall(:jl_method_table_for, Any, (Any,), sig) mt isa Core.MethodTable || return false if method_table === nothing @@ -234,14 +235,6 @@ function add_mt_backedge!(caller::Core.MethodInstance, mt::Core.MethodTable, @no return nothing end -function issupported() - @static if VERSION < v"1.7.0" - return false - else - return true - end -end - """ inactive(func::typeof(f), args...) diff --git a/lib/EnzymeTestUtils/src/generate_tangent.jl b/lib/EnzymeTestUtils/src/generate_tangent.jl index d774036e7e0..91822a509fc 100644 --- a/lib/EnzymeTestUtils/src/generate_tangent.jl +++ b/lib/EnzymeTestUtils/src/generate_tangent.jl @@ -60,14 +60,8 @@ end # get around the constructors and make the type directly # Note this is moderately evil accessing julia's internals -if VERSION >= v"1.3" - @generated function _force_construct(T, args...) - return Expr(:splatnew, :T, :args) - end -else - @generated function _force_construct(T, args...) - return Expr(:new, :T, Any[:(args[$i]) for i in 1:length(args)]...) - end +@generated function _force_construct(T, args...) + return Expr(:splatnew, :T, :args) end function _construct(T, args...) diff --git a/lib/EnzymeTestUtils/src/test_reverse.jl b/lib/EnzymeTestUtils/src/test_reverse.jl index 6c20aebb7aa..543f5de6999 100644 --- a/lib/EnzymeTestUtils/src/test_reverse.jl +++ b/lib/EnzymeTestUtils/src/test_reverse.jl @@ -7,12 +7,7 @@ for N in 1:30 eval(quote function call_with_kwargs(fkwargs::NT, f::FT, $(argexprs...)) where {NT, FT} Base.@_inline_meta - @static if VERSION ≤ v"1.8" - # callsite inline syntax unsupported in <= 1.8 - f($(argexprs...); fkwargs...) - else - @inline f($(argexprs...); fkwargs...) - end + @inline f($(argexprs...); fkwargs...) end end) end diff --git a/lib/EnzymeTestUtils/test/test_forward.jl b/lib/EnzymeTestUtils/test/test_forward.jl index 24de5b2f44f..5f8e5e7c6cf 100644 --- a/lib/EnzymeTestUtils/test/test_forward.jl +++ b/lib/EnzymeTestUtils/test/test_forward.jl @@ -87,9 +87,6 @@ end elseif TT <: NamedTuple x = (a=randn(T), b=randn(T)) else # TT <: TestStruct - if VERSION <= v"1.8" && Tx == BatchDuplicated - continue - end x = TestStruct(randn(T, 5), randn(T)) end atol = rtol = sqrt(eps(real(T))) @@ -117,38 +114,26 @@ end a = randn(T) atol = rtol = sqrt(eps(real(T))) - if VERSION < v"1.8" && ( - Tret <: BatchDuplicated || - Tx <: BatchDuplicated || - Ta <: BatchDuplicated - ) - @test !fails() do - test_forward(f_multiarg, Tret, (x, Tx), (a, Ta); atol, rtol) - end skip = true - else - @test !fails() do - test_forward(f_multiarg, Tret, (x, Tx), (a, Ta); atol, rtol) - end broken = ( - VERSION < v"1.8" && Tx <: Const && !(Ta <: Const) && T <: Complex - ) - end + @test !fails() do + test_forward(f_multiarg, Tret, (x, Tx), (a, Ta); atol, rtol) + end end end - VERSION >= v"1.8" && @testset "structured array inputs/outputs" begin - @testset for Tret in (Const, Duplicated, BatchDuplicated), - Tx in (Const, Duplicated, BatchDuplicated), - T in (Float32, Float64, ComplexF32, ComplexF64) + @testset "structured array inputs/outputs" begin + @testset for Tret in (Const, Duplicated, BatchDuplicated), + Tx in (Const, Duplicated, BatchDuplicated), + T in (Float32, Float64, ComplexF32, ComplexF64) - # if some are batch, none must be duplicated - are_activities_compatible(Tret, Tx) || continue + # if some are batch, none must be duplicated + are_activities_compatible(Tret, Tx) || continue - x = Hermitian(randn(T, 5, 5)) + x = Hermitian(randn(T, 5, 5)) - atol = rtol = sqrt(eps(real(T))) - test_forward(f_structured_array, Tret, (x, Tx); atol, rtol) - end - end + atol = rtol = sqrt(eps(real(T))) + test_forward(f_structured_array, Tret, (x, Tx); atol, rtol) + end + end @testset "equivalent arrays in output" begin function f(x) @@ -197,7 +182,7 @@ end atol = rtol = sqrt(eps(real(T))) @test !fails() do test_forward(f_mut_fwd!, Tret, (y, Ty), (x, Tx), (a, Ta); atol, rtol, runtime_activity=true) - end skip = (VERSION < v"1.8" && T <: Complex) + end end end @@ -230,13 +215,7 @@ end atol = rtol = sqrt(eps(real(T))) @test !fails() do test_forward((c, Tc), Tret, (y, Ty); atol, rtol) - end skip = ( - VERSION < v"1.8" && ( - Tret <: BatchDuplicated || - Tc <: BatchDuplicated || - Ty <: BatchDuplicated - ) - ) + end end end end diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index e0eae36e4d8..08cd15facbc 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -497,7 +497,7 @@ end if RT <: Const if needsPrimal if RealRt != fwd_RT - emit_error(B, orig, "Enzyme: incorrect return type of const primal-only forward custom rule - "*(string(RT))*" "*string(activity)*" want just return type "*string(RealRt)*" found "*string(fwd_RT)) + emit_error(B, orig, "Enzyme: incorrect return type of const primal-only forward custom rule - $C "*(string(RT))*" "*string(activity)*" want just return type "*string(RealRt)*" found "*string(fwd_RT)) return false end if get_return_info(RealRt)[2] !== nothing @@ -508,7 +508,7 @@ end end else if Nothing != fwd_RT - emit_error(B, orig, "Enzyme: incorrect return type of const no-primal forward custom rule - "*(string(RT))*" "*string(activity)*" want just return type Nothing found "*string(fwd_RT)) + emit_error(B, orig, "Enzyme: incorrect return type of const no-primal forward custom rule - $C "*(string(RT))*" "*string(activity)*" want just return type Nothing found "*string(fwd_RT)) return false end end @@ -519,7 +519,7 @@ end ST = NTuple{Int(width), ST} end if ST != fwd_RT - emit_error(B, orig, "Enzyme: incorrect return type of shadow-only forward custom rule - "*(string(RT))*" "*string(activity)*" want just shadow type "*string(ST)*" found "*string(fwd_RT)) + emit_error(B, orig, "Enzyme: incorrect return type of shadow-only forward custom rule - $C "*(string(RT))*" "*string(activity)*" want just shadow type "*string(ST)*" found "*string(fwd_RT)) return false end if get_return_info(RealRt)[2] !== nothing @@ -539,7 +539,7 @@ end BatchDuplicated{RealRt, Int(width)} end if ST != fwd_RT - emit_error(B, orig, "Enzyme: incorrect return type of prima/shadow forward custom rule - "*(string(RT))*" "*string(activity)*" want just shadow type "*string(ST)*" found "*string(fwd_RT)) + emit_error(B, orig, "Enzyme: incorrect return type of prima/shadow forward custom rule - $C "*(string(RT))*" "*string(activity)*" want just shadow type "*string(ST)*" found "*string(fwd_RT)) return false end if get_return_info(RealRt)[2] !== nothing diff --git a/test/runtests.jl b/test/runtests.jl index d99a28832bd..573140f2c27 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -85,14 +85,13 @@ end include("abi.jl") include("typetree.jl") -@static if Enzyme.EnzymeRules.issupported() - include("rules.jl") - include("rrules.jl") - include("kwrules.jl") - include("kwrrules.jl") - include("internal_rules.jl") - include("ruleinvalidation.jl") -end +include("rules.jl") +include("rrules.jl") +include("kwrules.jl") +include("kwrrules.jl") +include("internal_rules.jl") +include("ruleinvalidation.jl") + @static if !Sys.iswindows() include("blas.jl") end diff --git a/test/threads.jl b/test/threads.jl index 6899d8d2d66..9a06869c884 100644 --- a/test/threads.jl +++ b/test/threads.jl @@ -73,14 +73,9 @@ end out = [1.0, 2.0] dout = [1.0, 1.0] -@static if VERSION < v"1.8" - # GPUCompiler causes a stack overflow due to https://github.com/JuliaGPU/GPUCompiler.jl/issues/587 - # @test_throws AssertionError autodiff(Reverse, f_multi, Const, Duplicated(out, dout), Active(2.0)) -else res = autodiff(Reverse, f_multi, Const, Duplicated(out, dout), Active(2.0)) @test res[1][2] ≈ 2.0 end -end @testset "Closure-less threads $(Threads.nthreads())" begin function bf(i, x) From 0d6fe67ff400218d24d8c4aee9591852d8f90710 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 21 Sep 2024 14:48:31 -0500 Subject: [PATCH 43/87] fix (#1877) --- src/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index be4679d263f..07cbdeb65a1 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -5746,7 +5746,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; end end - if !(haskey(functions(mod), k_name) || has_custom_rule) + if !haskey(functions(mod), k_name) continue end From 29ed385d498a70b8d41da9d88a366707c7263388 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 21 Sep 2024 15:38:09 -0500 Subject: [PATCH 44/87] Try fixing buildkite (#1843) * Try fixing buildkite * Update pipeline.yml * Update pipeline.yml * Update pipeline.yml --- .buildkite/pipeline.yml | 67 ++++++++++++++++++++--------------------- 1 file changed, 33 insertions(+), 34 deletions(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 6de936a558b..1a9f70d04c6 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -15,10 +15,10 @@ steps: commands: | echo "--- Setup Julia packages" julia --color=yes -e ' - import Pkg - Pkg.develop(; path = pwd()) - Pkg.develop(; path = joinpath(pwd(), "lib", "EnzymeCore")) - Pkg.develop(; name = "CUDA")' || exit 3 + using Pkg + pkgs = [PackageSpec(; path) for path in (".", "lib/EnzymeCore", "lib/EnzymeTestUtils")] + push!(pkgs, PackageSpec(; name="CUDA")) + Pkg.develop(pkgs)' || exit 3 echo "+++ Run tests" julia --color=yes test/cuda.jl @@ -41,40 +41,39 @@ steps: commands: | echo "--- Setup Julia packages" julia --color=yes -e ' - import Pkg - Pkg.develop(; path = pwd()) - Pkg.develop(; path = joinpath(pwd(), "lib", "EnzymeCore")) - Pkg.develop(; name = "AMDGPU")' || exit 3 + using Pkg + pkgs = [PackageSpec(; path) for path in (".", "lib/EnzymeCore", "lib/EnzymeTestUtils")] + push!(pkgs, PackageSpec(; name="AMDGPU")) + Pkg.develop(pkgs)' || exit 3 echo "+++ Run tests" julia --color=yes test/amdgpu.jl env: JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager - # - label: "Metal Julia v{{matrix.version}}" - # matrix: - # setup: - # version: - # - "1.8" - # - "1.9" - # plugins: - # - JuliaCI/julia#v1: - # version: "{{matrix.version}}" - # agents: - # queue: "juliaecosystem" - # os: "macos" - # arch: "aarch64" - # if: build.message !~ /\[skip tests\]/ - # timeout_in_minutes: 60 - # commands: | - # echo "--- Setup Julia packages" - # julia --color=yes -e ' - # import Pkg - # Pkg.develop(; path = pwd()) - # Pkg.develop(; path = joinpath(pwd(), "lib", "EnzymeCore")) - # Pkg.develop(; name = "Metal")' || exit 3 + - label: "Metal Julia v{{matrix.version}}" + matrix: + setup: + version: + - "1.10" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.version}}" + agents: + queue: "juliaecosystem" + os: "macos" + arch: "aarch64" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 60 + commands: | + echo "--- Setup Julia packages" + julia --color=yes -e ' + using Pkg + pkgs = [PackageSpec(; path) for path in (".", "lib/EnzymeCore", "lib/EnzymeTestUtils")] + push!(pkgs, PackageSpec(; name="Metal")) + Pkg.develop(pkgs)' || exit 3 - # echo "+++ Run tests" - # julia --color=yes test/metal.jl - # env: - # JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager + echo "+++ Run tests" + julia --color=yes test/metal.jl + env: + JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager From a08b81033b7f1a335807d56ef3cd2251781b9ad5 Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Sat, 21 Sep 2024 16:40:58 -0400 Subject: [PATCH 45/87] Remove deprecated UnionAll Vararg (#1859) * Remove deprecated UnionAll Vararg * Replace remaining uses of Vararg in docstrings `...` is better understood by more users and easier on the eyes --- src/Enzyme.jl | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index a439cf430cf..7a864aa51a8 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -170,7 +170,7 @@ end end """ - autodiff(::ReverseMode, f, Activity, args::Vararg{<:Annotation, Nargs}) + autodiff(::ReverseMode, f, Activity, args::Annotation...) Auto-differentiate function `f` at arguments `args` using reverse mode. @@ -317,7 +317,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) end """ - autodiff(mode::Mode, f, ::Type{A}, args::Vararg{Annotation, Nargs}) + autodiff(mode::Mode, f, ::Type{A}, args::Annotation...) Like [`autodiff`](@ref) but will try to extend f to an annotation, if needed. """ @@ -345,7 +345,7 @@ Like [`autodiff`](@ref) but will try to guess the activity of the return value. end """ - autodiff(::ForwardMode, f, Activity, args::Vararg{<:Annotation, Nargs}) + autodiff(::ForwardMode, f, Activity, args::Annotation...) Auto-differentiate function `f` at arguments `args` using forward mode. @@ -431,7 +431,7 @@ f(x) = x*x end """ - autodiff_deferred(::ReverseMode, f, Activity, args::Vararg{<:Annotation, Nargs}) + autodiff_deferred(::ReverseMode, f, Activity, args::Annotation...) Same as [`autodiff`](@ref) but uses deferred compilation to support usage in GPU code, as well as high-order differentiation. @@ -472,9 +472,9 @@ code, as well as high-order differentiation. end """ - autodiff_deferred(::ForwardMode, f, Activity, args::Vararg{<:Annotation, Nargs}) + autodiff_deferred(::ForwardMode, f, Activity, args::Annotation...) -Same as `autodiff(::ForwardMode, f, Activity, args)` but uses deferred compilation to support usage in GPU +Same as `autodiff(::ForwardMode, f, Activity, args...)` but uses deferred compilation to support usage in GPU code, as well as high-order differentiation. """ @inline function autodiff_deferred(::ForwardMode{ReturnPrimal, ABI, ErrIfFuncWritten, RuntimeActivity}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {ReturnPrimal, FA<:Annotation, A<:Annotation, Nargs, ABI, ErrIfFuncWritten, RuntimeActivity} @@ -532,7 +532,7 @@ code, as well as high-order differentiation. end """ - autodiff_thunk(::ReverseModeSplit, ftype, Activity, argtypes::Vararg{Type{<:Annotation, Nargs}) + autodiff_thunk(::ReverseModeSplit, ftype, Activity, argtypes::Type{<:Annotation}...) Provide the split forward and reverse pass functions for annotated function type ftype when called with args of type `argtypes` when using reverse mode. @@ -628,7 +628,7 @@ end end """ - autodiff_thunk(::ForwardMode, ftype, Activity, argtypes::Vararg{Type{<:Annotation}, Nargs}) + autodiff_thunk(::ForwardMode, ftype, Activity, argtypes::Type{<:Annotation}...) Provide the thunk forward mode function for annotated function type ftype when called with args of type `argtypes`. @@ -798,7 +798,7 @@ import .Compiler: fspec, remove_innerty, UnknownTapeType end """ - autodiff_deferred_thunk(::ReverseModeSplit, ftype, Activity, argtypes::Vararg{Type{<:Annotation}, Nargs}) + autodiff_deferred_thunk(::ReverseModeSplit, ftype, Activity, argtypes::Type{<:Annotation}...) Provide the split forward and reverse pass functions for annotated function type ftype when called with args of type `argtypes` when using reverse mode. @@ -1067,7 +1067,7 @@ grad = gradient(ReverseWithPrimal, mul, [2.0], Const([3.0])) ``` """ -@generated function gradient(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, f::F, x::ty_0, args::Vararg{<:Any, N}) where {F, ty_0, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten, N} +@generated function gradient(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, f::F, x::ty_0, args::Vararg{Any, N}) where {F, ty_0, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten, N} toemit= Expr[quote act_0 = !(x isa Enzyme.Const) && Compiler.active_reg_inner(Core.Typeof(x), #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState end] From 0c36c5af6e7ed02a25e5bb7485d6477f97ca7eed Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 21 Sep 2024 21:16:38 -0500 Subject: [PATCH 46/87] Use correct triple (#1878) * Use correct triple * fix * fix * fix --- src/compiler.jl | 4 +-- src/compiler/orcv2.jl | 63 +++++-------------------------------------- 2 files changed, 9 insertions(+), 58 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 07cbdeb65a1..ce51a6e7f53 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3293,9 +3293,9 @@ end # Define EnzymeTarget Base.@kwdef struct EnzymeTarget <: AbstractCompilerTarget end -GPUCompiler.llvm_triple(::EnzymeTarget) = Sys.MACHINE -# GPUCompiler.llvm_datalayout(::EnzymeTarget) = nothing +GPUCompiler.llvm_triple(::EnzymeTarget) = LLVM.triple(JIT.get_jit()) +GPUCompiler.llvm_datalayout(::EnzymeTarget) = LLVM.datalayout(JIT.get_jit()) function GPUCompiler.llvm_machine(::EnzymeTarget) return JIT.get_tm() diff --git a/src/compiler/orcv2.jl b/src/compiler/orcv2.jl index 78ff089e7d3..482a961b52d 100644 --- a/src/compiler/orcv2.jl +++ b/src/compiler/orcv2.jl @@ -9,24 +9,12 @@ import GPUCompiler import ..Compiler import ..Compiler: API, cpu_name, cpu_features -@inline function use_ojit() - return !Sys.iswindows() -end - export get_trampoline -@static if use_ojit() - struct CompilerInstance - jit::LLVM.JuliaOJIT - lctm::Union{LLVM.LazyCallThroughManager, Nothing} - ism::Union{LLVM.IndirectStubsManager, Nothing} - end -else - struct CompilerInstance - jit::LLVM.LLJIT - lctm::Union{LLVM.LazyCallThroughManager, Nothing} - ism::Union{LLVM.IndirectStubsManager, Nothing} - end +struct CompilerInstance + jit::LLVM.JuliaOJIT + lctm::Union{LLVM.LazyCallThroughManager, Nothing} + ism::Union{LLVM.IndirectStubsManager, Nothing} end function LLVM.dispose(ci::CompilerInstance) @@ -44,6 +32,7 @@ const jit = Ref{CompilerInstance}() const tm = Ref{TargetMachine}() # for opt pipeline get_tm() = tm[] +get_jit() = jit[].jit function absolute_symbol_materialization(name, ptr) address = LLVM.API.LLVMOrcJITTargetAddress(reinterpret(UInt, ptr)) @@ -80,37 +69,7 @@ function __init__() LLVM.asm_verbosity!(tempTM, true) tm[] = tempTM - lljit = @static if !use_ojit() - tempTM = LLVM.JITTargetMachine(LLVM.triple(), cpu_name(), cpu_features(); optlevel) - LLVM.asm_verbosity!(tempTM, true) - - gdb = haskey(ENV, "ENABLE_GDBLISTENER") - perf = haskey(ENV, "ENABLE_JITPROFILING") - if gdb || perf - ollc = LLVM.ObjectLinkingLayerCreator() do es, triple - oll = ObjectLinkingLayer(es) - if gdb - register!(oll, GDBRegistrationListener()) - end - if perf - register!(oll, IntelJITEventListener()) - register!(oll, PerfJITEventListener()) - end - return oll - end - GC.@preserve ollc begin - builder = LLJITBuilder() - LLVM.linkinglayercreator!(builder, ollc) - tmb = TargetMachineBuilder(tempTM) - LLVM.targetmachinebuilder!(builder, tmb) - LLJIT(builder) - end - else - LLJIT(;tm=tempTM) - end - else - JuliaOJIT() - end + lljit = JuliaOJIT() jd_main = JITDylib(lljit) @@ -145,10 +104,6 @@ function __init__() end atexit() do - @static if !use_ojit() - ci = jit[] - dispose(ci) - end dispose(tm[]) end end @@ -229,11 +184,7 @@ function get_trampoline(job) tsm = move_to_threadsafe(mod) - il = @static if use_ojit() - LLVM.IRCompileLayer(lljit) - else - LLVM.IRTransformLayer(lljit) - end + il = LLVM.IRCompileLayer(lljit) LLVM.emit(il, mr, tsm) end return nothing From 6bfe8e0bf09e4bba62da490e073a95105ceed20a Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 23 Sep 2024 22:36:50 -0500 Subject: [PATCH 47/87] Cleanup absint (#1880) * Cleanup absint * cleanup * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * Update Project.toml * fix * Update runtests.jl * Update runtests.jl --- Project.toml | 2 +- src/Enzyme.jl | 1126 +++++-- src/absint.jl | 365 ++- src/api.jl | 1171 +++++-- src/compiler.jl | 5296 +++++++++++++++++++++++--------- src/compiler/interpreter.jl | 210 +- src/compiler/optimize.jl | 1304 +++++--- src/compiler/orcv2.jl | 89 +- src/compiler/passes.jl | 2 +- src/compiler/reflection.jl | 107 +- src/compiler/utils.jl | 169 +- src/compiler/validation.jl | 864 ++++-- src/gradientutils.jl | 66 +- src/internal_rules.jl | 590 ++-- src/pmap.jl | 66 +- src/rules/activityrules.jl | 70 +- src/rules/allocrules.jl | 105 +- src/rules/customrules.jl | 815 +++-- src/rules/jitrules.jl | 1476 +++++++-- src/rules/llvmrules.jl | 1082 +++++-- src/rules/parallelrules.jl | 359 ++- src/rules/typerules.jl | 18 +- src/rules/typeunstablerules.jl | 981 ++++-- src/typeanalysis.jl | 7 +- src/typetree.jl | 104 +- src/utils.jl | 64 +- test/runtests.jl | 11 + test/typetree.jl | 13 + 28 files changed, 12040 insertions(+), 4492 deletions(-) diff --git a/Project.toml b/Project.toml index 3c93057f906..5a0e192de5f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.2" +version = "0.13.3" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 7a864aa51a8..c99114e038f 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -2,11 +2,68 @@ module Enzyme import EnzymeCore -import EnzymeCore: Forward, ForwardWithPrimal, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode, ReverseHolomorphic, ReverseHolomorphicWithPrimal -export Forward, ForwardWithPrimal, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode, ReverseHolomorphic, ReverseHolomorphicWithPrimal - -import EnzymeCore: Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, ABI, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written, set_abi, set_runtime_activity, clear_runtime_activity, within_autodiff -export Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written, set_abi, set_runtime_activity, clear_runtime_activity, within_autodiff +import EnzymeCore: + Forward, + ForwardWithPrimal, + Reverse, + ReverseWithPrimal, + ReverseSplitNoPrimal, + ReverseSplitWithPrimal, + ReverseSplitModified, + ReverseSplitWidth, + ReverseMode, + ForwardMode, + ReverseHolomorphic, + ReverseHolomorphicWithPrimal +export Forward, + ForwardWithPrimal, + Reverse, + ReverseWithPrimal, + ReverseSplitNoPrimal, + ReverseSplitWithPrimal, + ReverseSplitModified, + ReverseSplitWidth, + ReverseMode, + ForwardMode, + ReverseHolomorphic, + ReverseHolomorphicWithPrimal + +import EnzymeCore: + Annotation, + Const, + Active, + Duplicated, + DuplicatedNoNeed, + BatchDuplicated, + BatchDuplicatedNoNeed, + ABI, + DefaultABI, + FFIABI, + InlineABI, + NonGenABI, + set_err_if_func_written, + clear_err_if_func_written, + set_abi, + set_runtime_activity, + clear_runtime_activity, + within_autodiff +export Annotation, + Const, + Active, + Duplicated, + DuplicatedNoNeed, + BatchDuplicated, + BatchDuplicatedNoNeed, + DefaultABI, + FFIABI, + InlineABI, + NonGenABI, + set_err_if_func_written, + clear_err_if_func_written, + set_abi, + set_runtime_activity, + clear_runtime_activity, + within_autodiff import EnzymeCore: BatchDuplicatedFunc export BatchDuplicatedFunc @@ -14,11 +71,24 @@ export BatchDuplicatedFunc import EnzymeCore: MixedDuplicated, BatchMixedDuplicated export MixedDuplicated, BatchMixedDuplicated -import EnzymeCore: batch_size, get_func +import EnzymeCore: batch_size, get_func export batch_size, get_func -import EnzymeCore: autodiff, autodiff_deferred, autodiff_thunk, autodiff_deferred_thunk, tape_type, make_zero, make_zero! -export autodiff, autodiff_deferred, autodiff_thunk, autodiff_deferred_thunk, tape_type, make_zero, make_zero! +import EnzymeCore: + autodiff, + autodiff_deferred, + autodiff_thunk, + autodiff_deferred_thunk, + tape_type, + make_zero, + make_zero! +export autodiff, + autodiff_deferred, + autodiff_thunk, + autodiff_deferred_thunk, + tape_type, + make_zero, + make_zero! export jacobian, gradient, gradient!, hvp, hvp!, hvp_and_gradient! export markType, batch_size, onehot, chunkedonehot @@ -58,7 +128,7 @@ import .Compiler: CompilationException end end -@inline function any_active(args::Vararg{Annotation, N}) where N +@inline function any_active(args::Vararg{Annotation,N}) where {N} any(ntuple(Val(N)) do i Base.@_inline_meta arg = @inbounds args[i] @@ -74,18 +144,22 @@ end end) end -@inline function vaTypeof(args::Vararg{Any, N}) where N - return Tuple{(ntuple(Val(N)) do i - Base.@_inline_meta - Core.Typeof(args[i]) - end)...} +@inline function vaTypeof(args::Vararg{Any,N}) where {N} + return Tuple{( + ntuple(Val(N)) do i + Base.@_inline_meta + Core.Typeof(args[i]) + end + )...} end -@inline function vaEltypes(args::Type{Ty}) where {Ty <: Tuple} - return Tuple{(ntuple(Val(length(Ty.parameters))) do i - Base.@_inline_meta - eltype(Ty.parameters[i]) - end)...} +@inline function vaEltypes(args::Type{Ty}) where {Ty<:Tuple} + return Tuple{( + ntuple(Val(length(Ty.parameters))) do i + Base.@_inline_meta + eltype(Ty.parameters[i]) + end + )...} end @inline function same_or_one_helper(current, next) @@ -99,22 +173,28 @@ end end @inline same_or_one_rec(current) = current -@inline same_or_one_rec(current, arg::BatchMixedDuplicated{T, N}, args...) where {T,N} = - same_or_one_rec(same_or_one_helper(current, N), args...) -@inline same_or_one_rec(current, arg::Type{BatchMixedDuplicated{T, N}}, args...) where {T,N} = - same_or_one_rec(same_or_one_helper(current, N), args...) -@inline same_or_one_rec(current, arg::BatchDuplicatedFunc{T, N}, args...) where {T,N} = - same_or_one_rec(same_or_one_helper(current, N), args...) -@inline same_or_one_rec(current, arg::Type{BatchDuplicatedFunc{T, N}}, args...) where {T,N} = - same_or_one_rec(same_or_one_helper(current, N), args...) -@inline same_or_one_rec(current, arg::BatchDuplicated{T, N}, args...) where {T,N} = - same_or_one_rec(same_or_one_helper(current, N), args...) -@inline same_or_one_rec(current, arg::Type{BatchDuplicated{T, N}}, args...) where {T,N} = - same_or_one_rec(same_or_one_helper(current, N), args...) -@inline same_or_one_rec(current, arg::BatchDuplicatedNoNeed{T, N}, args...) where {T,N} = - same_or_one_rec(same_or_one_helper(current, N), args...) -@inline same_or_one_rec(current, arg::Type{BatchDuplicatedNoNeed{T, N}}, args...) where {T,N} = - same_or_one_rec(same_or_one_helper(current, N), args...) +@inline same_or_one_rec(current, arg::BatchMixedDuplicated{T,N}, args...) where {T,N} = + same_or_one_rec(same_or_one_helper(current, N), args...) +@inline same_or_one_rec( + current, + arg::Type{BatchMixedDuplicated{T,N}}, + args..., +) where {T,N} = same_or_one_rec(same_or_one_helper(current, N), args...) +@inline same_or_one_rec(current, arg::BatchDuplicatedFunc{T,N}, args...) where {T,N} = + same_or_one_rec(same_or_one_helper(current, N), args...) +@inline same_or_one_rec(current, arg::Type{BatchDuplicatedFunc{T,N}}, args...) where {T,N} = + same_or_one_rec(same_or_one_helper(current, N), args...) +@inline same_or_one_rec(current, arg::BatchDuplicated{T,N}, args...) where {T,N} = + same_or_one_rec(same_or_one_helper(current, N), args...) +@inline same_or_one_rec(current, arg::Type{BatchDuplicated{T,N}}, args...) where {T,N} = + same_or_one_rec(same_or_one_helper(current, N), args...) +@inline same_or_one_rec(current, arg::BatchDuplicatedNoNeed{T,N}, args...) where {T,N} = + same_or_one_rec(same_or_one_helper(current, N), args...) +@inline same_or_one_rec( + current, + arg::Type{BatchDuplicatedNoNeed{T,N}}, + args..., +) where {T,N} = same_or_one_rec(same_or_one_helper(current, N), args...) @inline same_or_one_rec(current, arg, args...) = same_or_one_rec(current, args...) @inline function same_or_one(defaultVal, args...) @@ -127,7 +207,7 @@ end end -@inline function refn_seed(x::T) where T +@inline function refn_seed(x::T) where {T} if T <: Complex return conj(x) / 2 else @@ -135,7 +215,7 @@ end end end -@inline function imfn_seed(x::T) where T +@inline function imfn_seed(x::T) where {T} if T <: Complex return im * conj(x) / 2 else @@ -143,7 +223,11 @@ end end end -@inline function seed_complex_args(seen, seen2, args::Vararg{Annotation, Nargs}) where {Nargs} +@inline function seed_complex_args( + seen, + seen2, + args::Vararg{Annotation,Nargs}, +) where {Nargs} return ntuple(Val(Nargs)) do i Base.@_inline_meta arg = args[i] @@ -151,18 +235,29 @@ end arg elseif arg isa Duplicated || arg isa DuplicatedNoNeed RT = eltype(Core.Typeof(arg)) - BatchDuplicated(arg.val, (arg.dval, make_zero(RT, seen, arg.dval), make_zero(RT, seen2, arg.dval))) + BatchDuplicated( + arg.val, + (arg.dval, make_zero(RT, seen, arg.dval), make_zero(RT, seen2, arg.dval)), + ) else - throw(ErrorException("Active Complex return does not yet support batching in combined reverse mode")) + throw( + ErrorException( + "Active Complex return does not yet support batching in combined reverse mode", + ), + ) end end end -@inline function fuse_complex_results(results, args::Vararg{Annotation, Nargs}) where {Nargs} +@inline function fuse_complex_results(results, args::Vararg{Annotation,Nargs}) where {Nargs} ntuple(Val(Nargs)) do i Base.@_inline_meta if args[i] isa Active - Compiler.recursive_add(Compiler.recursive_add(results[1][i][1], results[1][i][2], refn_seed), results[1][i][3], imfn_seed) + Compiler.recursive_add( + Compiler.recursive_add(results[1][i][1], results[1][i][2], refn_seed), + results[1][i][3], + imfn_seed, + ) else results[1][i] end @@ -229,16 +324,30 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) [`Active`](@ref) will automatically convert plain integers to floating point values, but cannot do so for integer values in tuples and structs. """ -@inline function autodiff(rmode::ReverseMode{ReturnPrimal, RuntimeActivity,RABI,Holomorphic, ErrIfFuncWritten}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal, RuntimeActivity, RABI<:ABI,Holomorphic, Nargs, ErrIfFuncWritten} - tt′ = vaTypeof(args...) +@inline function autodiff( + rmode::ReverseMode{ReturnPrimal,RuntimeActivity,RABI,Holomorphic,ErrIfFuncWritten}, + f::FA, + ::Type{A}, + args::Vararg{Annotation,Nargs}, +) where { + FA<:Annotation, + A<:Annotation, + ReturnPrimal, + RuntimeActivity, + RABI<:ABI, + Holomorphic, + Nargs, + ErrIfFuncWritten, +} + tt′ = vaTypeof(args...) width = same_or_one(1, args...) if width == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) end - ModifiedBetween = Val(falses_from_args(Nargs+1)) + ModifiedBetween = Val(falses_from_args(Nargs + 1)) - tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} + tt = Tuple{map(T -> eltype(Core.Typeof(T)), args)...} FTy = Core.Typeof(f.val) @@ -251,12 +360,25 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) rt = if A isa UnionAll Compiler.primal_return_type(rmode, Val(codegen_world_age(FTy, tt)), FTy, tt) else - eltype(A) + eltype(A) end if A <: Active if (!allocatedinline(rt) || rt isa Union) && rt != Union{} - forward, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, Duplicated{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(true), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) + forward, adjoint = Enzyme.Compiler.thunk( + opt_mi, + FA, + Duplicated{rt}, + tt′, + Val(API.DEM_ReverseModeGradient), + Val(width), + ModifiedBetween, + Val(ReturnPrimal), + Val(true), + RABI, + Val(ErrIfFuncWritten), + Val(RuntimeActivity), + ) #=ShadowInit=# res = forward(f, args...) tape = res[1] if ReturnPrimal @@ -265,7 +387,11 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) return adjoint(f, args..., tape) end end - elseif A <: Duplicated || A<: DuplicatedNoNeed || A <: BatchDuplicated || A<: BatchDuplicatedNoNeed || A <: BatchDuplicatedFunc + elseif A <: Duplicated || + A <: DuplicatedNoNeed || + A <: BatchDuplicated || + A <: BatchDuplicatedNoNeed || + A <: BatchDuplicatedFunc throw(ErrorException("Duplicated Returns not yet handled")) end @@ -277,16 +403,40 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) f = if f isa Const || f isa Active f elseif f isa Duplicated || f isa DuplicatedNoNeed - BatchDuplicated(f.val, (f.dval, make_zero(typeof(f), seen, f.dval), make_zero(typeof(f), seen2, f.dval))) + BatchDuplicated( + f.val, + ( + f.dval, + make_zero(typeof(f), seen, f.dval), + make_zero(typeof(f), seen2, f.dval), + ), + ) else - throw(ErrorException("Active Complex return does not yet support batching in combined reverse mode")) + throw( + ErrorException( + "Active Complex return does not yet support batching in combined reverse mode", + ), + ) end width = same_or_one(3, args...) args = seed_complex_args(seen, seen2, args...) - tt′ = vaTypeof(args...) - - thunk = Enzyme.Compiler.thunk(opt_mi, typeof(f), A, tt′, #=Split=# Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) + tt′ = vaTypeof(args...) + + thunk = Enzyme.Compiler.thunk( + opt_mi, + typeof(f), + A, + tt′, + Val(API.DEM_ReverseModeCombined), + Val(width), + ModifiedBetween, + Val(ReturnPrimal), + Val(false), + RABI, + Val(ErrIfFuncWritten), + Val(RuntimeActivity), + ) #=ShadowInit=# results = thunk(f, args..., (rt(0), rt(1), rt(im))) @@ -305,10 +455,27 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) return (fused, results[2:end]...) end - throw(ErrorException("Reverse-mode Active Complex return is ambiguous and requires more information to specify the desired result. See https://enzyme.mit.edu/julia/stable/faq/#Complex-numbers for more details.")) - end - - thunk = Enzyme.Compiler.thunk(opt_mi, FA, A, tt′, #=Split=# Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) + throw( + ErrorException( + "Reverse-mode Active Complex return is ambiguous and requires more information to specify the desired result. See https://enzyme.mit.edu/julia/stable/faq/#Complex-numbers for more details.", + ), + ) + end + + thunk = Enzyme.Compiler.thunk( + opt_mi, + FA, + A, + tt′, + Val(API.DEM_ReverseModeCombined), + Val(width), + ModifiedBetween, + Val(ReturnPrimal), + Val(false), + RABI, + Val(ErrIfFuncWritten), + Val(RuntimeActivity), + ) #=ShadowInit=# if A <: Active args = (args..., Compiler.default_adjoint(rt)) @@ -321,10 +488,19 @@ end Like [`autodiff`](@ref) but will try to extend f to an annotation, if needed. """ -@inline function autodiff(mode::CMode, f::F, args::Vararg{Annotation, Nargs}) where {F, CMode<:Mode, Nargs} +@inline function autodiff( + mode::CMode, + f::F, + args::Vararg{Annotation,Nargs}, +) where {F,CMode<:Mode,Nargs} autodiff(EnzymeCore.set_err_if_func_written(mode), Const(f), args...) end -@inline function autodiff(mode::CMode, f::F, ::Type{RT}, args::Vararg{Annotation, Nargs}) where {F, RT<:Annotation, CMode<:Mode, Nargs} +@inline function autodiff( + mode::CMode, + f::F, + ::Type{RT}, + args::Vararg{Annotation,Nargs}, +) where {F,RT<:Annotation,CMode<:Mode,Nargs} autodiff(EnzymeCore.set_err_if_func_written(mode), Const(f), RT, args...) end @@ -333,14 +509,23 @@ end Like [`autodiff`](@ref) but will try to guess the activity of the return value. """ -@inline function autodiff(mode::CMode, f::FA, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, CMode<:Mode, Nargs} - tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} - rt = if mode isa ReverseMode - Compiler.primal_return_type(mode, Val(codegen_world_age(eltype(FA), tt)), eltype(FA), tt) +@inline function autodiff( + mode::CMode, + f::FA, + args::Vararg{Annotation,Nargs}, +) where {FA<:Annotation,CMode<:Mode,Nargs} + tt = Tuple{map(T -> eltype(Core.Typeof(T)), args)...} + rt = if mode isa ReverseMode + Compiler.primal_return_type( + mode, + Val(codegen_world_age(eltype(FA), tt)), + eltype(FA), + tt, + ) else Core.Compiler.return_type(f.val, tt) end - A = guess_activity(rt, mode) + A = guess_activity(rt, mode) autodiff(mode, f, A, args...) end @@ -384,11 +569,19 @@ f(x) = x*x (6.28,) ``` """ -@inline function autodiff(::ForwardMode{ReturnPrimal, RABI, ErrIfFuncWritten, RuntimeActivity}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation} where {ReturnPrimal, RABI <: ABI, Nargs, ErrIfFuncWritten, RuntimeActivity} +@inline function autodiff( + ::ForwardMode{ReturnPrimal,RABI,ErrIfFuncWritten,RuntimeActivity}, + f::FA, + ::Type{A}, + args::Vararg{Annotation,Nargs}, +) where { + FA<:Annotation, + A<:Annotation, +} where {ReturnPrimal,RABI<:ABI,Nargs,ErrIfFuncWritten,RuntimeActivity} if any_active(args...) throw(ErrorException("Active arguments not allowed in forward mode")) end - tt′ = vaTypeof(args...) + tt′ = vaTypeof(args...) width = same_or_one(1, args...) if width == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) @@ -397,27 +590,31 @@ f(x) = x*x throw(ErrorException("Active Returns not allowed in forward mode")) end if A <: DuplicatedNoNeed || A <: BatchDuplicatedNoNeed - throw(ErrorException("Return activity `DuplicatedNoNeed` is no longer now returning or avoiding the primal is passed in for Forward Mode AD.\nPlease use autodiff(Forward, ...) or autodiff(ForwardWithPrimal, ...)")) + throw( + ErrorException( + "Return activity `DuplicatedNoNeed` is no longer now returning or avoiding the primal is passed in for Forward Mode AD.\nPlease use autodiff(Forward, ...) or autodiff(ForwardWithPrimal, ...)", + ), + ) end RT = if A <: Duplicated && width != 1 if A isa UnionAll - BatchDuplicated{T, width} where T + BatchDuplicated{T,width} where {T} else - BatchDuplicated{eltype(A), width} + BatchDuplicated{eltype(A),width} end elseif A <: DuplicatedNoNeed && width != 1 if A isa UnionAll - BatchDuplicatedNoNeed{T, width} where T + BatchDuplicatedNoNeed{T,width} where {T} else - BatchDuplicatedNoNeed{eltype(A), width} + BatchDuplicatedNoNeed{eltype(A),width} end else A end - - ModifiedBetween = Val(falses_from_args(Nargs+1)) - - tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} + + ModifiedBetween = Val(falses_from_args(Nargs + 1)) + + tt = Tuple{map(T -> eltype(Core.Typeof(T)), args)...} opt_mi = if RABI <: NonGenABI Compiler.fspec(eltype(FA), tt′) @@ -425,8 +622,20 @@ f(x) = x*x Val(codegen_world_age(Core.Typeof(f.val), tt)) end - thunk = Enzyme.Compiler.thunk(opt_mi, FA, RT, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), - ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) + thunk = Enzyme.Compiler.thunk( + opt_mi, + FA, + RT, + tt′, + Val(API.DEM_ForwardMode), + Val(width), #=Mode=# + ModifiedBetween, + Val(ReturnPrimal), + Val(false), + RABI, + Val(ErrIfFuncWritten), + Val(RuntimeActivity), + ) #=ShadowInit=# thunk(f, args...) end @@ -436,16 +645,30 @@ end Same as [`autodiff`](@ref) but uses deferred compilation to support usage in GPU code, as well as high-order differentiation. """ -@inline function autodiff_deferred(::ReverseMode{ReturnPrimal, RuntimeActivity, ABI,Holomorphic,ErrIfFuncWritten}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal, Nargs, ABI,Holomorphic,ErrIfFuncWritten, RuntimeActivity} - tt′ = vaTypeof(args...) +@inline function autodiff_deferred( + ::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, + f::FA, + ::Type{A}, + args::Vararg{Annotation,Nargs}, +) where { + FA<:Annotation, + A<:Annotation, + ReturnPrimal, + Nargs, + ABI, + Holomorphic, + ErrIfFuncWritten, + RuntimeActivity, +} + tt′ = vaTypeof(args...) width = same_or_one(1, args...) if width == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) end - tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} - + tt = Tuple{map(T -> eltype(Core.Typeof(T)), args)...} + world = codegen_world_age(Core.Typeof(f.val), tt) - + if A isa UnionAll rt = Core.Compiler.return_type(f.val, tt) rt = A{rt} @@ -458,14 +681,31 @@ code, as well as high-order differentiation. error("Return type inferred to be Union{}. Giving up.") end - ModifiedBetween = Val(falses_from_args(Nargs+1)) - - adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), UnknownTapeType, Val(ErrIfFuncWritten), Val(RuntimeActivity)) - - thunk = Compiler.CombinedAdjointThunk{Ptr{Cvoid}, FA, rt, tt′, width, ReturnPrimal}(adjoint_ptr) + ModifiedBetween = Val(falses_from_args(Nargs + 1)) + + adjoint_ptr = Compiler.deferred_codegen( + Val(world), + FA, + Val(tt′), + Val(rt), + Val(API.DEM_ReverseModeCombined), + Val(width), + ModifiedBetween, + Val(ReturnPrimal), + Val(false), + UnknownTapeType, + Val(ErrIfFuncWritten), + Val(RuntimeActivity), + ) #=ShadowInit=# + + thunk = + Compiler.CombinedAdjointThunk{Ptr{Cvoid},FA,rt,tt′,width,ReturnPrimal}(adjoint_ptr) if rt <: Active args = (args..., Compiler.default_adjoint(eltype(rt))) - elseif A <: Duplicated || A<: DuplicatedNoNeed || A <: BatchDuplicated || A<: BatchDuplicatedNoNeed + elseif A <: Duplicated || + A <: DuplicatedNoNeed || + A <: BatchDuplicated || + A <: BatchDuplicatedNoNeed throw(ErrorException("Duplicated Returns not yet handled")) end thunk(f, args...) @@ -477,37 +717,54 @@ end Same as `autodiff(::ForwardMode, f, Activity, args...)` but uses deferred compilation to support usage in GPU code, as well as high-order differentiation. """ -@inline function autodiff_deferred(::ForwardMode{ReturnPrimal, ABI, ErrIfFuncWritten, RuntimeActivity}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {ReturnPrimal, FA<:Annotation, A<:Annotation, Nargs, ABI, ErrIfFuncWritten, RuntimeActivity} +@inline function autodiff_deferred( + ::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}, + f::FA, + ::Type{A}, + args::Vararg{Annotation,Nargs}, +) where { + ReturnPrimal, + FA<:Annotation, + A<:Annotation, + Nargs, + ABI, + ErrIfFuncWritten, + RuntimeActivity, +} if any_active(args...) throw(ErrorException("Active arguments not allowed in forward mode")) end - tt′ = vaTypeof(args...) + tt′ = vaTypeof(args...) width = same_or_one(1, args...) if width == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) end if A <: DuplicatedNoNeed || A <: BatchDuplicatedNoNeed - throw(ErrorException("Return activity `DuplicatedNoNeed` is no longer now returning or avoiding the primal is passed in for Forward Mode AD.\nPlease use autodiff(Forward, ...) or autodiff(ForwardWithPrimal, ...)")) + throw( + ErrorException( + "Return activity `DuplicatedNoNeed` is no longer now returning or avoiding the primal is passed in for Forward Mode AD.\nPlease use autodiff(Forward, ...) or autodiff(ForwardWithPrimal, ...)", + ), + ) end RT = if A <: Duplicated && width != 1 if A isa UnionAll - BatchDuplicated{T, width} where T + BatchDuplicated{T,width} where {T} else - BatchDuplicated{eltype(A), width} + BatchDuplicated{eltype(A),width} end elseif A <: DuplicatedNoNeed && width != 1 if A isa UnionAll - BatchDuplicatedNoNeed{T, width} where T + BatchDuplicatedNoNeed{T,width} where {T} else - BatchDuplicatedNoNeed{eltype(A), width} + BatchDuplicatedNoNeed{eltype(A),width} end else A end - tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} - + tt = Tuple{map(T -> eltype(Core.Typeof(T)), args)...} + world = codegen_world_age(Core.Typeof(f.val), tt) - + if RT isa UnionAll rt = Core.Compiler.return_type(f.val, tt) rt = RT{rt} @@ -524,10 +781,23 @@ code, as well as high-order differentiation. throw(ErrorException("Active Returns not allowed in forward mode")) end - ModifiedBetween = Val(falses_from_args(Nargs+1)) - - adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), UnknownTapeType, Val(ErrIfFuncWritten), Val(RuntimeActivity)) - thunk = Compiler.ForwardModeThunk{Ptr{Cvoid}, FA, rt, tt′, width, ReturnPrimal}(adjoint_ptr) + ModifiedBetween = Val(falses_from_args(Nargs + 1)) + + adjoint_ptr = Compiler.deferred_codegen( + Val(world), + FA, + Val(tt′), + Val(rt), + Val(API.DEM_ForwardMode), + Val(width), + ModifiedBetween, + Val(ReturnPrimal), + Val(false), + UnknownTapeType, + Val(ErrIfFuncWritten), + Val(RuntimeActivity), + ) #=ShadowInit=# + thunk = Compiler.ForwardModeThunk{Ptr{Cvoid},FA,rt,tt′,width,ReturnPrimal}(adjoint_ptr) thunk(f, args...) end @@ -574,7 +844,31 @@ result, ∂v, ∂A (7.26, 2.2, [3.3]) ``` """ -@inline function autodiff_thunk(rs::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity, Width,ModifiedBetweenT,RABI, ErrIfFuncWritten}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT,RABI<:ABI, Nargs, ErrIfFuncWritten, RuntimeActivity} +@inline function autodiff_thunk( + rs::ReverseModeSplit{ + ReturnPrimal, + ReturnShadow, + RuntimeActivity, + Width, + ModifiedBetweenT, + RABI, + ErrIfFuncWritten, + }, + ::Type{FA}, + ::Type{A}, + args::Vararg{Type{<:Annotation},Nargs}, +) where { + FA<:Annotation, + A<:Annotation, + ReturnPrimal, + ReturnShadow, + Width, + ModifiedBetweenT, + RABI<:ABI, + Nargs, + ErrIfFuncWritten, + RuntimeActivity, +} width = if Width == 0 w = same_or_one(1, args...) if w == 0 @@ -586,13 +880,13 @@ result, ∂v, ∂A end if ModifiedBetweenT === true - ModifiedBetween = Val(falses_from_args(Nargs+1)) + ModifiedBetween = Val(falses_from_args(Nargs + 1)) else ModifiedBetween = Val(ModifiedBetweenT) end - tt = Tuple{map(eltype, args)...} - + tt = Tuple{map(eltype, args)...} + if !(A <: Const) @assert ReturnShadow end @@ -602,7 +896,20 @@ result, ∂v, ∂A else Val(codegen_world_age(eltype(FA), tt)) end - Enzyme.Compiler.thunk(opt_mi, FA, A, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) + Enzyme.Compiler.thunk( + opt_mi, + FA, + A, + tt′, + Val(API.DEM_ReverseModeGradient), + Val(width), + ModifiedBetween, + Val(ReturnPrimal), + Val(false), + RABI, + Val(ErrIfFuncWritten), + Val(RuntimeActivity), + ) #=ShadowInit=# end """ @@ -620,11 +927,20 @@ end ((6.2,),) ``` """ -@inline function autodiff(f::Function, m::MMode, ::Type{A}, args::Vararg{Annotation, Nargs}) where {A<:Annotation, Nargs, MMode<:Mode} - autodiff(m, f, A, args...) +@inline function autodiff( + f::Function, + m::MMode, + ::Type{A}, + args::Vararg{Annotation,Nargs}, +) where {A<:Annotation,Nargs,MMode<:Mode} + autodiff(m, f, A, args...) end -@inline function autodiff(f::Function, m::MMode, args::Vararg{Annotation, Nargs}) where {Nargs, MMode<:Mode} - autodiff(m, f, args...) +@inline function autodiff( + f::Function, + m::MMode, + args::Vararg{Annotation,Nargs}, +) where {Nargs,MMode<:Mode} + autodiff(m, f, args...) end """ @@ -671,7 +987,20 @@ forward = autodiff_thunk(Forward, Const{typeof(f)}, Duplicated, Duplicated{Float (6.28,) ``` """ -@inline function autodiff_thunk(::ForwardMode{ReturnPrimal, RABI, ErrIfFuncWritten, RuntimeActivity}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {ReturnPrimal, FA<:Annotation, A<:Annotation, RABI<:ABI, Nargs, ErrIfFuncWritten, RuntimeActivity} +@inline function autodiff_thunk( + ::ForwardMode{ReturnPrimal,RABI,ErrIfFuncWritten,RuntimeActivity}, + ::Type{FA}, + ::Type{A}, + args::Vararg{Type{<:Annotation},Nargs}, +) where { + ReturnPrimal, + FA<:Annotation, + A<:Annotation, + RABI<:ABI, + Nargs, + ErrIfFuncWritten, + RuntimeActivity, +} width = same_or_one(1, A, args...) if width == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) @@ -680,23 +1009,64 @@ forward = autodiff_thunk(Forward, Const{typeof(f)}, Duplicated, Duplicated{Float throw(ErrorException("Active Returns not allowed in forward mode")) end if A <: DuplicatedNoNeed || A <: BatchDuplicatedNoNeed - throw(ErrorException("Return activity `DuplicatedNoNeed` is no longer now returning or avoiding the primal is passed in for Forward Mode AD.\nPlease use autodiff(Forward, ...) or autodiff(ForwardWithPrimal, ...)")) + throw( + ErrorException( + "Return activity `DuplicatedNoNeed` is no longer now returning or avoiding the primal is passed in for Forward Mode AD.\nPlease use autodiff(Forward, ...) or autodiff(ForwardWithPrimal, ...)", + ), + ) end - ModifiedBetween = Val(falses_from_args(Nargs+1)) + ModifiedBetween = Val(falses_from_args(Nargs + 1)) + + tt = Tuple{map(eltype, args)...} - tt = Tuple{map(eltype, args)...} - tt′ = Tuple{args...} opt_mi = if RABI <: NonGenABI Compiler.fspec(eltype(FA), tt′) else Val(codegen_world_age(eltype(FA), tt)) end - results = Enzyme.Compiler.thunk(opt_mi, FA, A, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) + results = Enzyme.Compiler.thunk( + opt_mi, + FA, + A, + tt′, + Val(API.DEM_ForwardMode), + Val(width), + ModifiedBetween, + Val(ReturnPrimal), + Val(false), + RABI, + Val(ErrIfFuncWritten), + Val(RuntimeActivity), + ) #=ShadowInit=# end -@inline function tape_type(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity, Width,ModifiedBetweenT, RABI, ErrIfFuncWritten}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs, ErrIfFuncWritten, RuntimeActivity} +@inline function tape_type( + ::ReverseModeSplit{ + ReturnPrimal, + ReturnShadow, + RuntimeActivity, + Width, + ModifiedBetweenT, + RABI, + ErrIfFuncWritten, + }, + ::Type{FA}, + ::Type{A}, + args::Vararg{Type{<:Annotation},Nargs}, +) where { + FA<:Annotation, + A<:Annotation, + ReturnPrimal, + ReturnShadow, + Width, + ModifiedBetweenT, + RABI<:ABI, + Nargs, + ErrIfFuncWritten, + RuntimeActivity, +} width = if Width == 0 w = same_or_one(1, args...) if w == 0 @@ -708,21 +1078,34 @@ end end if ModifiedBetweenT === true - ModifiedBetween = Val(falses_from_args(Nargs+1)) + ModifiedBetween = Val(falses_from_args(Nargs + 1)) else ModifiedBetween = Val(ModifiedBetweenT) end @assert ReturnShadow TT = Tuple{args...} - + primal_tt = Tuple{map(eltype, args)...} opt_mi = if RABI <: NonGenABI Compiler.fspec(eltype(FA), TT) else Val(codegen_world_age(eltype(FA), primal_tt)) end - nondef = Enzyme.Compiler.thunk(opt_mi, FA, A, TT, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) + nondef = Enzyme.Compiler.thunk( + opt_mi, + FA, + A, + TT, + Val(API.DEM_ReverseModeGradient), + Val(width), + ModifiedBetween, + Val(ReturnPrimal), + Val(false), + RABI, + Val(ErrIfFuncWritten), + Val(RuntimeActivity), + ) #=ShadowInit=# if nondef[1] isa Enzyme.Compiler.PrimalErrorThunk return Nothing else @@ -731,16 +1114,36 @@ end end end -const tape_cache = Dict{UInt, Type}() +const tape_cache = Dict{UInt,Type}() const tape_cache_lock = ReentrantLock() import .Compiler: fspec, remove_innerty, UnknownTapeType @inline function tape_type( - parent_job::Union{GPUCompiler.CompilerJob,Nothing}, ::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity, Width,ModifiedBetweenT, RABI}, - ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs} -) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs, RuntimeActivity} + parent_job::Union{GPUCompiler.CompilerJob,Nothing}, + ::ReverseModeSplit{ + ReturnPrimal, + ReturnShadow, + RuntimeActivity, + Width, + ModifiedBetweenT, + RABI, + }, + ::Type{FA}, + ::Type{A}, + args::Vararg{Type{<:Annotation},Nargs}, +) where { + FA<:Annotation, + A<:Annotation, + ReturnPrimal, + ReturnShadow, + Width, + ModifiedBetweenT, + RABI<:ABI, + Nargs, + RuntimeActivity, +} width = if Width == 0 w = same_or_one(1, args...) if w == 0 @@ -768,12 +1171,21 @@ import .Compiler: fspec, remove_innerty, UnknownTapeType target = Compiler.EnzymeTarget() params = Compiler.EnzymeCompilerParams( - Tuple{FA, TT.parameters...}, API.DEM_ReverseModeGradient, width, - Compiler.remove_innerty(A), true, #=abiwrap=#false, ModifiedBetweenT, - ReturnPrimal, #=ShadowInit=#false, Compiler.UnknownTapeType, RABI, #=errifwritte=#false, - RuntimeActivity + Tuple{FA,TT.parameters...}, + API.DEM_ReverseModeGradient, + width, + Compiler.remove_innerty(A), + true, + false, + ModifiedBetweenT, #=abiwrap=# + ReturnPrimal, + false, + Compiler.UnknownTapeType, + RABI, + false, #=errifwritte=# + RuntimeActivity, ) - job = Compiler.CompilerJob(mi, Compiler.CompilerConfig(target, params; kernel=false)) + job = Compiler.CompilerJob(mi, Compiler.CompilerConfig(target, params; kernel = false)) key = hash(parent_job, hash(job)) @@ -786,7 +1198,7 @@ import .Compiler: fspec, remove_innerty, UnknownTapeType if obj === nothing Compiler.JuliaContext() do ctx - _, meta = Compiler.codegen(:llvm, job; optimize=false, parent_job) + _, meta = Compiler.codegen(:llvm, job; optimize = false, parent_job) obj = meta.TapeType tape_cache[key] = obj end @@ -841,7 +1253,33 @@ result, ∂v, ∂A (7.26, 2.2, [3.3]) ``` """ -@inline function autodiff_deferred_thunk(mode::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity, Width,ModifiedBetweenT, RABI, ErrIfFuncWritten}, tt::Type{TapeType}, fa::Type{FA}, a2::Type{A2}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A2<:Annotation, TapeType, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs, ErrIfFuncWritten, RuntimeActivity} +@inline function autodiff_deferred_thunk( + mode::ReverseModeSplit{ + ReturnPrimal, + ReturnShadow, + RuntimeActivity, + Width, + ModifiedBetweenT, + RABI, + ErrIfFuncWritten, + }, + tt::Type{TapeType}, + fa::Type{FA}, + a2::Type{A2}, + args::Vararg{Type{<:Annotation},Nargs}, +) where { + FA<:Annotation, + A2<:Annotation, + TapeType, + ReturnPrimal, + ReturnShadow, + Width, + ModifiedBetweenT, + RABI<:ABI, + Nargs, + ErrIfFuncWritten, + RuntimeActivity, +} @assert RABI == FFIABI width = if Width == 0 w = same_or_one(1, args...) @@ -854,7 +1292,7 @@ result, ∂v, ∂A end if ModifiedBetweenT === true - ModifiedBetween = Val(falses_from_args(Nargs+1)) + ModifiedBetween = Val(falses_from_args(Nargs + 1)) else ModifiedBetween = Val(ModifiedBetweenT) end @@ -865,40 +1303,69 @@ result, ∂v, ∂A primal_tt = Tuple{map(eltype, args)...} world = codegen_world_age(eltype(FA), primal_tt) - primal_ptr = Compiler.deferred_codegen(Val(world), FA, Val(TT), Val(Compiler.remove_innerty(A2)), Val(API.DEM_ReverseModePrimal), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType, Val(ErrIfFuncWritten), Val(RuntimeActivity)) - adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(TT), Val(Compiler.remove_innerty(A2)), Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType, Val(ErrIfFuncWritten), Val(RuntimeActivity)) + primal_ptr = Compiler.deferred_codegen( + Val(world), + FA, + Val(TT), + Val(Compiler.remove_innerty(A2)), + Val(API.DEM_ReverseModePrimal), + Val(width), + ModifiedBetween, + Val(ReturnPrimal), + Val(false), + TapeType, + Val(ErrIfFuncWritten), + Val(RuntimeActivity), + ) #=ShadowInit=# + adjoint_ptr = Compiler.deferred_codegen( + Val(world), + FA, + Val(TT), + Val(Compiler.remove_innerty(A2)), + Val(API.DEM_ReverseModeGradient), + Val(width), + ModifiedBetween, + Val(ReturnPrimal), + Val(false), + TapeType, + Val(ErrIfFuncWritten), + Val(RuntimeActivity), + ) #=ShadowInit=# RT = if A2 <: Duplicated && width != 1 if A2 isa UnionAll - BatchDuplicated{T, width} where T + BatchDuplicated{T,width} where {T} else - BatchDuplicated{eltype(A2), width} + BatchDuplicated{eltype(A2),width} end elseif A2 <: DuplicatedNoNeed && width != 1 if A2 isa UnionAll - BatchDuplicatedNoNeed{T, width} where T + BatchDuplicatedNoNeed{T,width} where {T} else - BatchDuplicatedNoNeed{eltype(A2), width} + BatchDuplicatedNoNeed{eltype(A2),width} end elseif A2 <: MixedDuplicated && width != 1 if A2 isa UnionAll - BatchMixedDuplicated{T, width} where T + BatchMixedDuplicated{T,width} where {T} else - BatchMixedDuplicated{eltype(A2), width} + BatchMixedDuplicated{eltype(A2),width} end else A2 end - + rt = if RT isa UnionAll - RT{Core.Compiler.return_type(Tuple{eltype(FA), map(eltype, args)...})} + RT{Core.Compiler.return_type(Tuple{eltype(FA),map(eltype, args)...})} else @assert RT isa DataType RT end - aug_thunk = Compiler.AugmentedForwardThunk{Ptr{Cvoid}, FA, rt, TT, width, ReturnPrimal, TapeType}(primal_ptr) - adj_thunk = Compiler.AdjointThunk{Ptr{Cvoid}, FA, rt, TT, width, TapeType}(adjoint_ptr) + aug_thunk = + Compiler.AugmentedForwardThunk{Ptr{Cvoid},FA,rt,TT,width,ReturnPrimal,TapeType}( + primal_ptr, + ) + adj_thunk = Compiler.AdjointThunk{Ptr{Cvoid},FA,rt,TT,width,TapeType}(adjoint_ptr) aug_thunk, adj_thunk end @@ -911,11 +1378,11 @@ Base.@ccallable function __enzyme_double(x::Ptr{Cvoid})::Cvoid return nothing end -@inline function markType(::Type{T}, ptr::Ptr{Cvoid}) where T +@inline function markType(::Type{T}, ptr::Ptr{Cvoid}) where {T} markType(Base.unsafe_convert(Ptr{T}, ptr)) end -@inline function markType(data::Array{T}) where T +@inline function markType(data::Array{T}) where {T} GC.@preserve data markType(pointer(data)) end @@ -925,20 +1392,52 @@ end end @inline function markType(data::Ptr{Float32}) -@static if sizeof(Int) == sizeof(Int64) - Base.llvmcall(("declare void @__enzyme_float(i8* nocapture) nounwind define void @c(i64 %q) nounwind alwaysinline { %p = inttoptr i64 %q to i8* call void @__enzyme_float(i8* %p) ret void }", "c"), Cvoid, Tuple{Ptr{Float32}}, data) -else - Base.llvmcall(("declare void @__enzyme_float(i8* nocapture) nounwind define void @c(i32 %q) nounwind alwaysinline { %p = inttoptr i32 %q to i8* call void @__enzyme_float(i8* %p) ret void }", "c"), Cvoid, Tuple{Ptr{Float32}}, data) -end + @static if sizeof(Int) == sizeof(Int64) + Base.llvmcall( + ( + "declare void @__enzyme_float(i8* nocapture) nounwind define void @c(i64 %q) nounwind alwaysinline { %p = inttoptr i64 %q to i8* call void @__enzyme_float(i8* %p) ret void }", + "c", + ), + Cvoid, + Tuple{Ptr{Float32}}, + data, + ) + else + Base.llvmcall( + ( + "declare void @__enzyme_float(i8* nocapture) nounwind define void @c(i32 %q) nounwind alwaysinline { %p = inttoptr i32 %q to i8* call void @__enzyme_float(i8* %p) ret void }", + "c", + ), + Cvoid, + Tuple{Ptr{Float32}}, + data, + ) + end nothing end @inline function markType(data::Ptr{Float64}) -@static if sizeof(Int) == sizeof(Int64) - Base.llvmcall(("declare void @__enzyme_double(i8* nocapture) nounwind define void @c(i64 %q) nounwind alwaysinline { %p = inttoptr i64 %q to i8* call void @__enzyme_double(i8* %p) ret void }", "c"), Cvoid, Tuple{Ptr{Float64}}, data) -else - Base.llvmcall(("declare void @__enzyme_double(i8* nocapture) nounwind define void @c(i32 %q) nounwind alwaysinline { %p = inttoptr i32 %q to i8* call void @__enzyme_double(i8* %p) ret void }", "c"), Cvoid, Tuple{Ptr{Float64}}, data) -end + @static if sizeof(Int) == sizeof(Int64) + Base.llvmcall( + ( + "declare void @__enzyme_double(i8* nocapture) nounwind define void @c(i64 %q) nounwind alwaysinline { %p = inttoptr i64 %q to i8* call void @__enzyme_double(i8* %p) ret void }", + "c", + ), + Cvoid, + Tuple{Ptr{Float64}}, + data, + ) + else + Base.llvmcall( + ( + "declare void @__enzyme_double(i8* nocapture) nounwind define void @c(i32 %q) nounwind alwaysinline { %p = inttoptr i32 %q to i8* call void @__enzyme_double(i8* %p) ret void }", + "c", + ), + Cvoid, + Tuple{Ptr{Float64}}, + data, + ) + end nothing end @@ -947,24 +1446,24 @@ end ntuple(Val(N)) do i Base.@_inline_meta res = similar(x) - for idx in 1:N + for idx = 1:N @inbounds res[idx] = (i == idx) ? 1.0 : 0.0 end return res end end @inline function onehot(x, start, endl) - ntuple(Val(endl-start+1)) do i + ntuple(Val(endl - start + 1)) do i Base.@_inline_meta res = similar(x) - for idx in 1:length(x) - @inbounds res[idx] = (i + start - 1== idx) ? 1.0 : 0.0 + for idx = 1:length(x) + @inbounds res[idx] = (i + start - 1 == idx) ? 1.0 : 0.0 end return res end end -@inline function onehot(::Type{NTuple{N, T}}) where {T, N} +@inline function onehot(::Type{NTuple{N,T}}) where {T,N} ntuple(Val(N)) do i Base.@_inline_meta ntuple(Val(N)) do idx @@ -973,11 +1472,11 @@ end end end end -@inline function onehot(x::NTuple{N, T}) where {T, N} - onehot(NTuple{N, T}) +@inline function onehot(x::NTuple{N,T}) where {T,N} + onehot(NTuple{N,T}) end -@inline function onehot(x::NTuple{N, T}, start, endl) where {T, N} - ntuple(Val(endl-start+1)) do i +@inline function onehot(x::NTuple{N,T}, start, endl) where {T,N} + ntuple(Val(endl - start + 1)) do i Base.@_inline_meta ntuple(Val(N)) do idx Base.@_inline_meta @@ -1067,21 +1566,41 @@ grad = gradient(ReverseWithPrimal, mul, [2.0], Const([3.0])) ``` """ -@generated function gradient(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, f::F, x::ty_0, args::Vararg{Any, N}) where {F, ty_0, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten, N} - toemit= Expr[quote - act_0 = !(x isa Enzyme.Const) && Compiler.active_reg_inner(Core.Typeof(x), #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState +@generated function gradient( + rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, + f::F, + x::ty_0, + args::Vararg{Any,N}, +) where {F,ty_0,ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten,N} + toemit = Expr[quote + act_0 = + !(x isa Enzyme.Const) && + Compiler.active_reg_inner(Core.Typeof(x), (), nothing, Val(true)) == + Compiler.ActiveState #=justActive=# end] rargs = Union{Symbol,Expr}[:x] acts = Symbol[Symbol("act_0")] - for i in 1:N - argidx = quote args[$i] end + for i = 1:N + argidx = quote + args[$i] + end push!(rargs, argidx) sym = Symbol("act_$i") push!(acts, sym) - push!(toemit, quote - $sym = !($argidx isa Enzyme.Const) && Compiler.active_reg_inner(Core.Typeof($argidx), #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState - end) + push!( + toemit, + quote + $sym = + !($argidx isa Enzyme.Const) && + Compiler.active_reg_inner( + Core.Typeof($argidx), + (), + nothing, + Val(true), + ) == Compiler.ActiveState #=justActive=# + end, + ) end idx = 0 @@ -1118,7 +1637,7 @@ grad = gradient(ReverseWithPrimal, mul, [2.0], Const([3.0])) $shad end end) - idx+=1 + idx += 1 end push!(toemit, quote res = autodiff(rm, f, Active, $(enz_args...)) @@ -1128,7 +1647,7 @@ grad = gradient(ReverseWithPrimal, mul, [2.0], Const([3.0])) return quote Base.@_inline_meta $(toemit...) - (; derivs=($(resargs...),), val=res[2]) + (; derivs = ($(resargs...),), val = res[2]) end else return quote @@ -1166,26 +1685,31 @@ gradient!(ReverseWithPrimal, dx, f, [2.0, 3.0]) (derivs = ([3.0, 2.0],), val = 6.0) ``` """ -@inline function gradient!(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, dx::X, f::F, x::X) where {X<:Array, F, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten} +@inline function gradient!( + rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, + dx::X, + f::F, + x::X, +) where {X<:Array,F,ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} make_zero!(dx) res = autodiff(rm, f, Active, Duplicated(x, dx)) return if ReturnPrimal - (; derivs=(dx,), val=res[2]) + (; derivs = (dx,), val = res[2]) else (dx,) end end -@inline function chunkedonehot(x, ::Val{chunk}) where chunk +@inline function chunkedonehot(x, ::Val{chunk}) where {chunk} sz = length(x) num = ((sz + chunk - 1) ÷ chunk) ntuple(Val(num)) do i Base.@_inline_meta - onehot(x, (i-1)*chunk+1, i == num ? sz : (i*chunk) ) + onehot(x, (i - 1) * chunk + 1, i == num ? sz : (i * chunk)) end end -@inline function chunkedonehot(x::AbstractFloat, ::Val{chunk}) where chunk +@inline function chunkedonehot(x::AbstractFloat, ::Val{chunk}) where {chunk} return ((one(x),),) end @@ -1201,23 +1725,27 @@ function create_shadows(::Val{1}, x) return (onehot(x),) end -function create_shadows(::Val{chunk}, x) where chunk +function create_shadows(::Val{chunk}, x) where {chunk} return (chunkedonehot(x, Val(chunk)),) end -struct TupleArray{T, Shape, Length, N} <: AbstractArray{T,N} - data::NTuple{Length, T} +struct TupleArray{T,Shape,Length,N} <: AbstractArray{T,N} + data::NTuple{Length,T} end -TupleArray(data::NTuple{Length, T}, Shape) where {Length, T} = TupleArray{T, Shape, Length, length(Shape)}(data) - -@inline Base.eltype(::TupleArray{T}) where T = T -@inline Base.eltype(::Type{<:TupleArray{T}}) where T = T -@inline Base.size(::TupleArray{<:Any, Shape}) where Shape = Shape -@inline Base.ndims(::TupleArray{<:Any, <:Any, <:Any, N}) where N = N - -function Base.convert(::Type{Array{T, N}}, X::TupleArray{T, Shape, Length, N}) where {T, Shape, Length, N} - vals = Array{T, N}(undef, Shape...) - for i in 1:Length +TupleArray(data::NTuple{Length,T}, Shape) where {Length,T} = + TupleArray{T,Shape,Length,length(Shape)}(data) + +@inline Base.eltype(::TupleArray{T}) where {T} = T +@inline Base.eltype(::Type{<:TupleArray{T}}) where {T} = T +@inline Base.size(::TupleArray{<:Any,Shape}) where {Shape} = Shape +@inline Base.ndims(::TupleArray{<:Any,<:Any,<:Any,N}) where {N} = N + +function Base.convert( + ::Type{Array{T,N}}, + X::TupleArray{T,Shape,Length,N}, +) where {T,Shape,Length,N} + vals = Array{T,N}(undef, Shape...) + for i = 1:Length @inbounds val[i] = X.data[i] end return vals @@ -1225,9 +1753,9 @@ end function Base.getindex(a::TupleArray, args::Vararg{Int,N}) where {N} start = 0 - for i in 1:N + for i = 1:N start *= size(a, N - i + 1) - start += (args[N - i + 1] - 1) + start += (args[N-i+1] - 1) end start += 1 return a.data[start] @@ -1301,10 +1829,16 @@ grad = gradient(Forward, f, [2.0, 3.0, 4.0]) ([3.0 2.0 0.0; 0.0 1.0 1.0],) ``` """ -@inline function gradient(fm::ForwardMode{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity}, f, x; chunk::CS=nothing, shadows=create_shadows(chunk, x)) where {ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity, CS} +@inline function gradient( + fm::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}, + f, + x; + chunk::CS = nothing, + shadows = create_shadows(chunk, x), +) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity,CS} if length(shadows[1]) == 0 return if ReturnPrimal - (; derivs=(x,), val=f(x.val)) + (; derivs = (x,), val = f(x.val)) else (x,) end @@ -1331,9 +1865,9 @@ grad = gradient(Forward, f, [2.0, 3.0, 4.0]) if ReturnPrimal rp = autodiff(fm, f, Duplicated, Duplicated(x, shadows[1][1])) dres1 = rp[1] - fm2 = ForwardMode{#=ReturnPrimal=#false, ABI, ErrIfFuncWritten,RuntimeActivity}() + fm2 = ForwardMode{false,ABI,ErrIfFuncWritten,RuntimeActivity}() #=ReturnPrimal=# - res = ntuple(length(shadows[1])-1) do i + res = ntuple(length(shadows[1]) - 1) do i autodiff(fm2, f, Duplicated, Duplicated(x, shadows[1][i+1]))[1] end gres = if x isa AbstractFloat @@ -1359,9 +1893,16 @@ grad = gradient(Forward, f, [2.0, 3.0, 4.0]) gres = if x isa AbstractFloat dres1[1] else - fm2 = ForwardMode{#=ReturnPrimal=#false, ABI, ErrIfFuncWritten,RuntimeActivity}() - tmp = ntuple(length(shadows[1])-1) do i - values(autodiff(fm2, f, BatchDuplicated, BatchDuplicated(x, shadows[1][i+1]))[1]) + fm2 = ForwardMode{false,ABI,ErrIfFuncWritten,RuntimeActivity}() #=ReturnPrimal=# + tmp = ntuple(length(shadows[1]) - 1) do i + values( + autodiff( + fm2, + f, + BatchDuplicated, + BatchDuplicated(x, shadows[1][i+1]), + )[1], + ) end tupleconcat(dres1, tmp...) end @@ -1397,7 +1938,7 @@ grad = gradient(Forward, f, [2.0, 3.0, 4.0]) cols end if ReturnPrimal - (; derivs=(res,), val=gradtup[2]) + (; derivs = (res,), val = gradtup[2]) else (res,) end @@ -1466,7 +2007,13 @@ In the future, when this function is extended to handle non-array return types, this function will retun an AbstractArray of shape `size(output)` of values of the input type. ``` """ -@inline function jacobian(::ReverseMode{ReturnPrimal,RuntimeActivity, RABI, Holomorphic, ErrIfFuncWritten}, f::F, x::X; n_outs::OutType=nothing, chunk::CT=nothing) where {ReturnPrimal, F, X, RABI<:ABI, ErrIfFuncWritten, RuntimeActivity, OutType, CT, Holomorphic} +@inline function jacobian( + ::ReverseMode{ReturnPrimal,RuntimeActivity,RABI,Holomorphic,ErrIfFuncWritten}, + f::F, + x::X; + n_outs::OutType = nothing, + chunk::CT = nothing, +) where {ReturnPrimal,F,X,RABI<:ABI,ErrIfFuncWritten,RuntimeActivity,OutType,CT,Holomorphic} if n_outs == nothing res = if f isa Const @@ -1475,43 +2022,57 @@ this function will retun an AbstractArray of shape `size(output)` of values of t f(x) end jac = if res isa AbstractArray - jacobian(ReverseMode{false,RuntimeActivity,RABI, Holomorphic, ErrIfFuncWritten}(), f, x; n_outs=Val(size(res)), chunk) + jacobian( + ReverseMode{false,RuntimeActivity,RABI,Holomorphic,ErrIfFuncWritten}(), + f, + x; + n_outs = Val(size(res)), + chunk, + ) elseif res isa AbstractFloat - gradient(ReverseMode{false,RuntimeActivity,RABI, Holomorphic, ErrIfFuncWritten}(), f, x) + gradient( + ReverseMode{false,RuntimeActivity,RABI,Holomorphic,ErrIfFuncWritten}(), + f, + x, + ) else - throw(AssertionError("Unsupported return type of function for reverse-mode jacobian, $(Core.Typeof(res))")) + throw( + AssertionError( + "Unsupported return type of function for reverse-mode jacobian, $(Core.Typeof(res))", + ), + ) end return if ReturnPrimal - (; derivs=jac, val=res) + (; derivs = jac, val = res) else jac end else - @assert !Holomorphic + @assert !Holomorphic n_out_val = if length(Compiler.element(n_outs)) == 0 0 else prod(Compiler.element(n_outs)) end - + if chunk == Val(0) throw(ErrorException("Cannot differentiate with a batch size of 0")) end - - XT = Core.Typeof(x) - MD = Compiler.active_reg_inner(XT, #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState - tt = Tuple{XT} + + XT = Core.Typeof(x) + MD = Compiler.active_reg_inner(XT, (), nothing, Val(true)) == Compiler.ActiveState #=justActive=# + tt = Tuple{XT} rt = if f isa Const Core.Compiler.return_type(f.val, tt) else Core.Compiler.return_type(f, tt) end - + ModifiedBetween = Val((false, false)) FRT = Core.Typeof(f) FA = Const{FRT} - + opt_mi = if RABI <: NonGenABI Compiler.fspec(FRT, tt′) else @@ -1519,8 +2080,21 @@ this function will retun an AbstractArray of shape `size(output)` of values of t end if chunk == Val(1) || chunk == nothing - tt′ = MD ? Tuple{MixedDuplicated{XT}} : Tuple{Duplicated{XT}} - primal, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, DuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(1), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) + tt′ = MD ? Tuple{MixedDuplicated{XT}} : Tuple{Duplicated{XT}} + primal, adjoint = Enzyme.Compiler.thunk( + opt_mi, + FA, + DuplicatedNoNeed{rt}, + tt′, + Val(API.DEM_ReverseModeGradient), + Val(1), + ModifiedBetween, + Val(false), + Val(false), + RABI, + Val(ErrIfFuncWritten), + Val(RuntimeActivity), + ) #=ShadowInit=# tmp = ntuple(Val(n_out_val)) do i Base.@_inline_meta z = make_zero(x) @@ -1536,18 +2110,46 @@ this function will retun an AbstractArray of shape `size(output)` of values of t rows, outshape else chunksize = Compiler.element(chunk) - tt′ = MD ? Tuple{BatchMixedDuplicated{XT, chunksize}} : Tuple{BatchDuplicated{XT, chunksize}} - primal, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#chunk, ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) - + tt′ = + MD ? Tuple{BatchMixedDuplicated{XT,chunksize}} : + Tuple{BatchDuplicated{XT,chunksize}} + primal, adjoint = Enzyme.Compiler.thunk( + opt_mi, + FA, + BatchDuplicatedNoNeed{rt}, + tt′, + Val(API.DEM_ReverseModeGradient), + chunk, + ModifiedBetween, + Val(false), + Val(false), + RABI, + Val(ErrIfFuncWritten), + Val(RuntimeActivity), + ) #=ShadowInit=# + num = ((n_out_val + chunksize - 1) ÷ chunksize) - + if num * chunksize == n_out_val last_size = chunksize primal2, adjoint2 = primal, adjoint else - last_size = n_out_val - (num-1)*chunksize - tt′ = Tuple{BatchDuplicated{Core.Typeof(x), last_size}} - primal2, adjoint2 = Enzyme.Compiler.thunk(opt_mi, FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(last_size), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) + last_size = n_out_val - (num - 1) * chunksize + tt′ = Tuple{BatchDuplicated{Core.Typeof(x),last_size}} + primal2, adjoint2 = Enzyme.Compiler.thunk( + opt_mi, + FA, + BatchDuplicatedNoNeed{rt}, + tt′, + Val(API.DEM_ReverseModeGradient), + Val(last_size), + ModifiedBetween, + Val(false), + Val(false), + RABI, + Val(ErrIfFuncWritten), + Val(RuntimeActivity), + ) #=ShadowInit=# end tmp = ntuple(num) do i @@ -1557,18 +2159,29 @@ this function will retun an AbstractArray of shape `size(output)` of values of t z = make_zero(x) MD ? Ref(z) : z end - res = (i == num ? primal2 : primal)(Const(f), MD ? BatchMixedDuplicated(x, dx) : BatchDuplicated(x, dx)) + res = (i == num ? primal2 : primal)( + Const(f), + MD ? BatchMixedDuplicated(x, dx) : BatchDuplicated(x, dx), + ) tape = res[1] j = 0 for shadow in res[3] j += 1 - @inbounds shadow[(i-1)*chunksize+j] += Compiler.default_adjoint(eltype(typeof(shadow))) + @inbounds shadow[(i-1)*chunksize+j] += + Compiler.default_adjoint(eltype(typeof(shadow))) end - (i == num ? adjoint2 : adjoint)(Const(f), MD ? BatchMixedDuplicated(x, dx) : BatchDuplicated(x, dx), tape) - return MD ? (ntuple(Val(i == num ? last_size : chunksize)) do idx - Base.@_inline_meta - dx[idx][] - end) : dx, (i == 1 ? size(res[3][1]) : nothing) + (i == num ? adjoint2 : adjoint)( + Const(f), + MD ? BatchMixedDuplicated(x, dx) : BatchDuplicated(x, dx), + tape, + ) + return MD ? ( + ntuple(Val(i == num ? last_size : chunksize)) do idx + Base.@_inline_meta + dx[idx][] + end + ) : dx, + (i == 1 ? size(res[3][1]) : nothing) end rows = tupleconcat(map(first, tmp)...) outshape = tmp[1][2] @@ -1581,7 +2194,10 @@ this function will retun an AbstractArray of shape `size(output)` of values of t st3 = if length(outshape) == 1 && length(inshape) == 1 transpose(st2) else - transp = ( ((length(inshape)+1):(length(inshape)+length(outshape)))... , (1:length(inshape))... ) + transp = ( + ((length(inshape)+1):(length(inshape)+length(outshape)))..., + (1:length(inshape))..., + ) PermutedDimsArray(st2, transp) end @@ -1590,14 +2206,14 @@ this function will retun an AbstractArray of shape `size(output)` of values of t reshape(collect(rows), outshape) end if ReturnPrimal - # TODO optimize away redundant fwd pass - (; derivs=(res,), val=if f isa Enzyme.Const - f.val(x) - else - f(x) - end) + # TODO optimize away redundant fwd pass + (; derivs = (res,), val = if f isa Enzyme.Const + f.val(x) + else + f(x) + end) else - (res,) + (res,) end end end @@ -1624,7 +2240,7 @@ hvp(f, [2.0, 3.0], [5.0, 2.7]) 16.201003759768003 ``` """ -@inline function hvp(f::F, x::X, v::X) where {F, X} +@inline function hvp(f::F, x::X, v::X) where {F,X} res = make_zero(x) hvp!(res, f, x, v) return res @@ -1657,9 +2273,16 @@ res 16.201003759768003 ``` """ -@inline function hvp!(res::X, f::F, x::X, v::X) where {F, X} +@inline function hvp!(res::X, f::F, x::X, v::X) where {F,X} grad = make_zero(x) - Enzyme.autodiff(Forward, gradient!, Const(Reverse), DuplicatedNoNeed(grad, res), Const(f), Duplicated(x, v)) + Enzyme.autodiff( + Forward, + gradient!, + Const(Reverse), + DuplicatedNoNeed(grad, res), + Const(f), + Duplicated(x, v), + ) return nothing end @@ -1693,8 +2316,15 @@ grad 1.920340573300732 ``` """ -@inline function hvp_and_gradient!(res::X, grad::X, f::F, x::X, v::X) where {F, X} - Enzyme.autodiff(Forward, gradient!, Const(Reverse), Duplicated(grad, res), Const(f), Duplicated(x, v)) +@inline function hvp_and_gradient!(res::X, grad::X, f::F, x::X, v::X) where {F,X} + Enzyme.autodiff( + Forward, + gradient!, + Const(Reverse), + Duplicated(grad, res), + Const(f), + Duplicated(x, v), + ) return nothing end @@ -1732,7 +2362,7 @@ Enzyme.autodiff(Forward, sort, Duplicated, BatchDuplicated(x, (dx,))) """ macro import_frule(args...) return _import_frule(args...) -end +end function _import_rrule end # defined in EnzymeChainRulesCoreExt extension diff --git a/src/absint.jl b/src/absint.jl index b84657aadbd..585b1625a39 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -1,9 +1,8 @@ # Abstractly interpret julia from LLVM # Return (bool if could interpret, julia object interpreted to) -function absint(arg::LLVM.Value, partial::Bool=false) - if isa(arg, LLVM.BitCastInst) || - isa(arg, LLVM.AddrSpaceCastInst) +function absint(arg::LLVM.Value, partial::Bool = false) + if isa(arg, LLVM.BitCastInst) || isa(arg, LLVM.AddrSpaceCastInst) return absint(operands(arg)[1], partial) end if isa(arg, ConstantExpr) @@ -18,12 +17,17 @@ function absint(arg::LLVM.Value, partial::Bool=false) nm = LLVM.name(fn) end for (fname, ty) in ( - ("jl_box_int64", Int64), ("ijl_box_int64", Int64), - ("jl_box_uint64", UInt64), ("ijl_box_uint64", UInt64), - ("jl_box_int32", Int32), ("ijl_box_int32", Int32), - ("jl_box_uint32", UInt32), ("ijl_box_uint32", UInt32), - ("jl_box_char", Char), ("ijl_box_char", Char), - ) + ("jl_box_int64", Int64), + ("ijl_box_int64", Int64), + ("jl_box_uint64", UInt64), + ("ijl_box_uint64", UInt64), + ("jl_box_int32", Int32), + ("ijl_box_int32", Int32), + ("jl_box_uint32", UInt32), + ("ijl_box_uint32", UInt32), + ("jl_box_char", Char), + ("ijl_box_char", Char), + ) if nm == fname v = first(operands(arg)) if isa(v, ConstantInt) @@ -39,7 +43,8 @@ function absint(arg::LLVM.Value, partial::Bool=false) return absint(operands(arg)[1], partial) end if nm == "jl_typeof" || nm == "ijl_typeof" - return abs_typeof(operands(arg)[1], partial) + vals = abs_typeof(operands(arg)[1], partial) + return (vals[1], vals[2]) end if LLVM.callconv(arg) == 37 || nm == "julia.call" index = 1 @@ -54,11 +59,11 @@ function absint(arg::LLVM.Value, partial::Bool=false) legal, Ty = absint(operands(arg)[index], partial) unionalls = [] for sarg in operands(arg)[index+1:end-1] - slegal , foundv = absint(sarg, partial) + slegal, foundv = absint(sarg, partial) if slegal push!(found, foundv) elseif partial - foundv = TypeVar(Symbol("sarg"*string(sarg))) + foundv = TypeVar(Symbol("sarg" * string(sarg))) push!(found, foundv) push!(unionalls, foundv) else @@ -80,7 +85,7 @@ function absint(arg::LLVM.Value, partial::Bool=false) found = [] legal = true for sarg in operands(arg)[index:end-1] - slegal , foundv = absint(sarg, partial) + slegal, foundv = absint(sarg, partial) if slegal push!(found, foundv) else @@ -107,25 +112,28 @@ function absint(arg::LLVM.Value, partial::Bool=false) end end - if isa(arg, GlobalVariable) + if isa(arg, GlobalVariable) gname = LLVM.name(arg) for (k, v) in JuliaGlobalNameMap - if gname == k || gname == "ejl_"*k + if gname == k || gname == "ejl_" * k return (true, v) end end for (k, v) in JuliaEnzymeNameMap - if gname == k || gname == "ejl_"*k + if gname == k || gname == "ejl_" * k return (true, v) end end end - if isa(arg, LLVM.LoadInst) && value_type(arg) == LLVM.PointerType(LLVM.StructType(LLVMType[]), Tracked) + if isa(arg, LLVM.LoadInst) && + value_type(arg) == LLVM.PointerType(LLVM.StructType(LLVMType[]), Tracked) ptr = operands(arg)[1] ce = ptr while isa(ce, ConstantExpr) - if opcode(ce) == LLVM.API.LLVMAddrSpaceCast || opcode(ce) == LLVM.API.LLVMBitCast || opcode(ce) == LLVM.API.LLVMIntToPtr + if opcode(ce) == LLVM.API.LLVMAddrSpaceCast || + opcode(ce) == LLVM.API.LLVMBitCast || + opcode(ce) == LLVM.API.LLVMIntToPtr ce = operands(ce)[1] else break @@ -149,9 +157,21 @@ function absint(arg::LLVM.Value, partial::Bool=false) return (false, nothing) end -function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Type},Tuple{Bool, Nothing}} - if isa(arg, LLVM.BitCastInst) || - isa(arg, LLVM.AddrSpaceCastInst) +function actual_size(@nospecialize(typ2)) + if typ2 <: Array || typ2 <: AbstractString + return sizeof(Int) + elseif Base.isconcretetype(typ2) + return sizeof(typ2) + else + return sizeof(Int) + end +end + +function abs_typeof( + arg::LLVM.Value, + partial::Bool = false, +)::Union{Tuple{Bool,Type,GPUCompiler.ArgumentCC},Tuple{Bool,Nothing,Nothing}} + if isa(arg, LLVM.BitCastInst) || isa(arg, LLVM.AddrSpaceCastInst) return abs_typeof(operands(arg)[1], partial) end if isa(arg, ConstantExpr) @@ -160,7 +180,7 @@ function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Typ end end - if isa(arg, LLVM.CallInst) + if isa(arg, LLVM.CallInst) fn = LLVM.called_operand(arg) nm = "" if isa(fn, LLVM.Function) @@ -170,27 +190,36 @@ function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Typ if nm == "julia.pointer_from_objref" return abs_typeof(operands(arg)[1], partial) end - + for (fname, ty) in ( - ("jl_box_int64", Int64), ("ijl_box_int64", Int64), - ("jl_box_uint64", UInt64), ("ijl_box_uint64", UInt64), - ("jl_box_int32", Int32), ("ijl_box_int32", Int32), - ("jl_box_uint32", UInt32), ("ijl_box_uint32", UInt32), - ("jl_box_float32", Float32), ("ijl_box_float32", Float32), - ("jl_box_char", Char), ("ijl_box_char", Char), - ("jl_specializations_get_linfo", Core.MethodInstance), - ("ijl_specializations_get_linfo", Core.MethodInstance), - ) + ("jl_box_int64", Int64), + ("ijl_box_int64", Int64), + ("jl_box_uint64", UInt64), + ("ijl_box_uint64", UInt64), + ("jl_box_int32", Int32), + ("ijl_box_int32", Int32), + ("jl_box_uint32", UInt32), + ("ijl_box_uint32", UInt32), + ("jl_box_float32", Float32), + ("ijl_box_float32", Float32), + ("jl_box_char", Char), + ("ijl_box_char", Char), + ("jl_specializations_get_linfo", Core.MethodInstance), + ("ijl_specializations_get_linfo", Core.MethodInstance), + ) if nm == fname - return (true, ty) + return (true, ty, GPUCompiler.MUT_REF) end end - - # Type tag is arg 3 - if nm == "julia.gc_alloc_obj" || nm == "jl_gc_alloc_typed" || nm == "ijl_gc_alloc_typed" - return absint(operands(arg)[3], partial) + + # Type tag is arg 3 + if nm == "julia.gc_alloc_obj" || + nm == "jl_gc_alloc_typed" || + nm == "ijl_gc_alloc_typed" + vals = absint(operands(arg)[3], partial) + return (vals[1], vals[2], vals[1] ? GPUCompiler.BITS_REF : nothing) end - # Type tag is arg 1 + # Type tag is arg 1 if nm == "jl_alloc_array_1d" || nm == "ijl_alloc_array_1d" || nm == "jl_alloc_array_2d" || @@ -199,11 +228,13 @@ function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Typ nm == "ijl_alloc_array_3d" || nm == "jl_new_array" || nm == "ijl_new_array" - return absint(operands(arg)[1], partial) + vals = absint(operands(arg)[1], partial) + return (vals[1], vals[2], vals[1] ? GPUCompiler.MUT_REF : nothing) end if nm == "jl_new_structt" || nm == "ijl_new_structt" - return absint(operands(arg)[1], partial) + vals = absint(operands(arg)[1], partial) + return (vals[1], vals[2], vals[1] ? GPUCompiler.MUT_REF : nothing) end if LLVM.callconv(arg) == 37 || nm == "julia.call" @@ -213,14 +244,15 @@ function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Typ nm = LLVM.name(fn) index += 1 end - - if nm == "jl_f_isdefined" || nm == "ijl_f_isdefined" - return (true, Bool) - end + + if nm == "jl_f_isdefined" || nm == "ijl_f_isdefined" + return (true, Bool, GPUCompiler.MUT_REF) + end if nm == "jl_new_structv" || nm == "ijl_new_structv" @assert index == 2 - return absint(operands(arg)[index], partial) + vals = absint(operands(arg)[index], partial) + return (vals[1], vals[2], vals[1] ? GPUCompiler.MUT_REF : nothing) end if nm == "jl_f_tuple" || nm == "ijl_f_tuple" @@ -229,11 +261,11 @@ function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Typ unionalls = [] legal = true for sarg in operands(arg)[index:end-1] - slegal , foundv = abs_typeof(sarg, partial) + slegal, foundv, _ = abs_typeof(sarg, partial) if slegal push!(found, foundv) elseif partial - foundv = TypeVar(Symbol("sarg"*string(sarg))) + foundv = TypeVar(Symbol("sarg" * string(sarg))) push!(found, foundv) push!(unionalls, foundv) else @@ -246,7 +278,7 @@ function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Typ for u in unionalls res = UnionAll(u, res) end - return (true, res) + return (true, res, GPUCompiler.BITS_REF) end end end @@ -261,16 +293,26 @@ function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Typ end if nm == "jl_array_copy" || nm == "ijl_array_copy" - legal, RT = abs_typeof(operands(arg)[1], partial) + legal, RT, _ = abs_typeof(operands(arg)[1], partial) if legal @assert RT <: Array + return (legal, RT, GPUCompiler.MUT_REF) end - return (legal, RT) + return (legal, RT, nothing) end _, RT = enzyme_custom_extract_mi(arg, false) if RT !== nothing - return (true, RT) + llrt, sret, returnRoots = get_return_info(RT) + if sret !== nothing + if llrt == RT + return (true, RT, GPUCompiler.BITS_VALUE) + elseif llrt == Ptr{RT} + return (true, RT, GPUCompiler.MUT_REF) + elseif llrt == Any + return (true, RT, GPUCompiler.BITS_REF) + end + end end end @@ -279,15 +321,16 @@ function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Typ offset = nothing error = false while true - if isa(larg, LLVM.BitCastInst) || - isa(larg, LLVM.AddrSpaceCastInst) - larg = operands(larg)[1] - continue + if isa(larg, LLVM.BitCastInst) || isa(larg, LLVM.AddrSpaceCastInst) + larg = operands(larg)[1] + continue end - if offset === nothing && isa(larg, LLVM.GetElementPtrInst) && all(x->isa(x, LLVM.ConstantInt), operands(larg)[2:end]) - b = LLVM.IRBuilder() + if offset === nothing && + isa(larg, LLVM.GetElementPtrInst) && + all(x -> isa(x, LLVM.ConstantInt), operands(larg)[2:end]) + b = LLVM.IRBuilder() position!(b, larg) - offty = LLVM.IntType(8*sizeof(Int)) + offty = LLVM.IntType(8 * sizeof(Int)) offset = API.EnzymeComputeByteOffsetOfGEP(b, larg, offty) @assert isa(offset, LLVM.ConstantInt) offset = convert(Int, offset) @@ -302,154 +345,148 @@ function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Typ end if !error - if isa(larg, LLVM.Argument) - f = LLVM.Function(LLVM.API.LLVMGetParamParent(larg)) - idx = only([i for (i, v) in enumerate(LLVM.parameters(f)) if v == larg]) - typ, byref = enzyme_extract_parm_type(f, idx, #=error=#false) + legal, typ, byref = abs_typeof(larg) + if legal && (byref == GPUCompiler.MUT_REF || byref == GPUCompiler.BITS_REF) @static if VERSION < v"1.11-" - if typ !== nothing && typ <: Array && Base.isconcretetype(typ) + if typ <: Array && Base.isconcretetype(typ) T = eltype(typ) if offset === nothing || offset == 0 - return (true, Ptr{T}) + return (true, Ptr{T}, GPUCompiler.BITS_VALUE) else - return (true, Int) + return (true, Int, GPUCompiler.BITS_VALUE) end end end - if typ !== nothing && byref == GPUCompiler.BITS_REF + if byref == GPUCompiler.BITS_REF || byref == GPUCompiler.MUT_REF + dl = LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(arg)))) if offset === nothing - return (true, typ) - else - function llsz(ty) - if isa(ty, LLVM.PointerType) - return sizeof(Ptr{Cvoid}) - elseif isa(ty, LLVM.IntegerType) - return LLVM.width(ty) / 8 + byref = GPUCompiler.BITS_VALUE + legal = true + typ2 = typ + while actual_size(typ2) != sizeof(dl, value_type(arg)) + if fieldcount(typ2) > 0 + typ2 = fieldtype(typ, 1) + if !Base.allocatedinline(typ2) + if byref != GPUCompiler.BITS_VALUE + legal = false + break + end + byref = GPUCompiler.MUT_REF + continue + end end - error("Unknown llvm type to size: "*string(ty)) + legal = false + break + end + if legal + return (true, typ2, byref) end + else @assert Base.isconcretetype(typ) - for i in 1:fieldcount(typ) + for i = 1:fieldcount(typ) if fieldoffset(typ, i) == offset - subT = fieldtype(typ, i) + subT = fieldtype(typ, i) fsize = if i == fieldcount(typ) sizeof(typ) else - fieldoffset(typ, i+1) + fieldoffset(typ, i + 1) end - offset - if fsize == llsz(value_type(larg)) - if Base.isconcretetype(subT) && is_concrete_tuple(subT) && length(subT.parameters) == 1 + if fsize == sizeof(dl, value_type(arg)) + if Base.isconcretetype(subT) && + is_concrete_tuple(subT) && + length(subT.parameters) == 1 subT = subT.parameters[1] end - return (true, subT) + if Base.allocatedinline(subT) + return (true, subT, GPUCompiler.BITS_VALUE) + else + return (true, subT, GPUCompiler.MUT_REF) + end end end end end end - else - legal, RT = abs_typeof(larg) - if legal - if RT <: Array && Base.isconcretetype(RT) - @static if VERSION < v"1.11-" - T = eltype(RT) - - if offset == 0 - return (true, Ptr{T}) - end - - return (true, Int) - end - end - if RT <: Ptr && Base.isconcretetype(RT) - return (true, eltype(RT)) - end - end + elseif legal && if typ <: Ptr && Base.isconcretetype(typ) + return (true, eltype(typ), GPUCompiler.BITS_VALUE) + end end end end - + if isa(arg, LLVM.ExtractValueInst) larg = operands(arg)[1] indptrs = LLVM.API.LLVMGetIndices(arg) numind = LLVM.API.LLVMGetNumIndices(arg) - offset = Cuint[unsafe_load(indptrs, i) for i in 1:numind] - if isa(larg, LLVM.Argument) || isa(larg, LLVM.ExtractValueInst) - typ, byref = if isa(larg, LLVM.Argument) - f = LLVM.Function(LLVM.API.LLVMGetParamParent(larg)) - idx = only([i for (i, v) in enumerate(LLVM.parameters(f)) if v == larg]) - enzyme_extract_parm_type(f, idx, #=error=#false) - else - found, typ = abs_typeof(larg, partial) - if !found - return (false, nothing) - end - (typ, GPUCompiler.BITS_VALUE) - end - if typ !== nothing && byref == GPUCompiler.BITS_VALUE - for ind in offset - @assert Base.isconcretetype(typ) - cnt = 0 - for i in 1:fieldcount(typ) - styp = fieldtype(typ, i) - if isghostty(styp) - continue - end - if cnt == ind - typ = styp - break - end - cnt+=1 + offset = Cuint[unsafe_load(indptrs, i) for i = 1:numind] + found, typ, byref = abs_typeof(larg, partial) + if !found + return (false, nothing, nothing) + end + if byref == GPUCompiler.BITS_VALUE + for ind in offset + @assert Base.isconcretetype(typ) + cnt = 0 + for i = 1:fieldcount(typ) + styp = fieldtype(typ, i) + if isghostty(styp) + continue end + if cnt == ind + typ = styp + break + end + cnt += 1 end - return (true, typ) + end + if Base.allocatedinline(typ) + return (true, typ, GPUCompiler.BITS_VALUE) + else + return (true, typ, GPUCompiler.MUT_REF) end end end - + if isa(arg, LLVM.Argument) f = LLVM.Function(LLVM.API.LLVMGetParamParent(arg)) idx = only([i for (i, v) in enumerate(LLVM.parameters(f)) if v == arg]) - typ, byref = enzyme_extract_parm_type(f, idx, #=error=#false) + typ, byref = enzyme_extract_parm_type(f, idx, false) #=error=# if typ !== nothing - if byref == GPUCompiler.BITS_REF - typ = Ptr{typ} - end - return (true, typ) + return (true, typ, byref) end end legal, val = absint(arg, partial) - if legal - return (true, Core.Typeof(val)) - end - return (false, nothing) -end - -function abs_cstring(arg::LLVM.Value)::Tuple{Bool,String} - if isa(arg, ConstantExpr) - ce = arg - while isa(ce, ConstantExpr) - if opcode(ce) == LLVM.API.LLVMAddrSpaceCast || opcode(ce) == LLVM.API.LLVMBitCast || opcode(ce) == LLVM.API.LLVMIntToPtr - ce = operands(ce)[1] - elseif opcode(ce) == LLVM.API.LLVMGetElementPtr - if all(x -> isa(x, LLVM.ConstantInt) && convert(UInt, x) == 0, operands(ce)[2:end]) - ce = operands(ce)[1] - else - break - end - else - break - end - end - if isa(ce, LLVM.GlobalVariable) - ce = LLVM.initializer(ce) - if (isa(ce, LLVM.ConstantArray) || isa(ce, LLVM.ConstantDataArray)) && eltype(value_type(ce)) == LLVM.IntType(8) - return (true, String(map((x)->convert(UInt8, x), collect(ce)[1:(end-1)]))) - end - - end - end - return (false, "") + if legal + return (true, Core.Typeof(val), GPUCompiler.BITS_REF) + end + return (false, nothing, nothing) end +# +# function abs_cstring(arg::LLVM.Value)::Tuple{Bool,String} +# if isa(arg, ConstantExpr) +# ce = arg +# while isa(ce, ConstantExpr) +# if opcode(ce) == LLVM.API.LLVMAddrSpaceCast || opcode(ce) == LLVM.API.LLVMBitCast || opcode(ce) == LLVM.API.LLVMIntToPtr +# ce = operands(ce)[1] +# elseif opcode(ce) == LLVM.API.LLVMGetElementPtr +# if all(x -> isa(x, LLVM.ConstantInt) && convert(UInt, x) == 0, operands(ce)[2:end]) +# ce = operands(ce)[1] +# else +# break +# end +# else +# break +# end +# end +# if isa(ce, LLVM.GlobalVariable) +# ce = LLVM.initializer(ce) +# if (isa(ce, LLVM.ConstantArray) || isa(ce, LLVM.ConstantDataArray)) && eltype(value_type(ce)) == LLVM.IntType(8) +# return (true, String(map((x)->convert(UInt8, x), collect(ce)[1:(end-1)]))) +# end +# +# end +# end +# return (false, "") +# end diff --git a/src/api.jl b/src/api.jl index 9e446dcf579..2861eba86ff 100644 --- a/src/api.jl +++ b/src/api.jl @@ -20,51 +20,56 @@ struct IntList data::Ptr{Int64} size::Csize_t end -IntList() = IntList(Ptr{Int64}(0),0) - -@cenum(CConcreteType, - DT_Anything = 0, - DT_Integer = 1, - DT_Pointer = 2, - DT_Half = 3, - DT_Float = 4, - DT_Double = 5, - DT_Unknown = 6, - DT_FP80 = 7, - DT_BFloat16 = 8 +IntList() = IntList(Ptr{Int64}(0), 0) + +@cenum( + CConcreteType, + DT_Anything = 0, + DT_Integer = 1, + DT_Pointer = 2, + DT_Half = 3, + DT_Float = 4, + DT_Double = 5, + DT_Unknown = 6, + DT_FP80 = 7, + DT_BFloat16 = 8 ) function EnzymeConcreteTypeIsFloat(cc::CConcreteType) - if cc == DT_Half - return LLVM.HalfType() - elseif cc == DT_Float - return LLVM.FloatType() - elseif cc == DT_Double - return LLVM.DoubleType() - elseif cc == DT_FP80 - return LLVM.X86FP80Type() - elseif cc == DT_BFloat16 - return LLVM.BFloatType() - else - return nothing - end -end - -@cenum(CValueType, - VT_None = 0, - VT_Primal = 1, - VT_Shadow = 2, - VT_Both = 3 -) - -function EnzymeBitcodeReplacement(mod, NotToReplace, found) + if cc == DT_Half + return LLVM.HalfType() + elseif cc == DT_Float + return LLVM.FloatType() + elseif cc == DT_Double + return LLVM.DoubleType() + elseif cc == DT_FP80 + return LLVM.X86FP80Type() + elseif cc == DT_BFloat16 + return LLVM.BFloatType() + else + return nothing + end +end + +@cenum(CValueType, VT_None = 0, VT_Primal = 1, VT_Shadow = 2, VT_Both = 3) + +function EnzymeBitcodeReplacement(mod, NotToReplace, found) foundSize = Ref{Csize_t}(0) foundP = Ref{Ptr{Cstring}}(C_NULL) - res = ccall((:EnzymeBitcodeReplacement, libEnzymeBCLoad), UInt8, (LLVM.API.LLVMModuleRef, Ptr{Cstring}, Csize_t, Ptr{Ptr{Cstring}}, Ptr{Csize_t}), mod, NotToReplace, length(NotToReplace), foundP, foundSize) + res = ccall( + (:EnzymeBitcodeReplacement, libEnzymeBCLoad), + UInt8, + (LLVM.API.LLVMModuleRef, Ptr{Cstring}, Csize_t, Ptr{Ptr{Cstring}}, Ptr{Csize_t}), + mod, + NotToReplace, + length(NotToReplace), + foundP, + foundSize, + ) foundNum = foundSize[] if foundNum != 0 foundP = foundP[] - for i in 1:foundNum + for i = 1:foundNum str = unsafe_load(foundP, i) push!(found, Base.unsafe_string(str)) Libc.free(str) @@ -72,34 +77,75 @@ function EnzymeBitcodeReplacement(mod, NotToReplace, found) end Libc.free(foundP) end - return res + return res end struct EnzymeTypeTree end const CTypeTreeRef = Ptr{EnzymeTypeTree} EnzymeNewTypeTree() = ccall((:EnzymeNewTypeTree, libEnzyme), CTypeTreeRef, ()) -EnzymeNewTypeTreeCT(T, ctx) = ccall((:EnzymeNewTypeTreeCT, libEnzyme), CTypeTreeRef, (CConcreteType, LLVMContextRef), T, ctx) -EnzymeNewTypeTreeTR(tt) = ccall((:EnzymeNewTypeTreeTR, libEnzyme), CTypeTreeRef, (CTypeTreeRef,), tt) +EnzymeNewTypeTreeCT(T, ctx) = ccall( + (:EnzymeNewTypeTreeCT, libEnzyme), + CTypeTreeRef, + (CConcreteType, LLVMContextRef), + T, + ctx, +) +EnzymeNewTypeTreeTR(tt) = + ccall((:EnzymeNewTypeTreeTR, libEnzyme), CTypeTreeRef, (CTypeTreeRef,), tt) EnzymeFreeTypeTree(tt) = ccall((:EnzymeFreeTypeTree, libEnzyme), Cvoid, (CTypeTreeRef,), tt) -EnzymeSetTypeTree(dst, src) = ccall((:EnzymeSetTypeTree, libEnzyme), UInt8, (CTypeTreeRef, CTypeTreeRef), dst, src) -EnzymeMergeTypeTree(dst, src) = ccall((:EnzymeMergeTypeTree, libEnzyme), UInt8, (CTypeTreeRef, CTypeTreeRef), dst, src) -function EnzymeCheckedMergeTypeTree(dst, src) +EnzymeSetTypeTree(dst, src) = + ccall((:EnzymeSetTypeTree, libEnzyme), UInt8, (CTypeTreeRef, CTypeTreeRef), dst, src) +EnzymeMergeTypeTree(dst, src) = + ccall((:EnzymeMergeTypeTree, libEnzyme), UInt8, (CTypeTreeRef, CTypeTreeRef), dst, src) +function EnzymeCheckedMergeTypeTree(dst, src) legal = Ref{UInt8}(0) - res = ccall((:EnzymeCheckedMergeTypeTree, libEnzyme), UInt8, (CTypeTreeRef, CTypeTreeRef, Ptr{UInt8}), dst, src, legal) + res = ccall( + (:EnzymeCheckedMergeTypeTree, libEnzyme), + UInt8, + (CTypeTreeRef, CTypeTreeRef, Ptr{UInt8}), + dst, + src, + legal, + ) return res != 0, legal[] != 0 end -EnzymeTypeTreeOnlyEq(dst, x) = ccall((:EnzymeTypeTreeOnlyEq, libEnzyme), Cvoid, (CTypeTreeRef, Int64), dst, x) -EnzymeTypeTreeLookupEq(dst, x, dl) = ccall((:EnzymeTypeTreeLookupEq, libEnzyme), Cvoid, (CTypeTreeRef, Int64, Cstring), dst, x, dl) -EnzymeTypeTreeCanonicalizeInPlace(dst, x, dl) = ccall((:EnzymeTypeTreeCanonicalizeInPlace, libEnzyme), Cvoid, (CTypeTreeRef, Int64, Cstring), dst, x, dl) -EnzymeTypeTreeData0Eq(dst) = ccall((:EnzymeTypeTreeData0Eq, libEnzyme), Cvoid, (CTypeTreeRef,), dst) -EnzymeTypeTreeInner0(dst) = ccall((:EnzymeTypeTreeInner0, libEnzyme), CConcreteType, (CTypeTreeRef,), dst) -EnzymeTypeTreeShiftIndiciesEq(dst, dl, offset, maxSize, addOffset) = - ccall((:EnzymeTypeTreeShiftIndiciesEq, libEnzyme), Cvoid, (CTypeTreeRef, Cstring, Int64, Int64, UInt64), - dst, dl, offset, maxSize, addOffset) +EnzymeTypeTreeOnlyEq(dst, x) = + ccall((:EnzymeTypeTreeOnlyEq, libEnzyme), Cvoid, (CTypeTreeRef, Int64), dst, x) +EnzymeTypeTreeLookupEq(dst, x, dl) = ccall( + (:EnzymeTypeTreeLookupEq, libEnzyme), + Cvoid, + (CTypeTreeRef, Int64, Cstring), + dst, + x, + dl, +) +EnzymeTypeTreeCanonicalizeInPlace(dst, x, dl) = ccall( + (:EnzymeTypeTreeCanonicalizeInPlace, libEnzyme), + Cvoid, + (CTypeTreeRef, Int64, Cstring), + dst, + x, + dl, +) +EnzymeTypeTreeData0Eq(dst) = + ccall((:EnzymeTypeTreeData0Eq, libEnzyme), Cvoid, (CTypeTreeRef,), dst) +EnzymeTypeTreeInner0(dst) = + ccall((:EnzymeTypeTreeInner0, libEnzyme), CConcreteType, (CTypeTreeRef,), dst) +EnzymeTypeTreeShiftIndiciesEq(dst, dl, offset, maxSize, addOffset) = ccall( + (:EnzymeTypeTreeShiftIndiciesEq, libEnzyme), + Cvoid, + (CTypeTreeRef, Cstring, Int64, Int64, UInt64), + dst, + dl, + offset, + maxSize, + addOffset, +) -EnzymeTypeTreeToString(tt) = ccall((:EnzymeTypeTreeToString, libEnzyme), Cstring, (CTypeTreeRef,), tt) +EnzymeTypeTreeToString(tt) = + ccall((:EnzymeTypeTreeToString, libEnzyme), Cstring, (CTypeTreeRef,), tt) EnzymeStringFree(str) = ccall((:EnzymeStringFree, libEnzyme), Cvoid, (Cstring,), str) struct CFnTypeInfo @@ -109,35 +155,65 @@ struct CFnTypeInfo known_values::Ptr{IntList} end -SetMD(v::Union{LLVM.Instruction, LLVM.GlobalVariable}, kind::String, node::LLVM.Metadata) = ccall((:EnzymeSetStringMD, libEnzyme), Cvoid, (LLVM.API.LLVMValueRef, Cstring, LLVM.API.LLVMValueRef), v, kind, LLVM.Value(node)) +SetMD(v::Union{LLVM.Instruction,LLVM.GlobalVariable}, kind::String, node::LLVM.Metadata) = + ccall( + (:EnzymeSetStringMD, libEnzyme), + Cvoid, + (LLVM.API.LLVMValueRef, Cstring, LLVM.API.LLVMValueRef), + v, + kind, + LLVM.Value(node), + ) @static if !isdefined(LLVM, :ValueMetadataDict) -Base.haskey(md::LLVM.InstructionMetadataDict, kind::String) = - ccall((:EnzymeGetStringMD, libEnzyme), Cvoid, (LLVM.API.LLVMValueRef, Cstring), md.inst, kind) != C_NULL - -function Base.getindex(md::LLVM.InstructionMetadataDict, kind::String) - objref = ccall((:EnzymeGetStringMD, libEnzyme), Cvoid, (LLVM.API.LLVMValueRef, Cstring), md.inst, kind) != C_NULL - objref == C_NULL && throw(KeyError(kind)) - return LLVM.Metadata(LLVM.MetadataAsValue(objref)) - end - -Base.setindex!(md::LLVM.InstructionMetadataDict, node::LLVM.Metadata, kind::String) = - ccall((:EnzymeSetStringMD, libEnzyme), Cvoid, (LLVM.API.LLVMValueRef, Cstring, LLVM.API.LLVMValueRef), md.inst, kind, LLVM.Value(node)) -end + Base.haskey(md::LLVM.InstructionMetadataDict, kind::String) = + ccall( + (:EnzymeGetStringMD, libEnzyme), + Cvoid, + (LLVM.API.LLVMValueRef, Cstring), + md.inst, + kind, + ) != C_NULL + + function Base.getindex(md::LLVM.InstructionMetadataDict, kind::String) + objref = + ccall( + (:EnzymeGetStringMD, libEnzyme), + Cvoid, + (LLVM.API.LLVMValueRef, Cstring), + md.inst, + kind, + ) != C_NULL + objref == C_NULL && throw(KeyError(kind)) + return LLVM.Metadata(LLVM.MetadataAsValue(objref)) + end -@cenum(CDIFFE_TYPE, - DFT_OUT_DIFF = 0, # add differential to an output struct - DFT_DUP_ARG = 1, # duplicate the argument and store differential inside - DFT_CONSTANT = 2, # no differential - DFT_DUP_NONEED = 3 # duplicate this argument and store differential inside, - # but don't need the forward + Base.setindex!(md::LLVM.InstructionMetadataDict, node::LLVM.Metadata, kind::String) = + ccall( + (:EnzymeSetStringMD, libEnzyme), + Cvoid, + (LLVM.API.LLVMValueRef, Cstring, LLVM.API.LLVMValueRef), + md.inst, + kind, + LLVM.Value(node), + ) +end + +@cenum( + CDIFFE_TYPE, + DFT_OUT_DIFF = 0, # add differential to an output struct + DFT_DUP_ARG = 1, # duplicate the argument and store differential inside + DFT_CONSTANT = 2, # no differential + DFT_DUP_NONEED = 3 # duplicate this argument and store differential inside, + # but don't need the forward ) -@cenum(CDerivativeMode, - DEM_ForwardMode = 0, - DEM_ReverseModePrimal = 1, - DEM_ReverseModeGradient = 2, - DEM_ReverseModeCombined = 3 +@cenum( + CDerivativeMode, + DEM_ForwardMode = 0, + DEM_ReverseModePrimal = 1, + DEM_ReverseModeGradient = 2, + DEM_ReverseModeCombined = 3 ) # Create the derivative function itself. @@ -155,31 +231,133 @@ end # pass # \p AtomicAdd is whether to perform all adjoint updates to memory in an atomic way # \p PostOpt is whether to perform basic optimization of the function after synthesis -function EnzymeCreatePrimalAndGradient(logic, todiff, retType, constant_args, TA, - returnValue, dretUsed, mode, runtimeActivity, width, additionalArg, - forceAnonymousTape, typeInfo, - uncacheable_args, augmented, atomicAdd) +function EnzymeCreatePrimalAndGradient( + logic, + todiff, + retType, + constant_args, + TA, + returnValue, + dretUsed, + mode, + runtimeActivity, + width, + additionalArg, + forceAnonymousTape, + typeInfo, + uncacheable_args, + augmented, + atomicAdd, +) freeMemory = true - ccall((:EnzymeCreatePrimalAndGradient, libEnzyme), LLVMValueRef, - (EnzymeLogicRef, LLVMValueRef, LLVM.API.LLVMBuilderRef, LLVMValueRef, CDIFFE_TYPE, Ptr{CDIFFE_TYPE}, Csize_t, - EnzymeTypeAnalysisRef, UInt8, UInt8, CDerivativeMode, UInt8, Cuint, UInt8, LLVMTypeRef, UInt8, CFnTypeInfo, - Ptr{UInt8}, Csize_t, EnzymeAugmentedReturnPtr, UInt8), - logic, C_NULL, C_NULL, todiff, retType, constant_args, length(constant_args), TA, returnValue, - dretUsed, mode, runtimeActivity, width, freeMemory, additionalArg, forceAnonymousTape, typeInfo, uncacheable_args, length(uncacheable_args), - augmented, atomicAdd) -end - -function EnzymeCreateForwardDiff(logic, todiff, retType, constant_args, TA, - returnValue, mode, runtimeActivity, width, additionalArg, typeInfo, - uncacheable_args) + ccall( + (:EnzymeCreatePrimalAndGradient, libEnzyme), + LLVMValueRef, + ( + EnzymeLogicRef, + LLVMValueRef, + LLVM.API.LLVMBuilderRef, + LLVMValueRef, + CDIFFE_TYPE, + Ptr{CDIFFE_TYPE}, + Csize_t, + EnzymeTypeAnalysisRef, + UInt8, + UInt8, + CDerivativeMode, + UInt8, + Cuint, + UInt8, + LLVMTypeRef, + UInt8, + CFnTypeInfo, + Ptr{UInt8}, + Csize_t, + EnzymeAugmentedReturnPtr, + UInt8, + ), + logic, + C_NULL, + C_NULL, + todiff, + retType, + constant_args, + length(constant_args), + TA, + returnValue, + dretUsed, + mode, + runtimeActivity, + width, + freeMemory, + additionalArg, + forceAnonymousTape, + typeInfo, + uncacheable_args, + length(uncacheable_args), + augmented, + atomicAdd, + ) +end + +function EnzymeCreateForwardDiff( + logic, + todiff, + retType, + constant_args, + TA, + returnValue, + mode, + runtimeActivity, + width, + additionalArg, + typeInfo, + uncacheable_args, +) freeMemory = true aug = C_NULL - ccall((:EnzymeCreateForwardDiff, libEnzyme), LLVMValueRef, - (EnzymeLogicRef, LLVMValueRef, LLVM.API.LLVMBuilderRef, LLVMValueRef, CDIFFE_TYPE, Ptr{CDIFFE_TYPE}, Csize_t, - EnzymeTypeAnalysisRef, UInt8, CDerivativeMode, UInt8, UInt8, Cuint, LLVMTypeRef, CFnTypeInfo, - Ptr{UInt8}, Csize_t, EnzymeAugmentedReturnPtr), - logic, C_NULL, C_NULL, todiff, retType, constant_args, length(constant_args), TA, returnValue, - mode, freeMemory, runtimeActivity, width, additionalArg, typeInfo, uncacheable_args, length(uncacheable_args), aug) + ccall( + (:EnzymeCreateForwardDiff, libEnzyme), + LLVMValueRef, + ( + EnzymeLogicRef, + LLVMValueRef, + LLVM.API.LLVMBuilderRef, + LLVMValueRef, + CDIFFE_TYPE, + Ptr{CDIFFE_TYPE}, + Csize_t, + EnzymeTypeAnalysisRef, + UInt8, + CDerivativeMode, + UInt8, + UInt8, + Cuint, + LLVMTypeRef, + CFnTypeInfo, + Ptr{UInt8}, + Csize_t, + EnzymeAugmentedReturnPtr, + ), + logic, + C_NULL, + C_NULL, + todiff, + retType, + constant_args, + length(constant_args), + TA, + returnValue, + mode, + freeMemory, + runtimeActivity, + width, + additionalArg, + typeInfo, + uncacheable_args, + length(uncacheable_args), + aug, + ) end # Create an augmented forward pass. @@ -193,16 +371,61 @@ end # \p forceAnonymousTape forces the tape to be an i8* rather than the true tape structure # \p AtomicAdd is whether to perform all adjoint updates to memory in an atomic way # \p PostOpt is whether to perform basic optimization of the function after synthesis -function EnzymeCreateAugmentedPrimal(logic, todiff, retType, constant_args, TA, returnUsed, - shadowReturnUsed, - typeInfo, uncacheable_args, forceAnonymousTape, runtimeActivity, width, atomicAdd) - ccall((:EnzymeCreateAugmentedPrimal, libEnzyme), EnzymeAugmentedReturnPtr, - (EnzymeLogicRef, LLVMValueRef, LLVM.API.LLVMBuilderRef, LLVMValueRef, CDIFFE_TYPE, Ptr{CDIFFE_TYPE}, Csize_t, - EnzymeTypeAnalysisRef, UInt8, UInt8, - CFnTypeInfo, Ptr{UInt8}, Csize_t, UInt8, UInt8, Cuint, UInt8), - logic, C_NULL, C_NULL, todiff, retType, constant_args, length(constant_args), TA, returnUsed, +function EnzymeCreateAugmentedPrimal( + logic, + todiff, + retType, + constant_args, + TA, + returnUsed, + shadowReturnUsed, + typeInfo, + uncacheable_args, + forceAnonymousTape, + runtimeActivity, + width, + atomicAdd, +) + ccall( + (:EnzymeCreateAugmentedPrimal, libEnzyme), + EnzymeAugmentedReturnPtr, + ( + EnzymeLogicRef, + LLVMValueRef, + LLVM.API.LLVMBuilderRef, + LLVMValueRef, + CDIFFE_TYPE, + Ptr{CDIFFE_TYPE}, + Csize_t, + EnzymeTypeAnalysisRef, + UInt8, + UInt8, + CFnTypeInfo, + Ptr{UInt8}, + Csize_t, + UInt8, + UInt8, + Cuint, + UInt8, + ), + logic, + C_NULL, + C_NULL, + todiff, + retType, + constant_args, + length(constant_args), + TA, + returnUsed, shadowReturnUsed, - typeInfo, uncacheable_args, length(uncacheable_args), forceAnonymousTape, runtimeActivity, width, atomicAdd) + typeInfo, + uncacheable_args, + length(uncacheable_args), + forceAnonymousTape, + runtimeActivity, + width, + atomicAdd, + ) end # typedef uint8_t (*CustomRuleType)(int /*direction*/, CTypeTreeRef /*return*/, @@ -213,7 +436,15 @@ const CustomRuleType = Ptr{Cvoid} function CreateTypeAnalysis(logic, rulenames, rules) @assert length(rulenames) == length(rules) - ccall((:CreateTypeAnalysis, libEnzyme), EnzymeTypeAnalysisRef, (EnzymeLogicRef, Ptr{Cstring}, Ptr{CustomRuleType}, Csize_t), logic, rulenames, rules, length(rules)) + ccall( + (:CreateTypeAnalysis, libEnzyme), + EnzymeTypeAnalysisRef, + (EnzymeLogicRef, Ptr{Cstring}, Ptr{CustomRuleType}, Csize_t), + logic, + rulenames, + rules, + length(rules), + ) end function ClearTypeAnalysis(ta) @@ -225,67 +456,422 @@ function FreeTypeAnalysis(ta) end function EnzymeAnalyzeTypes(ta, CTI, F) - ccall((:EnzymeAnalyzeTypes, libEnzyme), EnzymeTypeAnalyzerRef, (EnzymeTypeAnalysisRef, CFnTypeInfo, LLVMValueRef), ta, CTI, F) + ccall( + (:EnzymeAnalyzeTypes, libEnzyme), + EnzymeTypeAnalyzerRef, + (EnzymeTypeAnalysisRef, CFnTypeInfo, LLVMValueRef), + ta, + CTI, + F, + ) end - + const CustomShadowAlloc = Ptr{Cvoid} const CustomShadowFree = Ptr{Cvoid} -EnzymeRegisterAllocationHandler(name, ahandle, fhandle) = ccall((:EnzymeRegisterAllocationHandler, libEnzyme), Cvoid, (Cstring, CustomShadowAlloc, CustomShadowFree), name, ahandle, fhandle) +EnzymeRegisterAllocationHandler(name, ahandle, fhandle) = ccall( + (:EnzymeRegisterAllocationHandler, libEnzyme), + Cvoid, + (Cstring, CustomShadowAlloc, CustomShadowFree), + name, + ahandle, + fhandle, +) const CustomAugmentedForwardPass = Ptr{Cvoid} const CustomForwardPass = Ptr{Cvoid} const CustomReversePass = Ptr{Cvoid} -EnzymeRegisterCallHandler(name, fwdhandle, revhandle) = ccall((:EnzymeRegisterCallHandler, libEnzyme), Cvoid, (Cstring, CustomAugmentedForwardPass, CustomReversePass), name, fwdhandle, revhandle) -EnzymeRegisterFwdCallHandler(name, fwdhandle) = ccall((:EnzymeRegisterFwdCallHandler, libEnzyme), Cvoid, (Cstring, CustomForwardPass), name, fwdhandle) +EnzymeRegisterCallHandler(name, fwdhandle, revhandle) = ccall( + (:EnzymeRegisterCallHandler, libEnzyme), + Cvoid, + (Cstring, CustomAugmentedForwardPass, CustomReversePass), + name, + fwdhandle, + revhandle, +) +EnzymeRegisterFwdCallHandler(name, fwdhandle) = ccall( + (:EnzymeRegisterFwdCallHandler, libEnzyme), + Cvoid, + (Cstring, CustomForwardPass), + name, + fwdhandle, +) -EnzymeInsertValue(B::LLVM.IRBuilder, v::LLVM.Value, v2::LLVM.Value, insts::Vector{Cuint}, name="") = LLVM.Value(ccall((:EnzymeInsertValue, libEnzyme), LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVMValueRef, LLVMValueRef, Ptr{Cuint}, Int64, Cstring), B, v, v2, insts, length(insts), name)) +EnzymeInsertValue( + B::LLVM.IRBuilder, + v::LLVM.Value, + v2::LLVM.Value, + insts::Vector{Cuint}, + name = "", +) = LLVM.Value( + ccall( + (:EnzymeInsertValue, libEnzyme), + LLVMValueRef, + (LLVM.API.LLVMBuilderRef, LLVMValueRef, LLVMValueRef, Ptr{Cuint}, Int64, Cstring), + B, + v, + v2, + insts, + length(insts), + name, + ), +) const CustomDiffUse = Ptr{Cvoid} -EnzymeRegisterDiffUseCallHandler(name, handle) = ccall((:EnzymeRegisterDiffUseCallHandler, libEnzyme), Cvoid, (Cstring, CustomDiffUse), name, handle) -EnzymeSetCalledFunction(ci::LLVM.CallInst, fn::LLVM.Function, toremove) = ccall((:EnzymeSetCalledFunction, libEnzyme), Cvoid, (LLVMValueRef, LLVMValueRef, Ptr{Int64}, Int64), ci, fn, toremove, length(toremove)) -EnzymeCloneFunctionWithoutReturnOrArgs(fn::LLVM.Function, keepret, args) = ccall((:EnzymeCloneFunctionWithoutReturnOrArgs, libEnzyme), LLVMValueRef, (LLVMValueRef,UInt8,Ptr{Int64}, Int64), fn, keepret, args, length(args)) -EnzymeGetShadowType(width, T) = ccall((:EnzymeGetShadowType, libEnzyme), LLVMTypeRef, (UInt64,LLVMTypeRef), width, T) - -EnzymeGradientUtilsReplaceAWithB(gutils, a, b) = ccall((:EnzymeGradientUtilsReplaceAWithB, libEnzyme), Cvoid, (EnzymeGradientUtilsRef,LLVMValueRef, LLVMValueRef), gutils, a, b) -EnzymeGradientUtilsErase(gutils, a) = ccall((:EnzymeGradientUtilsErase, libEnzyme), Cvoid, (EnzymeGradientUtilsRef,LLVMValueRef), gutils, a) -EnzymeGradientUtilsEraseWithPlaceholder(gutils, a, orig, erase) = ccall((:EnzymeGradientUtilsEraseWithPlaceholder, libEnzyme), Cvoid, (EnzymeGradientUtilsRef,LLVMValueRef, LLVMValueRef, UInt8), gutils, a, orig, erase) -EnzymeGradientUtilsGetMode(gutils) = ccall((:EnzymeGradientUtilsGetMode, libEnzyme), CDerivativeMode, (EnzymeGradientUtilsRef,), gutils) -EnzymeGradientUtilsGetWidth(gutils) = ccall((:EnzymeGradientUtilsGetWidth, libEnzyme), UInt64, (EnzymeGradientUtilsRef,), gutils) -EnzymeGradientUtilsGetRuntimeActivity(gutils) = ccall((:EnzymeGradientUtilsGetRuntimeActivity, libEnzyme), UInt8, (EnzymeGradientUtilsRef,), gutils) != 0 -EnzymeGradientUtilsNewFromOriginal(gutils, val) = ccall((:EnzymeGradientUtilsNewFromOriginal, libEnzyme), LLVMValueRef, (EnzymeGradientUtilsRef, LLVMValueRef), gutils, val) -EnzymeGradientUtilsSetDebugLocFromOriginal(gutils, val, orig) = ccall((:EnzymeGradientUtilsSetDebugLocFromOriginal, libEnzyme), Cvoid, (EnzymeGradientUtilsRef, LLVMValueRef, LLVMValueRef), gutils, val, orig) -EnzymeGradientUtilsLookup(gutils, val, B) = ccall((:EnzymeGradientUtilsLookup, libEnzyme), LLVMValueRef, (EnzymeGradientUtilsRef, LLVMValueRef, LLVM.API.LLVMBuilderRef), gutils, val, B) -EnzymeGradientUtilsInvertPointer(gutils, val, B) = ccall((:EnzymeGradientUtilsInvertPointer, libEnzyme), LLVMValueRef, (EnzymeGradientUtilsRef, LLVMValueRef, LLVM.API.LLVMBuilderRef), gutils, val, B) -EnzymeGradientUtilsDiffe(gutils, val, B) = ccall((:EnzymeGradientUtilsDiffe, libEnzyme), LLVMValueRef, (EnzymeGradientUtilsRef, LLVMValueRef, LLVM.API.LLVMBuilderRef), gutils, val, B) -EnzymeGradientUtilsAddToDiffe(gutils, val, diffe, B, T) = ccall((:EnzymeGradientUtilsAddToDiffe, libEnzyme), Cvoid, (EnzymeGradientUtilsRef, LLVMValueRef, LLVMValueRef, LLVM.API.LLVMBuilderRef, LLVMTypeRef), gutils, val, diffe, B, T) -function EnzymeGradientUtilsAddToInvertedPointerDiffeTT(gutils, orig, origVal, vd, size, origptr, prediff, B, align, premask) - ccall((:EnzymeGradientUtilsAddToInvertedPointerDiffeTT, libEnzyme), Cvoid, (EnzymeGradientUtilsRef, LLVMValueRef, LLVMValueRef, CTypeTreeRef, Cuint, LLVMValueRef, LLVMValueRef, LLVM.API.LLVMBuilderRef, Cuint, LLVMValueRef), gutils, orig, origVal, vd, size, origptr, prediff, B, align, premask) -end - -EnzymeGradientUtilsSetDiffe(gutils, val, diffe, B) = ccall((:EnzymeGradientUtilsSetDiffe, libEnzyme), Cvoid, (EnzymeGradientUtilsRef, LLVMValueRef, LLVMValueRef, LLVM.API.LLVMBuilderRef), gutils, val, diffe, B) -EnzymeGradientUtilsIsConstantValue(gutils, val) = ccall((:EnzymeGradientUtilsIsConstantValue, libEnzyme), UInt8, (EnzymeGradientUtilsRef, LLVMValueRef), gutils, val) -EnzymeGradientUtilsIsConstantInstruction(gutils, val) = ccall((:EnzymeGradientUtilsIsConstantInstruction, libEnzyme), UInt8, (EnzymeGradientUtilsRef, LLVMValueRef), gutils, val) -EnzymeGradientUtilsAllocationBlock(gutils) = ccall((:EnzymeGradientUtilsAllocationBlock, libEnzyme), LLVM.API.LLVMBasicBlockRef, (EnzymeGradientUtilsRef,), gutils) - -EnzymeGradientUtilsTypeAnalyzer(gutils) = ccall((:EnzymeGradientUtilsTypeAnalyzer, libEnzyme), EnzymeTypeAnalyzerRef, (EnzymeGradientUtilsRef,), gutils) - -EnzymeGradientUtilsAllocAndGetTypeTree(gutils, val) = ccall((:EnzymeGradientUtilsAllocAndGetTypeTree, libEnzyme), CTypeTreeRef, (EnzymeGradientUtilsRef,LLVMValueRef), gutils, val) - -EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, size) = ccall((:EnzymeGradientUtilsGetUncacheableArgs, libEnzyme), UInt8, (EnzymeGradientUtilsRef,LLVMValueRef, Ptr{UInt8}, UInt64), gutils, orig, uncacheable, size) - -EnzymeGradientUtilsGetDiffeType(gutils, op, isforeign) = ccall((:EnzymeGradientUtilsGetDiffeType, libEnzyme), CDIFFE_TYPE, (EnzymeGradientUtilsRef,LLVMValueRef, UInt8), gutils, op, isforeign) - -EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, mode) = ccall((:EnzymeGradientUtilsGetReturnDiffeType, libEnzyme), CDIFFE_TYPE, (EnzymeGradientUtilsRef,LLVMValueRef, Ptr{UInt8}, Ptr{UInt8}, CDerivativeMode), gutils, orig, needsPrimalP, needsShadowP, mode) - -EnzymeGradientUtilsSubTransferHelper(gutils, mode, secretty, intrinsic, dstAlign, srcAlign, offset, dstConstant, origdst, srcConstant, origsrc, length, isVolatile, MTI, allowForward, shadowsLookedUp) = ccall((:EnzymeGradientUtilsSubTransferHelper, libEnzyme), - Cvoid, - ( EnzymeGradientUtilsRef, CDerivativeMode, LLVMTypeRef, UInt64, UInt64, UInt64, UInt64, UInt8, LLVMValueRef, UInt8, LLVMValueRef, LLVMValueRef, LLVMValueRef, LLVMValueRef, UInt8, UInt8), - gutils, mode, secretty, intrinsic, dstAlign, srcAlign, offset, dstConstant, origdst, srcConstant, origsrc, length, isVolatile, MTI, allowForward, shadowsLookedUp) - -EnzymeGradientUtilsCallWithInvertedBundles(gutils, func, funcTy, argvs, argc, orig, valTys, valCnt, B, lookup) = ccall((:EnzymeGradientUtilsCallWithInvertedBundles, libEnzyme), LLVMValueRef, (EnzymeGradientUtilsRef,LLVMValueRef, LLVMTypeRef, Ptr{LLVMValueRef}, UInt64, LLVMValueRef, Ptr{CValueType}, UInt64, LLVM.API.LLVMBuilderRef, UInt8), gutils, func, funcTy, argvs, argc, orig, valTys, valCnt, B, lookup) - -function sub_transfer(gutils, mode, secretty, intrinsic, dstAlign, srcAlign, offset, dstConstant, origdst, srcConstant, origsrc, length, isVolatile, MTI, allowForward, shadowsLookedUp) +EnzymeRegisterDiffUseCallHandler(name, handle) = ccall( + (:EnzymeRegisterDiffUseCallHandler, libEnzyme), + Cvoid, + (Cstring, CustomDiffUse), + name, + handle, +) +EnzymeSetCalledFunction(ci::LLVM.CallInst, fn::LLVM.Function, toremove) = ccall( + (:EnzymeSetCalledFunction, libEnzyme), + Cvoid, + (LLVMValueRef, LLVMValueRef, Ptr{Int64}, Int64), + ci, + fn, + toremove, + length(toremove), +) +EnzymeCloneFunctionWithoutReturnOrArgs(fn::LLVM.Function, keepret, args) = ccall( + (:EnzymeCloneFunctionWithoutReturnOrArgs, libEnzyme), + LLVMValueRef, + (LLVMValueRef, UInt8, Ptr{Int64}, Int64), + fn, + keepret, + args, + length(args), +) +EnzymeGetShadowType(width, T) = + ccall((:EnzymeGetShadowType, libEnzyme), LLVMTypeRef, (UInt64, LLVMTypeRef), width, T) + +EnzymeGradientUtilsReplaceAWithB(gutils, a, b) = ccall( + (:EnzymeGradientUtilsReplaceAWithB, libEnzyme), + Cvoid, + (EnzymeGradientUtilsRef, LLVMValueRef, LLVMValueRef), + gutils, + a, + b, +) +EnzymeGradientUtilsErase(gutils, a) = ccall( + (:EnzymeGradientUtilsErase, libEnzyme), + Cvoid, + (EnzymeGradientUtilsRef, LLVMValueRef), + gutils, + a, +) +EnzymeGradientUtilsEraseWithPlaceholder(gutils, a, orig, erase) = ccall( + (:EnzymeGradientUtilsEraseWithPlaceholder, libEnzyme), + Cvoid, + (EnzymeGradientUtilsRef, LLVMValueRef, LLVMValueRef, UInt8), + gutils, + a, + orig, + erase, +) +EnzymeGradientUtilsGetMode(gutils) = ccall( + (:EnzymeGradientUtilsGetMode, libEnzyme), + CDerivativeMode, + (EnzymeGradientUtilsRef,), + gutils, +) +EnzymeGradientUtilsGetWidth(gutils) = ccall( + (:EnzymeGradientUtilsGetWidth, libEnzyme), + UInt64, + (EnzymeGradientUtilsRef,), + gutils, +) +EnzymeGradientUtilsGetRuntimeActivity(gutils) = + ccall( + (:EnzymeGradientUtilsGetRuntimeActivity, libEnzyme), + UInt8, + (EnzymeGradientUtilsRef,), + gutils, + ) != 0 +EnzymeGradientUtilsNewFromOriginal(gutils, val) = ccall( + (:EnzymeGradientUtilsNewFromOriginal, libEnzyme), + LLVMValueRef, + (EnzymeGradientUtilsRef, LLVMValueRef), + gutils, + val, +) +EnzymeGradientUtilsSetDebugLocFromOriginal(gutils, val, orig) = ccall( + (:EnzymeGradientUtilsSetDebugLocFromOriginal, libEnzyme), + Cvoid, + (EnzymeGradientUtilsRef, LLVMValueRef, LLVMValueRef), + gutils, + val, + orig, +) +EnzymeGradientUtilsLookup(gutils, val, B) = ccall( + (:EnzymeGradientUtilsLookup, libEnzyme), + LLVMValueRef, + (EnzymeGradientUtilsRef, LLVMValueRef, LLVM.API.LLVMBuilderRef), + gutils, + val, + B, +) +EnzymeGradientUtilsInvertPointer(gutils, val, B) = ccall( + (:EnzymeGradientUtilsInvertPointer, libEnzyme), + LLVMValueRef, + (EnzymeGradientUtilsRef, LLVMValueRef, LLVM.API.LLVMBuilderRef), + gutils, + val, + B, +) +EnzymeGradientUtilsDiffe(gutils, val, B) = ccall( + (:EnzymeGradientUtilsDiffe, libEnzyme), + LLVMValueRef, + (EnzymeGradientUtilsRef, LLVMValueRef, LLVM.API.LLVMBuilderRef), + gutils, + val, + B, +) +EnzymeGradientUtilsAddToDiffe(gutils, val, diffe, B, T) = ccall( + (:EnzymeGradientUtilsAddToDiffe, libEnzyme), + Cvoid, + ( + EnzymeGradientUtilsRef, + LLVMValueRef, + LLVMValueRef, + LLVM.API.LLVMBuilderRef, + LLVMTypeRef, + ), + gutils, + val, + diffe, + B, + T, +) +function EnzymeGradientUtilsAddToInvertedPointerDiffeTT( + gutils, + orig, + origVal, + vd, + size, + origptr, + prediff, + B, + align, + premask, +) + ccall( + (:EnzymeGradientUtilsAddToInvertedPointerDiffeTT, libEnzyme), + Cvoid, + ( + EnzymeGradientUtilsRef, + LLVMValueRef, + LLVMValueRef, + CTypeTreeRef, + Cuint, + LLVMValueRef, + LLVMValueRef, + LLVM.API.LLVMBuilderRef, + Cuint, + LLVMValueRef, + ), + gutils, + orig, + origVal, + vd, + size, + origptr, + prediff, + B, + align, + premask, + ) +end + +EnzymeGradientUtilsSetDiffe(gutils, val, diffe, B) = ccall( + (:EnzymeGradientUtilsSetDiffe, libEnzyme), + Cvoid, + (EnzymeGradientUtilsRef, LLVMValueRef, LLVMValueRef, LLVM.API.LLVMBuilderRef), + gutils, + val, + diffe, + B, +) +EnzymeGradientUtilsIsConstantValue(gutils, val) = ccall( + (:EnzymeGradientUtilsIsConstantValue, libEnzyme), + UInt8, + (EnzymeGradientUtilsRef, LLVMValueRef), + gutils, + val, +) +EnzymeGradientUtilsIsConstantInstruction(gutils, val) = ccall( + (:EnzymeGradientUtilsIsConstantInstruction, libEnzyme), + UInt8, + (EnzymeGradientUtilsRef, LLVMValueRef), + gutils, + val, +) +EnzymeGradientUtilsAllocationBlock(gutils) = ccall( + (:EnzymeGradientUtilsAllocationBlock, libEnzyme), + LLVM.API.LLVMBasicBlockRef, + (EnzymeGradientUtilsRef,), + gutils, +) + +EnzymeGradientUtilsTypeAnalyzer(gutils) = ccall( + (:EnzymeGradientUtilsTypeAnalyzer, libEnzyme), + EnzymeTypeAnalyzerRef, + (EnzymeGradientUtilsRef,), + gutils, +) + +EnzymeGradientUtilsAllocAndGetTypeTree(gutils, val) = ccall( + (:EnzymeGradientUtilsAllocAndGetTypeTree, libEnzyme), + CTypeTreeRef, + (EnzymeGradientUtilsRef, LLVMValueRef), + gutils, + val, +) + +EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, size) = ccall( + (:EnzymeGradientUtilsGetUncacheableArgs, libEnzyme), + UInt8, + (EnzymeGradientUtilsRef, LLVMValueRef, Ptr{UInt8}, UInt64), + gutils, + orig, + uncacheable, + size, +) + +EnzymeGradientUtilsGetDiffeType(gutils, op, isforeign) = ccall( + (:EnzymeGradientUtilsGetDiffeType, libEnzyme), + CDIFFE_TYPE, + (EnzymeGradientUtilsRef, LLVMValueRef, UInt8), + gutils, + op, + isforeign, +) + +EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, mode) = + ccall( + (:EnzymeGradientUtilsGetReturnDiffeType, libEnzyme), + CDIFFE_TYPE, + (EnzymeGradientUtilsRef, LLVMValueRef, Ptr{UInt8}, Ptr{UInt8}, CDerivativeMode), + gutils, + orig, + needsPrimalP, + needsShadowP, + mode, + ) + +EnzymeGradientUtilsSubTransferHelper( + gutils, + mode, + secretty, + intrinsic, + dstAlign, + srcAlign, + offset, + dstConstant, + origdst, + srcConstant, + origsrc, + length, + isVolatile, + MTI, + allowForward, + shadowsLookedUp, +) = ccall( + (:EnzymeGradientUtilsSubTransferHelper, libEnzyme), + Cvoid, + ( + EnzymeGradientUtilsRef, + CDerivativeMode, + LLVMTypeRef, + UInt64, + UInt64, + UInt64, + UInt64, + UInt8, + LLVMValueRef, + UInt8, + LLVMValueRef, + LLVMValueRef, + LLVMValueRef, + LLVMValueRef, + UInt8, + UInt8, + ), + gutils, + mode, + secretty, + intrinsic, + dstAlign, + srcAlign, + offset, + dstConstant, + origdst, + srcConstant, + origsrc, + length, + isVolatile, + MTI, + allowForward, + shadowsLookedUp, +) + +EnzymeGradientUtilsCallWithInvertedBundles( + gutils, + func, + funcTy, + argvs, + argc, + orig, + valTys, + valCnt, + B, + lookup, +) = ccall( + (:EnzymeGradientUtilsCallWithInvertedBundles, libEnzyme), + LLVMValueRef, + ( + EnzymeGradientUtilsRef, + LLVMValueRef, + LLVMTypeRef, + Ptr{LLVMValueRef}, + UInt64, + LLVMValueRef, + Ptr{CValueType}, + UInt64, + LLVM.API.LLVMBuilderRef, + UInt8, + ), + gutils, + func, + funcTy, + argvs, + argc, + orig, + valTys, + valCnt, + B, + lookup, +) + +function sub_transfer( + gutils, + mode, + secretty, + intrinsic, + dstAlign, + srcAlign, + offset, + dstConstant, + origdst, + srcConstant, + origsrc, + length, + isVolatile, + MTI, + allowForward, + shadowsLookedUp, +) GC.@preserve secretty begin if secretty === nothing secretty = Base.unsafe_convert(LLVMTypeRef, C_NULL) @@ -293,15 +879,37 @@ function sub_transfer(gutils, mode, secretty, intrinsic, dstAlign, srcAlign, off secretty = Base.unsafe_convert(LLVMTypeRef, secretty) end - EnzymeGradientUtilsSubTransferHelper(gutils, mode, secretty, intrinsic, dstAlign, srcAlign, offset, dstConstant, origdst, srcConstant, origsrc, length, isVolatile, MTI, allowForward, shadowsLookedUp) + EnzymeGradientUtilsSubTransferHelper( + gutils, + mode, + secretty, + intrinsic, + dstAlign, + srcAlign, + offset, + dstConstant, + origdst, + srcConstant, + origsrc, + length, + isVolatile, + MTI, + allowForward, + shadowsLookedUp, + ) end end -function CreateLogic(postOpt=false) +function CreateLogic(postOpt = false) ccall((:CreateEnzymeLogic, libEnzyme), EnzymeLogicRef, (UInt8,), postOpt) end -EnzymeLogicErasePreprocessedFunctions(logic) = ccall((:EnzymeLogicErasePreprocessedFunctions, libEnzyme), Cvoid, (EnzymeLogicRef,), logic) +EnzymeLogicErasePreprocessedFunctions(logic) = ccall( + (:EnzymeLogicErasePreprocessedFunctions, libEnzyme), + Cvoid, + (EnzymeLogicRef,), + logic, +) function ClearLogic(logic) ccall((:ClearEnzymeLogic, libEnzyme), Cvoid, (EnzymeLogicRef,), logic) @@ -313,22 +921,43 @@ end function EnzymeExtractReturnInfo(ret, data, existed) @assert length(data) == length(existed) - ccall((:EnzymeExtractReturnInfo, libEnzyme), - Cvoid, (EnzymeAugmentedReturnPtr, Ptr{Int64}, Ptr{UInt8}, Csize_t), - ret, data, existed, length(data)) + ccall( + (:EnzymeExtractReturnInfo, libEnzyme), + Cvoid, + (EnzymeAugmentedReturnPtr, Ptr{Int64}, Ptr{UInt8}, Csize_t), + ret, + data, + existed, + length(data), + ) end function EnzymeExtractFunctionFromAugmentation(ret) - ccall((:EnzymeExtractFunctionFromAugmentation, libEnzyme), LLVMValueRef, (EnzymeAugmentedReturnPtr,), ret) + ccall( + (:EnzymeExtractFunctionFromAugmentation, libEnzyme), + LLVMValueRef, + (EnzymeAugmentedReturnPtr,), + ret, + ) end function EnzymeExtractTapeTypeFromAugmentation(ret) - ccall((:EnzymeExtractTapeTypeFromAugmentation, libEnzyme), LLVMTypeRef, (EnzymeAugmentedReturnPtr,), ret) + ccall( + (:EnzymeExtractTapeTypeFromAugmentation, libEnzyme), + LLVMTypeRef, + (EnzymeAugmentedReturnPtr,), + ret, + ) end function EnzymeExtractUnderlyingTapeTypeFromAugmentation(ret) - ccall((:EnzymeExtractUnderlyingTapeTypeFromAugmentation, libEnzyme), LLVMTypeRef, (EnzymeAugmentedReturnPtr,), ret) + ccall( + (:EnzymeExtractUnderlyingTapeTypeFromAugmentation, libEnzyme), + LLVMTypeRef, + (EnzymeAugmentedReturnPtr,), + ret, + ) end import Libdl @@ -598,28 +1227,44 @@ function EnzymeRemoveTrivialAtomicIncrements(func) end function EnzymeAddAttributorLegacyPass(PM) - ccall((:EnzymeAddAttributorLegacyPass, libEnzyme),Cvoid,(LLVM.API.LLVMPassManagerRef,), PM) -end - -@cenum(ErrorType, - ET_NoDerivative = 0, - ET_NoShadow = 1, - ET_IllegalTypeAnalysis = 2, - ET_NoType = 3, - ET_IllegalFirstPointer = 4, - ET_InternalError = 5, - ET_TypeDepthExceeded = 6, - ET_MixedActivityError = 7, - ET_IllegalReplaceFicticiousPHIs = 8, - ET_GetIndexError = 9 + ccall( + (:EnzymeAddAttributorLegacyPass, libEnzyme), + Cvoid, + (LLVM.API.LLVMPassManagerRef,), + PM, + ) +end + +@cenum( + ErrorType, + ET_NoDerivative = 0, + ET_NoShadow = 1, + ET_IllegalTypeAnalysis = 2, + ET_NoType = 3, + ET_IllegalFirstPointer = 4, + ET_InternalError = 5, + ET_TypeDepthExceeded = 6, + ET_MixedActivityError = 7, + ET_IllegalReplaceFicticiousPHIs = 8, + ET_GetIndexError = 9 ) function EnzymeTypeAnalyzerToString(typeanalyzer) - ccall((:EnzymeTypeAnalyzerToString, libEnzyme), Cstring, (EnzymeTypeAnalyzerRef,), typeanalyzer) + ccall( + (:EnzymeTypeAnalyzerToString, libEnzyme), + Cstring, + (EnzymeTypeAnalyzerRef,), + typeanalyzer, + ) end function EnzymeGradientUtilsInvertedPointersToString(gutils) - ccall((:EnzymeGradientUtilsInvertedPointersToString, libEnzyme), Cstring, (Ptr{Cvoid},), gutils) + ccall( + (:EnzymeGradientUtilsInvertedPointersToString, libEnzyme), + Cstring, + (Ptr{Cvoid},), + gutils, + ) end function EnzymeSetHandler(handler) @@ -694,60 +1339,162 @@ function __init__() end function moveBefore(i1, i2, BR) - ccall((:EnzymeMoveBefore, libEnzyme),Cvoid,(LLVM.API.LLVMValueRef,LLVM.API.LLVMValueRef, LLVM.API.LLVMBuilderRef), i1, i2, BR) + ccall( + (:EnzymeMoveBefore, libEnzyme), + Cvoid, + (LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMBuilderRef), + i1, + i2, + BR, + ) end function EnzymeCloneFunctionDISubprogramInto(i1, i2) - ccall((:EnzymeCloneFunctionDISubprogramInto, libEnzyme),Cvoid,(LLVM.API.LLVMValueRef,LLVM.API.LLVMValueRef), i1, i2) + ccall( + (:EnzymeCloneFunctionDISubprogramInto, libEnzyme), + Cvoid, + (LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef), + i1, + i2, + ) end function EnzymeCopyMetadata(i1, i2) - ccall((:EnzymeCopyMetadata, libEnzyme),Cvoid,(LLVM.API.LLVMValueRef,LLVM.API.LLVMValueRef), i1, i2) + ccall( + (:EnzymeCopyMetadata, libEnzyme), + Cvoid, + (LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef), + i1, + i2, + ) end function SetMustCache!(i1) - ccall((:EnzymeSetMustCache, libEnzyme),Cvoid,(LLVM.API.LLVMValueRef,), i1) + ccall((:EnzymeSetMustCache, libEnzyme), Cvoid, (LLVM.API.LLVMValueRef,), i1) end function SetForMemSet!(i1) - ccall((:EnzymeSetForMemSet, libEnzyme),Cvoid,(LLVM.API.LLVMValueRef,), i1) + ccall((:EnzymeSetForMemSet, libEnzyme), Cvoid, (LLVM.API.LLVMValueRef,), i1) end function HasFromStack(i1) - ccall((:EnzymeHasFromStack, libEnzyme),UInt8,(LLVM.API.LLVMValueRef,), i1) != 0 + ccall((:EnzymeHasFromStack, libEnzyme), UInt8, (LLVM.API.LLVMValueRef,), i1) != 0 end function AddPreserveNVVMPass!(pm, i8) - ccall((:AddPreserveNVVMPass, libEnzyme),Cvoid,(LLVM.API.LLVMPassManagerRef,UInt8), pm, i8) + ccall( + (:AddPreserveNVVMPass, libEnzyme), + Cvoid, + (LLVM.API.LLVMPassManagerRef, UInt8), + pm, + i8, + ) end function EnzymeReplaceFunctionImplementation(mod) - ccall((:EnzymeReplaceFunctionImplementation, libEnzyme),Cvoid,(LLVM.API.LLVMModuleRef,), mod) + ccall( + (:EnzymeReplaceFunctionImplementation, libEnzyme), + Cvoid, + (LLVM.API.LLVMModuleRef,), + mod, + ) end function EnzymeDumpModuleRef(mod) - ccall((:EnzymeDumpModuleRef, libEnzyme),Cvoid,(LLVM.API.LLVMModuleRef,), mod) -end - -EnzymeComputeByteOffsetOfGEP(B, V, T) = LLVM.Value(ccall((:EnzymeComputeByteOffsetOfGEP, libEnzyme), LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMTypeRef), B, V, T)) - -EnzymeAllocaType(al) = LLVM.LLVMType(ccall((:EnzymeAllocaType, libEnzyme), LLVM.API.LLVMTypeRef, (LLVM.API.LLVMValueRef,), al)) - -EnzymeAttributeKnownFunctions(f) = ccall((:EnzymeAttributeKnownFunctions, libEnzyme), Cvoid, (LLVM.API.LLVMValueRef,), f) + ccall((:EnzymeDumpModuleRef, libEnzyme), Cvoid, (LLVM.API.LLVMModuleRef,), mod) +end + +EnzymeComputeByteOffsetOfGEP(B, V, T) = LLVM.Value( + ccall( + (:EnzymeComputeByteOffsetOfGEP, libEnzyme), + LLVM.API.LLVMValueRef, + (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMTypeRef), + B, + V, + T, + ), +) -EnzymeAnonymousAliasScopeDomain(str, ctx) = LLVM.Metadata(ccall((:EnzymeAnonymousAliasScopeDomain, libEnzyme), LLVM.API.LLVMMetadataRef, (Cstring,LLVMContextRef), str, ctx)) -EnzymeAnonymousAliasScope(dom::LLVM.Metadata, str) = LLVM.Metadata(ccall((:EnzymeAnonymousAliasScope, libEnzyme), LLVM.API.LLVMMetadataRef, (LLVM.API.LLVMMetadataRef,Cstring), dom.ref, str)) -EnzymeFixupJuliaCallingConvention(f) = ccall((:EnzymeFixupJuliaCallingConvention, libEnzyme), Cvoid, (LLVM.API.LLVMValueRef,), f) -EnzymeFixupBatchedJuliaCallingConvention(f) = ccall((:EnzymeFixupBatchedJuliaCallingConvention, libEnzyme), Cvoid, (LLVM.API.LLVMValueRef,), f) +EnzymeAllocaType(al) = LLVM.LLVMType( + ccall( + (:EnzymeAllocaType, libEnzyme), + LLVM.API.LLVMTypeRef, + (LLVM.API.LLVMValueRef,), + al, + ), +) -e_extract_value!(builder, AggVal, Index, Name::String="") = - GC.@preserve Index begin - LLVM.Value(ccall((:EnzymeBuildExtractValue, libEnzyme), LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, Ptr{Cuint}, Cuint, Cstring), builder, AggVal, Index, length(Index), Name)) - end +EnzymeAttributeKnownFunctions(f) = + ccall((:EnzymeAttributeKnownFunctions, libEnzyme), Cvoid, (LLVM.API.LLVMValueRef,), f) + +EnzymeAnonymousAliasScopeDomain(str, ctx) = LLVM.Metadata( + ccall( + (:EnzymeAnonymousAliasScopeDomain, libEnzyme), + LLVM.API.LLVMMetadataRef, + (Cstring, LLVMContextRef), + str, + ctx, + ), +) +EnzymeAnonymousAliasScope(dom::LLVM.Metadata, str) = LLVM.Metadata( + ccall( + (:EnzymeAnonymousAliasScope, libEnzyme), + LLVM.API.LLVMMetadataRef, + (LLVM.API.LLVMMetadataRef, Cstring), + dom.ref, + str, + ), +) +EnzymeFixupJuliaCallingConvention(f) = ccall( + (:EnzymeFixupJuliaCallingConvention, libEnzyme), + Cvoid, + (LLVM.API.LLVMValueRef,), + f, +) +EnzymeFixupBatchedJuliaCallingConvention(f) = ccall( + (:EnzymeFixupBatchedJuliaCallingConvention, libEnzyme), + Cvoid, + (LLVM.API.LLVMValueRef,), + f, +) -e_insert_value!(builder, AggVal, EltVal, Index, Name::String="") = - GC.@preserve Index begin - LLVM.Value(ccall((:EnzymeBuildInsertValue, libEnzyme), LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef, Ptr{Cuint}, Cuint, Cstring), builder, AggVal, EltVal, Index, length(Index), Name)) - end +e_extract_value!(builder, AggVal, Index, Name::String = "") = GC.@preserve Index begin + LLVM.Value( + ccall( + (:EnzymeBuildExtractValue, libEnzyme), + LLVM.API.LLVMValueRef, + (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, Ptr{Cuint}, Cuint, Cstring), + builder, + AggVal, + Index, + length(Index), + Name, + ), + ) +end + +e_insert_value!(builder, AggVal, EltVal, Index, Name::String = "") = + GC.@preserve Index begin + LLVM.Value( + ccall( + (:EnzymeBuildInsertValue, libEnzyme), + LLVM.API.LLVMValueRef, + ( + LLVM.API.LLVMBuilderRef, + LLVM.API.LLVMValueRef, + LLVM.API.LLVMValueRef, + Ptr{Cuint}, + Cuint, + Cstring, + ), + builder, + AggVal, + EltVal, + Index, + length(Index), + Name, + ), + ) + end end diff --git a/src/compiler.jl b/src/compiler.jl index ce51a6e7f53..f3680cc4c05 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1,12 +1,32 @@ module Compiler import ..Enzyme -import Enzyme: Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, - BatchDuplicatedNoNeed, - BatchDuplicatedFunc, - Annotation, guess_activity, eltype, - API, TypeTree, typetree, TypeTreeTable, only!, shift!, data0!, merge!, to_md, to_fullmd, - TypeAnalysis, FnTypeInfo, Logic, allocatedinline, ismutabletype +import Enzyme: + Const, + Active, + Duplicated, + DuplicatedNoNeed, + BatchDuplicated, + BatchDuplicatedNoNeed, + BatchDuplicatedFunc, + Annotation, + guess_activity, + eltype, + API, + TypeTree, + typetree, + TypeTreeTable, + only!, + shift!, + data0!, + merge!, + to_md, + to_fullmd, + TypeAnalysis, + FnTypeInfo, + Logic, + allocatedinline, + ismutabletype using Enzyme import EnzymeCore @@ -45,12 +65,11 @@ include("gradientutils.jl") # Julia function to LLVM stem and arity const cmplx_known_ops = -Dict{DataType, Tuple{Symbol, Int, Union{Nothing, Tuple{Symbol, DataType}}}}( - typeof(Base.inv) => (:cmplx_inv, 1, nothing), - typeof(Base.sqrt) => (:cmplx_sqrt, 1, nothing), - ) -const known_ops = -Dict{DataType, Tuple{Symbol, Int, Union{Nothing, Tuple{Symbol, DataType}}}}( + Dict{DataType,Tuple{Symbol,Int,Union{Nothing,Tuple{Symbol,DataType}}}}( + typeof(Base.inv) => (:cmplx_inv, 1, nothing), + typeof(Base.sqrt) => (:cmplx_sqrt, 1, nothing), + ) +const known_ops = Dict{DataType,Tuple{Symbol,Int,Union{Nothing,Tuple{Symbol,DataType}}}}( typeof(Base.cbrt) => (:cbrt, 1, nothing), typeof(Base.rem2pi) => (:jl_rem2pi, 2, nothing), typeof(Base.sqrt) => (:sqrt, 1, nothing), @@ -85,7 +104,7 @@ Dict{DataType, Tuple{Symbol, Int, Union{Nothing, Tuple{Symbol, DataType}}}}( typeof(Base.tanh) => (:tanh, 1, nothing), typeof(Base.ldexp) => (:ldexp, 2, nothing), typeof(Base.FastMath.tanh_fast) => (:tanh, 1, nothing), - typeof(Base.fma_emulated) => (:fma, 3, nothing) + typeof(Base.fma_emulated) => (:fma, 3, nothing), ) @inline function find_math_method(@nospecialize(func), sparam_vals) if func ∈ keys(known_ops) @@ -118,7 +137,7 @@ Dict{DataType, Tuple{Symbol, Int, Union{Nothing, Tuple{Symbol, DataType}}}}( return name, toinject, T end end - end + end if func ∈ keys(cmplx_known_ops) name, arity, toinject = cmplx_known_ops[func] @@ -146,67 +165,123 @@ const nofreefns = Set{String}(( "pcre2_jit_stack_assign_8", "pcre2_match_context_create_8", "pcre2_jit_stack_create_8", - "ijl_gc_enable_finalizers_internal", "jl_gc_enable_finalizers_internal", + "ijl_gc_enable_finalizers_internal", + "jl_gc_enable_finalizers_internal", "pcre2_match_data_create_from_pattern_8", - "ijl_gc_run_pending_finalizers", "jl_gc_run_pending_finalizers", - "ijl_typeassert", "jl_typeassert", - "ijl_f_isdefined", "jl_f_isdefined", - "ijl_field_index", "jl_field_index", - "ijl_specializations_get_linfo", "jl_specializations_get_linfo", - "ijl_gf_invoke_lookup_worlds", "jl_gf_invoke_lookup_worlds", - "ijl_gc_get_total_bytes", "jl_gc_get_total_bytes", - "ijl_array_grow_at", "jl_array_grow_at", - "ijl_try_substrtod", "jl_try_substrtod", + "ijl_gc_run_pending_finalizers", + "jl_gc_run_pending_finalizers", + "ijl_typeassert", + "jl_typeassert", + "ijl_f_isdefined", + "jl_f_isdefined", + "ijl_field_index", + "jl_field_index", + "ijl_specializations_get_linfo", + "jl_specializations_get_linfo", + "ijl_gf_invoke_lookup_worlds", + "jl_gf_invoke_lookup_worlds", + "ijl_gc_get_total_bytes", + "jl_gc_get_total_bytes", + "ijl_array_grow_at", + "jl_array_grow_at", + "ijl_try_substrtod", + "jl_try_substrtod", "jl_f__apply_iterate", - "ijl_field_index", "jl_field_index", - "julia.call", "julia.call2", - "ijl_tagged_gensym", "jl_tagged_gensym", - "ijl_array_ptr_copy", "jl_array_ptr_copy", - "ijl_array_copy", "jl_array_copy", - "ijl_get_nth_field_checked", "ijl_get_nth_field_checked", - "jl_array_del_end","ijl_array_del_end", - "jl_get_world_counter", "ijl_get_world_counter", - "memhash32_seed", "memhash_seed", - "ijl_module_parent", "jl_module_parent", + "ijl_field_index", + "jl_field_index", + "julia.call", + "julia.call2", + "ijl_tagged_gensym", + "jl_tagged_gensym", + "ijl_array_ptr_copy", + "jl_array_ptr_copy", + "ijl_array_copy", + "jl_array_copy", + "ijl_get_nth_field_checked", + "ijl_get_nth_field_checked", + "jl_array_del_end", + "ijl_array_del_end", + "jl_get_world_counter", + "ijl_get_world_counter", + "memhash32_seed", + "memhash_seed", + "ijl_module_parent", + "jl_module_parent", "julia.safepoint", - "ijl_set_task_tid", "jl_set_task_tid", - "ijl_get_task_tid", "jl_get_task_tid", + "ijl_set_task_tid", + "jl_set_task_tid", + "ijl_get_task_tid", + "jl_get_task_tid", "julia.get_pgcstack_or_new", - "ijl_global_event_loop", "jl_global_event_loop", - "ijl_gf_invoke_lookup", "jl_gf_invoke_lookup", - "ijl_f_typeassert", "jl_f_typeassert", - "ijl_type_unionall", "jl_type_unionall", - "jl_gc_queue_root", "gpu_report_exception", "gpu_signal_exception", - "julia.ptls_states", "julia.write_barrier", "julia.typeof", - "jl_backtrace_from_here", "ijl_backtrace_from_here", - "jl_box_int64", "jl_box_int32", - "ijl_box_int64", "ijl_box_int32", - "jl_box_uint64", "jl_box_uint32", - "ijl_box_uint64", "ijl_box_uint32", - "ijl_box_char", "jl_box_char", + "ijl_global_event_loop", + "jl_global_event_loop", + "ijl_gf_invoke_lookup", + "jl_gf_invoke_lookup", + "ijl_f_typeassert", + "jl_f_typeassert", + "ijl_type_unionall", + "jl_type_unionall", + "jl_gc_queue_root", + "gpu_report_exception", + "gpu_signal_exception", + "julia.ptls_states", + "julia.write_barrier", + "julia.typeof", + "jl_backtrace_from_here", + "ijl_backtrace_from_here", + "jl_box_int64", + "jl_box_int32", + "ijl_box_int64", + "ijl_box_int32", + "jl_box_uint64", + "jl_box_uint32", + "ijl_box_uint64", + "ijl_box_uint32", + "ijl_box_char", + "jl_box_char", "ijl_subtype", - "jl_subtype", "julia.get_pgcstack", "jl_in_threaded_region", - "jl_object_id_", "jl_object_id", "ijl_object_id_", "ijl_object_id", + "jl_subtype", + "julia.get_pgcstack", + "jl_in_threaded_region", + "jl_object_id_", + "jl_object_id", + "ijl_object_id_", + "ijl_object_id", "jl_breakpoint", - "llvm.julia.gc_preserve_begin","llvm.julia.gc_preserve_end", + "llvm.julia.gc_preserve_begin", + "llvm.julia.gc_preserve_end", "jl_get_ptls_states", "ijl_get_ptls_states", "jl_f_fieldtype", "jl_symbol_n", - "jl_stored_inline", "ijl_stored_inline", - "jl_f_apply_type", "jl_f_issubtype", - "jl_isa", "ijl_isa", - "jl_matching_methods", "ijl_matching_methods", - "jl_excstack_state", "ijl_excstack_state", - "jl_current_exception", "ijl_current_exception", + "jl_stored_inline", + "ijl_stored_inline", + "jl_f_apply_type", + "jl_f_issubtype", + "jl_isa", + "ijl_isa", + "jl_matching_methods", + "ijl_matching_methods", + "jl_excstack_state", + "ijl_excstack_state", + "jl_current_exception", + "ijl_current_exception", "memhash_seed", - "jl_f__typevar", "ijl_f__typevar", - "jl_f_isa", "ijl_f_isa", - "jl_set_task_threadpoolid", "ijl_set_task_threadpoolid", - "jl_types_equal", "ijl_types_equal", - "jl_invoke", "ijl_invoke", - "jl_apply_generic", "ijl_apply_generic", - "jl_egal__unboxed", "julia.pointer_from_objref", "_platform_memcmp", + "jl_f__typevar", + "ijl_f__typevar", + "jl_f_isa", + "ijl_f_isa", + "jl_set_task_threadpoolid", + "ijl_set_task_threadpoolid", + "jl_types_equal", + "ijl_types_equal", + "jl_invoke", + "ijl_invoke", + "jl_apply_generic", + "ijl_apply_generic", + "jl_egal__unboxed", + "julia.pointer_from_objref", + "_platform_memcmp", "memcmp", "julia.except_enter", "jl_array_grow_end", @@ -233,53 +308,96 @@ const nofreefns = Set{String}(( const inactivefns = Set{String}(( "pcre2_match_data_create_from_pattern_8", - "ijl_typeassert", "jl_typeassert", - "ijl_f_isdefined", "jl_f_isdefined", - "ijl_field_index", "jl_field_index", - "ijl_specializations_get_linfo", "jl_specializations_get_linfo", - "ijl_gf_invoke_lookup_worlds", "jl_gf_invoke_lookup_worlds", - "ijl_gc_get_total_bytes", "jl_gc_get_total_bytes", - "ijl_try_substrtod", "jl_try_substrtod", - "ijl_tagged_gensym", "jl_tagged_gensym", - "jl_get_world_counter", "ijl_get_world_counter", - "memhash32_seed", "memhash_seed", - "ijl_module_parent", "jl_module_parent", + "ijl_typeassert", + "jl_typeassert", + "ijl_f_isdefined", + "jl_f_isdefined", + "ijl_field_index", + "jl_field_index", + "ijl_specializations_get_linfo", + "jl_specializations_get_linfo", + "ijl_gf_invoke_lookup_worlds", + "jl_gf_invoke_lookup_worlds", + "ijl_gc_get_total_bytes", + "jl_gc_get_total_bytes", + "ijl_try_substrtod", + "jl_try_substrtod", + "ijl_tagged_gensym", + "jl_tagged_gensym", + "jl_get_world_counter", + "ijl_get_world_counter", + "memhash32_seed", + "memhash_seed", + "ijl_module_parent", + "jl_module_parent", "julia.safepoint", - "ijl_set_task_tid", "jl_set_task_tid", - "ijl_get_task_tid", "jl_get_task_tid", + "ijl_set_task_tid", + "jl_set_task_tid", + "ijl_get_task_tid", + "jl_get_task_tid", "julia.get_pgcstack_or_new", - "ijl_global_event_loop", "jl_global_event_loop", - "ijl_gf_invoke_lookup", "jl_gf_invoke_lookup", - "ijl_f_typeassert", "jl_f_typeassert", - "ijl_type_unionall", "jl_type_unionall", - "jl_gc_queue_root", "gpu_report_exception", "gpu_signal_exception", - "julia.ptls_states", "julia.write_barrier", "julia.typeof", - "jl_backtrace_from_here", "ijl_backtrace_from_here", - "jl_box_int64", "jl_box_int32", - "ijl_box_int64", "ijl_box_int32", - "jl_box_uint64", "jl_box_uint32", - "ijl_box_uint64", "ijl_box_uint32", - "ijl_box_char", "jl_box_char", + "ijl_global_event_loop", + "jl_global_event_loop", + "ijl_gf_invoke_lookup", + "jl_gf_invoke_lookup", + "ijl_f_typeassert", + "jl_f_typeassert", + "ijl_type_unionall", + "jl_type_unionall", + "jl_gc_queue_root", + "gpu_report_exception", + "gpu_signal_exception", + "julia.ptls_states", + "julia.write_barrier", + "julia.typeof", + "jl_backtrace_from_here", + "ijl_backtrace_from_here", + "jl_box_int64", + "jl_box_int32", + "ijl_box_int64", + "ijl_box_int32", + "jl_box_uint64", + "jl_box_uint32", + "ijl_box_uint64", + "ijl_box_uint32", + "ijl_box_char", + "jl_box_char", "ijl_subtype", - "jl_subtype", "julia.get_pgcstack", "jl_in_threaded_region", - "jl_object_id_", "jl_object_id", "ijl_object_id_", "ijl_object_id", + "jl_subtype", + "julia.get_pgcstack", + "jl_in_threaded_region", + "jl_object_id_", + "jl_object_id", + "ijl_object_id_", + "ijl_object_id", "jl_breakpoint", - "llvm.julia.gc_preserve_begin","llvm.julia.gc_preserve_end", + "llvm.julia.gc_preserve_begin", + "llvm.julia.gc_preserve_end", "jl_get_ptls_states", "ijl_get_ptls_states", "jl_f_fieldtype", "jl_symbol_n", - "jl_stored_inline", "ijl_stored_inline", - "jl_f_apply_type", "jl_f_issubtype", - "jl_isa", "ijl_isa", - "jl_matching_methods", "ijl_matching_methods", - "jl_excstack_state", "ijl_excstack_state", - "jl_current_exception", "ijl_current_exception", + "jl_stored_inline", + "ijl_stored_inline", + "jl_f_apply_type", + "jl_f_issubtype", + "jl_isa", + "ijl_isa", + "jl_matching_methods", + "ijl_matching_methods", + "jl_excstack_state", + "ijl_excstack_state", + "jl_current_exception", + "ijl_current_exception", "memhash_seed", - "jl_f__typevar", "ijl_f__typevar", - "jl_f_isa", "ijl_f_isa", - "jl_set_task_threadpoolid", "ijl_set_task_threadpoolid", - "jl_types_equal", "ijl_types_equal", + "jl_f__typevar", + "ijl_f__typevar", + "jl_f_isa", + "ijl_f_isa", + "jl_set_task_threadpoolid", + "ijl_set_task_threadpoolid", + "jl_types_equal", + "ijl_types_equal", "jl_string_to_array", "ijl_string_to_array", "jl_alloc_string", @@ -292,13 +410,11 @@ const inactivefns = Set{String}(( "uv_os_homedir", "jl_array_to_string", "ijl_array_to_string", - "pcre2_jit_compile_8" + "pcre2_jit_compile_8", # "jl_" )) -const activefns = Set{String}(( - "jl_", -)) +const activefns = Set{String}(("jl_",)) const inactiveglobs = Set{String}(( "ijl_boxed_uint8_cache", @@ -322,7 +438,7 @@ struct Merger{seen,worldT,justActive,UnionSret,AbstractIsMixed} world::worldT end -@inline element(::Val{T}) where T = T +@inline element(::Val{T}) where {T} = T # From https://github.com/JuliaLang/julia/blob/81813164963f38dcd779d65ecd222fad8d7ed437/src/cgutils.cpp#L570 @inline function isghostty(ty) @@ -338,7 +454,9 @@ end return false end -@inline function (c::Merger{seen,worldT,justActive,UnionSret,AbstractIsMixed})(f::Int) where {seen,worldT,justActive,UnionSret,AbstractIsMixed} +@inline function (c::Merger{seen,worldT,justActive,UnionSret,AbstractIsMixed})( + f::Int, +) where {seen,worldT,justActive,UnionSret,AbstractIsMixed} T = element(first(seen)) reftype = ismutabletype(T) || (T isa UnionAll && !AbstractIsMixed) @@ -353,7 +471,14 @@ end return Val(AnyState) end - sub = active_reg_inner(subT, seen, c.world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed)) + sub = active_reg_inner( + subT, + seen, + c.world, + Val(justActive), + Val(UnionSret), + Val(AbstractIsMixed), + ) if sub == AnyState Val(AnyState) @@ -374,9 +499,9 @@ end end end -@inline forcefold(::Val{RT}) where RT = RT +@inline forcefold(::Val{RT}) where {RT} = RT -@inline function forcefold(::Val{ty}, ::Val{sty}, C::Vararg{Any, N}) where {ty, sty, N} +@inline function forcefold(::Val{ty}, ::Val{sty}, C::Vararg{Any,N}) where {ty,sty,N} if sty == AnyState || sty == ty return forcefold(Val(ty), C...) end @@ -387,50 +512,92 @@ end end end -@inline ptreltype(::Type{Ptr{T}}) where T = T +@inline ptreltype(::Type{Ptr{T}}) where {T} = T @inline ptreltype(::Type{Core.LLVMPtr{T,N}}) where {T,N} = T @inline ptreltype(::Type{Core.LLVMPtr{T} where N}) where {T} = T -@inline ptreltype(::Type{Base.RefValue{T}}) where T = T +@inline ptreltype(::Type{Base.RefValue{T}}) where {T} = T @inline ptreltype(::Type{Array{T,N}}) where {T,N} = T -@inline ptreltype(::Type{Array{T, N} where N}) where {T} = T -@inline ptreltype(::Type{Complex{T}}) where T = T -@inline ptreltype(::Type{Tuple{Vararg{T}}}) where T = T -@inline ptreltype(::Type{IdDict{K, V}}) where {K, V} = V -@inline ptreltype(::Type{IdDict{K, V} where K}) where {V} = V +@inline ptreltype(::Type{Array{T,N} where N}) where {T} = T +@inline ptreltype(::Type{Complex{T}}) where {T} = T +@inline ptreltype(::Type{Tuple{Vararg{T}}}) where {T} = T +@inline ptreltype(::Type{IdDict{K,V}}) where {K,V} = V +@inline ptreltype(::Type{IdDict{K,V} where K}) where {V} = V @inline is_arrayorvararg_ty(::Type) = false @inline is_arrayorvararg_ty(::Type{Array{T,N}}) where {T,N} = true -@inline is_arrayorvararg_ty(::Type{Array{T, N} where N}) where {T} = true -@inline is_arrayorvararg_ty(::Type{Tuple{Vararg{T2}}}) where T2 = true -@inline is_arrayorvararg_ty(::Type{Ptr{T}}) where T = true +@inline is_arrayorvararg_ty(::Type{Array{T,N} where N}) where {T} = true +@inline is_arrayorvararg_ty(::Type{Tuple{Vararg{T2}}}) where {T2} = true +@inline is_arrayorvararg_ty(::Type{Ptr{T}}) where {T} = true @inline is_arrayorvararg_ty(::Type{Core.LLVMPtr{T,N}}) where {T,N} = true @inline is_arrayorvararg_ty(::Type{Core.LLVMPtr{T,N} where N}) where {T} = true -@inline is_arrayorvararg_ty(::Type{Base.RefValue{T}}) where T = true -@inline is_arrayorvararg_ty(::Type{IdDict{K, V}}) where {K, V} = true -@inline is_arrayorvararg_ty(::Type{IdDict{K, V} where K}) where {V} = true +@inline is_arrayorvararg_ty(::Type{Base.RefValue{T}}) where {T} = true +@inline is_arrayorvararg_ty(::Type{IdDict{K,V}}) where {K,V} = true +@inline is_arrayorvararg_ty(::Type{IdDict{K,V} where K}) where {V} = true -@inline function datatype_fieldcount(t::Type{T}) where T +@inline function datatype_fieldcount(t::Type{T}) where {T} return Base.datatype_fieldcount(t) end -@inline function staticInTup(::Val{T}, tup::NTuple{N, Val}) where {T, N} +@inline function staticInTup(::Val{T}, tup::NTuple{N,Val}) where {T,N} any(ntuple(Val(N)) do i Base.@_inline_meta Val(T) == tup[i] end) end -@inline function active_reg_recur(::Type{ST}, seen::Seen, world, ::Val{justActive}, ::Val{UnionSret}, ::Val{AbstractIsMixed}) where {ST, Seen, justActive, UnionSret, AbstractIsMixed} +@inline function active_reg_recur( + ::Type{ST}, + seen::Seen, + world, + ::Val{justActive}, + ::Val{UnionSret}, + ::Val{AbstractIsMixed}, +) where {ST,Seen,justActive,UnionSret,AbstractIsMixed} if ST isa Union - return forcefold(Val(active_reg_recur(ST.a, seen, world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed))), Val(active_reg_recur(ST.b, seen, world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed)))) + return forcefold( + Val( + active_reg_recur( + ST.a, + seen, + world, + Val(justActive), + Val(UnionSret), + Val(AbstractIsMixed), + ), + ), + Val( + active_reg_recur( + ST.b, + seen, + world, + Val(justActive), + Val(UnionSret), + Val(AbstractIsMixed), + ), + ), + ) end - return active_reg_inner(ST, seen, world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed)) + return active_reg_inner( + ST, + seen, + world, + Val(justActive), + Val(UnionSret), + Val(AbstractIsMixed), + ) end @inline is_vararg_tup(x) = false -@inline is_vararg_tup(::Type{Tuple{Vararg{T2}}}) where T2 = true - -@inline function active_reg_inner(::Type{T}, seen::ST, world::Union{Nothing, UInt}, ::Val{justActive}=Val(false), ::Val{UnionSret}=Val(false), ::Val{AbstractIsMixed}=Val(false))::ActivityState where {ST,T, justActive, UnionSret, AbstractIsMixed} +@inline is_vararg_tup(::Type{Tuple{Vararg{T2}}}) where {T2} = true + +@inline function active_reg_inner( + ::Type{T}, + seen::ST, + world::Union{Nothing,UInt}, + ::Val{justActive} = Val(false), + ::Val{UnionSret} = Val(false), + ::Val{AbstractIsMixed} = Val(false), +)::ActivityState where {ST,T,justActive,UnionSret,AbstractIsMixed} if T === Any if AbstractIsMixed return MixedState @@ -444,7 +611,14 @@ end end if T <: Complex && !(T isa UnionAll) - return active_reg_inner(ptreltype(T), seen, world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed)) + return active_reg_inner( + ptreltype(T), + seen, + world, + Val(justActive), + Val(UnionSret), + Val(AbstractIsMixed), + ) end if T <: BigFloat @@ -455,12 +629,24 @@ end return ActiveState end - if T <: Ptr || T <: Core.LLVMPtr || T <: Base.RefValue || T <: Array || is_arrayorvararg_ty(T) + if T <: Ptr || + T <: Core.LLVMPtr || + T <: Base.RefValue || + T <: Array || + is_arrayorvararg_ty(T) if justActive return AnyState end - if is_arrayorvararg_ty(T) && active_reg_inner(ptreltype(T), seen, world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed)) == AnyState + if is_arrayorvararg_ty(T) && + active_reg_inner( + ptreltype(T), + seen, + world, + Val(justActive), + Val(UnionSret), + Val(AbstractIsMixed), + ) == AnyState return AnyState else if AbstractIsMixed && is_vararg_tup(T) @@ -482,10 +668,22 @@ end inactivety = if typeof(world) === Nothing EnzymeCore.EnzymeRules.inactive_type(T) else - inmi = GPUCompiler.methodinstance(typeof(EnzymeCore.EnzymeRules.inactive_type), Tuple{Type{T}}, world) - args = Any[EnzymeCore.EnzymeRules.inactive_type, T]; + inmi = GPUCompiler.methodinstance( + typeof(EnzymeCore.EnzymeRules.inactive_type), + Tuple{Type{T}}, + world, + ) + args = Any[EnzymeCore.EnzymeRules.inactive_type, T] GC.@preserve T begin - ccall(:jl_invoke, Any, (Any, Ptr{Any}, Cuint, Any), EnzymeCore.EnzymeRules.inactive_type, args, length(args), inmi) + ccall( + :jl_invoke, + Any, + (Any, Ptr{Any}, Cuint, Any), + EnzymeCore.EnzymeRules.inactive_type, + args, + length(args), + inmi, + ) end end @@ -516,19 +714,28 @@ end # if sret union, the data is stored in a stack memory location and is therefore # not unique'd preventing the boxing of the union in the default case if UnionSret && is_sret_union(T) - return active_reg_recur(T, seen, world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed)) + return active_reg_recur( + T, + seen, + world, + Val(justActive), + Val(UnionSret), + Val(AbstractIsMixed), + ) else if justActive return AnyState end - if active_reg_inner(T.a, seen, world, Val(justActive), Val(UnionSret)) != AnyState + if active_reg_inner(T.a, seen, world, Val(justActive), Val(UnionSret)) != + AnyState if AbstractIsMixed return MixedState else return DupState end end - if active_reg_inner(T.b, seen, world, Val(justActive), Val(UnionSret)) != AnyState + if active_reg_inner(T.b, seen, world, Val(justActive), Val(UnionSret)) != + AnyState if AbstractIsMixed return MixedState else @@ -562,17 +769,19 @@ end end nT = if T <: Tuple && T != Tuple && !(T isa UnionAll) - Tuple{(ntuple(length(T.parameters)) do i - Base.@_inline_meta - sT = T.parameters[i] - if sT isa TypeVar - Any - elseif sT isa Core.TypeofVararg - Any - else - sT + Tuple{( + ntuple(length(T.parameters)) do i + Base.@_inline_meta + sT = T.parameters[i] + if sT isa TypeVar + Any + elseif sT isa Core.TypeofVararg + Any + else + sT + end end - end)...} + )...} else T end @@ -583,40 +792,48 @@ end seen2 = (Val(nT), seen...) - fty = Merger{seen2,typeof(world),justActive, UnionSret, AbstractIsMixed}(world) + fty = Merger{seen2,typeof(world),justActive,UnionSret,AbstractIsMixed}(world) ty = forcefold(Val(AnyState), ntuple(fty, Val(fieldcount(nT)))...) return ty end -@inline @generated function active_reg_nothrow(::Type{T}, ::Val{world}) where {T, world} +@inline @generated function active_reg_nothrow(::Type{T}, ::Val{world}) where {T,world} return active_reg_inner(T, (), world) end -Base.@pure @inline function active_reg(::Type{T}, world::Union{Nothing, UInt}=nothing)::Bool where {T} +Base.@pure @inline function active_reg( + ::Type{T}, + world::Union{Nothing,UInt} = nothing, +)::Bool where {T} seen = () # check if it could contain an active - if active_reg_inner(T, seen, world, #=justActive=#Val(true)) == ActiveState - state = active_reg_inner(T, seen, world, #=justActive=#Val(false)) + if active_reg_inner(T, seen, world, Val(true)) == ActiveState #=justActive=# + state = active_reg_inner(T, seen, world, Val(false)) #=justActive=# if state == ActiveState return true end @assert state == MixedState - throw(AssertionError(string(T)*" has mixed internal activity types. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information")) + throw( + AssertionError( + string(T) * + " has mixed internal activity types. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information", + ), + ) else return false end end -@inline function guaranteed_const(::Type{T}) where T +@inline function guaranteed_const(::Type{T}) where {T} rt = active_reg_nothrow(T, Val(nothing)) res = rt == AnyState return res end -@inline function guaranteed_const_nongen(::Type{T}, world) where T +@inline function guaranteed_const_nongen(::Type{T}, world) where {T} rt = active_reg_inner(T, (), world) res = rt == AnyState return res @@ -624,12 +841,13 @@ end # check if a value is guaranteed to be not contain active[register] data # (aka not either mixed or active) -@inline function guaranteed_nonactive(::Type{T}) where T +@inline function guaranteed_nonactive(::Type{T}) where {T} rt = Enzyme.Compiler.active_reg_nothrow(T, Val(nothing)) return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState end -@inline Enzyme.guess_activity(::Type{T}, mode::Enzyme.Mode) where {T} = guess_activity(T, convert(API.CDerivativeMode, mode)) +@inline Enzyme.guess_activity(::Type{T}, mode::Enzyme.Mode) where {T} = + guess_activity(T, convert(API.CDerivativeMode, mode)) @inline function Enzyme.guess_activity(::Type{T}, Mode::API.CDerivativeMode) where {T} ActReg = active_reg_inner(T, (), nothing) @@ -650,54 +868,72 @@ end end # User facing interface -abstract type AbstractThunk{FA, RT, TT, Width} end +abstract type AbstractThunk{FA,RT,TT,Width} end -struct CombinedAdjointThunk{PT, FA, RT, TT, Width, ReturnPrimal} <: AbstractThunk{FA, RT, TT, Width} +struct CombinedAdjointThunk{PT,FA,RT,TT,Width,ReturnPrimal} <: AbstractThunk{FA,RT,TT,Width} adjoint::PT end -struct ForwardModeThunk{PT, FA, RT, TT, Width, ReturnPrimal} <: AbstractThunk{FA, RT, TT, Width} +struct ForwardModeThunk{PT,FA,RT,TT,Width,ReturnPrimal} <: AbstractThunk{FA,RT,TT,Width} adjoint::PT end -struct AugmentedForwardThunk{PT, FA, RT, TT, Width, ReturnPrimal, TapeType} <: AbstractThunk{FA, RT, TT, Width} +struct AugmentedForwardThunk{PT,FA,RT,TT,Width,ReturnPrimal,TapeType} <: + AbstractThunk{FA,RT,TT,Width} primal::PT end -struct AdjointThunk{PT, FA, RT, TT, Width, TapeType} <: AbstractThunk{FA, RT, TT, Width} +struct AdjointThunk{PT,FA,RT,TT,Width,TapeType} <: AbstractThunk{FA,RT,TT,Width} adjoint::PT end -struct PrimalErrorThunk{PT, FA, RT, TT, Width, ReturnPrimal} <: AbstractThunk{FA, RT, TT, Width} +struct PrimalErrorThunk{PT,FA,RT,TT,Width,ReturnPrimal} <: AbstractThunk{FA,RT,TT,Width} adjoint::PT end -@inline return_type(::AbstractThunk{FA, RT}) where {FA, RT} = RT -@inline return_type(::Type{AugmentedForwardThunk{PT, FA, RT, TT, Width, ReturnPrimal, TapeType}}) where {PT, FA, RT, TT, Width, ReturnPrimal, TapeType} = RT - -@inline EnzymeRules.tape_type(::Type{AugmentedForwardThunk{PT, FA, RT, TT, Width, ReturnPrimal, TapeType}}) where {PT, FA, RT, TT, Width, ReturnPrimal, TapeType} = TapeType -@inline EnzymeRules.tape_type(::AugmentedForwardThunk{PT, FA, RT, TT, Width, ReturnPrimal, TapeType}) where {PT, FA, RT, TT, Width, ReturnPrimal, TapeType} = TapeType -@inline EnzymeRules.tape_type(::Type{AdjointThunk{PT, FA, RT, TT, Width, TapeType}}) where {PT, FA, RT, TT, Width, TapeType} = TapeType -@inline EnzymeRules.tape_type(::AdjointThunk{PT, FA, RT, TT, Width, TapeType}) where {PT, FA, RT, TT, Width, TapeType} = TapeType +@inline return_type(::AbstractThunk{FA,RT}) where {FA,RT} = RT +@inline return_type( + ::Type{AugmentedForwardThunk{PT,FA,RT,TT,Width,ReturnPrimal,TapeType}}, +) where {PT,FA,RT,TT,Width,ReturnPrimal,TapeType} = RT + +@inline EnzymeRules.tape_type( + ::Type{AugmentedForwardThunk{PT,FA,RT,TT,Width,ReturnPrimal,TapeType}}, +) where {PT,FA,RT,TT,Width,ReturnPrimal,TapeType} = TapeType +@inline EnzymeRules.tape_type( + ::AugmentedForwardThunk{PT,FA,RT,TT,Width,ReturnPrimal,TapeType}, +) where {PT,FA,RT,TT,Width,ReturnPrimal,TapeType} = TapeType +@inline EnzymeRules.tape_type( + ::Type{AdjointThunk{PT,FA,RT,TT,Width,TapeType}}, +) where {PT,FA,RT,TT,Width,TapeType} = TapeType +@inline EnzymeRules.tape_type( + ::AdjointThunk{PT,FA,RT,TT,Width,TapeType}, +) where {PT,FA,RT,TT,Width,TapeType} = TapeType using .JIT -declare_allocobj!(mod) = get_function!(mod, "julia.gc_alloc_obj") do - T_jlvalue = LLVM.StructType(LLVM.LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - T_ppjlvalue = LLVM.PointerType(LLVM.PointerType(T_jlvalue)) - T_size_t = convert(LLVM.LLVMType, Int) +declare_allocobj!(mod) = + get_function!(mod, "julia.gc_alloc_obj") do + T_jlvalue = LLVM.StructType(LLVM.LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + T_ppjlvalue = LLVM.PointerType(LLVM.PointerType(T_jlvalue)) + T_size_t = convert(LLVM.LLVMType, Int) - LLVM.FunctionType(T_prjlvalue, [T_ppjlvalue, T_size_t, T_prjlvalue]) -end -function emit_allocobj!(B, tag::LLVM.Value, Size::LLVM.Value, needs_workaround::Bool, name::String="") + LLVM.FunctionType(T_prjlvalue, [T_ppjlvalue, T_size_t, T_prjlvalue]) + end +function emit_allocobj!( + B, + tag::LLVM.Value, + Size::LLVM.Value, + needs_workaround::Bool, + name::String = "", +) curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) - T_jlvalue = LLVM.StructType(LLVMType[]) + T_jlvalue = LLVM.StructType(LLVMType[]) T_pjlvalue = LLVM.PointerType(T_jlvalue) T_ppjlvalue = LLVM.PointerType(T_pjlvalue) @@ -705,13 +941,13 @@ function emit_allocobj!(B, tag::LLVM.Value, Size::LLVM.Value, needs_workaround:: T_pint8 = LLVM.PointerType(T_int8) pgcstack = reinsert_gcmarker!(fn, B) - ct = inbounds_gep!(B, + ct = inbounds_gep!( + B, T_pjlvalue, bitcast!(B, pgcstack, T_ppjlvalue), - [LLVM.ConstantInt(current_task_offset())]) - ptls_field = inbounds_gep!(B, - T_pjlvalue, - ct, [LLVM.ConstantInt(current_ptls_offset())]) + [LLVM.ConstantInt(current_task_offset())], + ) + ptls_field = inbounds_gep!(B, T_pjlvalue, ct, [LLVM.ConstantInt(current_ptls_offset())]) T_ppint8 = LLVM.PointerType(T_pint8) ptls = load!(B, T_pint8, bitcast!(B, ptls_field, T_ppint8)) @@ -732,12 +968,12 @@ function emit_allocobj!(B, tag::LLVM.Value, Size::LLVM.Value, needs_workaround:: return call!(B, alty, alloc_obj, [ct, Size, tag], name) end -function emit_allocobj!(B, T::DataType, name::String="") +function emit_allocobj!(B, T::DataType, name::String = "") curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) - T_jlvalue = LLVM.StructType(LLVMType[]) + T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue_UT = LLVM.PointerType(T_jlvalue) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) @@ -746,14 +982,15 @@ function emit_allocobj!(B, T::DataType, name::String="") T_size_t = convert(LLVM.LLVMType, UInt) Size = LLVM.ConstantInt(T_size_t, sizeof(T)) - emit_allocobj!(B, tag, Size, #=needs_workaround=#false, name) -end -declare_pointerfromobjref!(mod) = get_function!(mod, "julia.pointer_from_objref") do - T_jlvalue = LLVM.StructType(LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Derived) - T_pjlvalue = LLVM.PointerType(T_jlvalue) - LLVM.FunctionType(T_pjlvalue, [T_prjlvalue]) + emit_allocobj!(B, tag, Size, false, name) #=needs_workaround=# end +declare_pointerfromobjref!(mod) = + get_function!(mod, "julia.pointer_from_objref") do + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Derived) + T_pjlvalue = LLVM.PointerType(T_jlvalue) + LLVM.FunctionType(T_pjlvalue, [T_prjlvalue]) + end function emit_pointerfromobjref!(B, T) curent_bb = position(B) fn = LLVM.parent(curent_bb) @@ -762,21 +999,27 @@ function emit_pointerfromobjref!(B, T) return call!(B, fty, func, [T]) end -declare_writebarrier!(mod) = get_function!(mod, "julia.write_barrier") do - T_jlvalue = LLVM.StructType(LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - LLVM.FunctionType(LLVM.VoidType(), [T_prjlvalue]; vararg=true) -end -declare_apply_generic!(mod) = get_function!(mod, "ijl_apply_generic") do - T_jlvalue = LLVM.StructType(LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, LLVM.PointerType(T_prjlvalue), LLVM.Int32Type()]) -end -declare_juliacall!(mod) = get_function!(mod, "julia.call") do - T_jlvalue = LLVM.StructType(LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - LLVM.FunctionType(T_prjlvalue, [T_prjlvalue]; vararg=true) -end +declare_writebarrier!(mod) = + get_function!(mod, "julia.write_barrier") do + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + LLVM.FunctionType(LLVM.VoidType(), [T_prjlvalue]; vararg = true) + end +declare_apply_generic!(mod) = + get_function!(mod, "ijl_apply_generic") do + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + LLVM.FunctionType( + T_prjlvalue, + [T_prjlvalue, LLVM.PointerType(T_prjlvalue), LLVM.Int32Type()], + ) + end +declare_juliacall!(mod) = + get_function!(mod, "julia.call") do + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + LLVM.FunctionType(T_prjlvalue, [T_prjlvalue]; vararg = true) + end function emit_jl!(B::LLVM.IRBuilder, val::LLVM.Value)::LLVM.Value curent_bb = position(B) @@ -804,9 +1047,15 @@ function emit_getfield!(B::LLVM.IRBuilder, val::LLVM.Value, fld::LLVM.Value)::LL args = [val, fld] - julia_call, FT = get_function!(mod, "julia.call", - LLVM.FunctionType(T_prjlvalue, - [LLVM.PointerType(gen_FT), T_prjlvalue]; vararg=true)) + julia_call, FT = get_function!( + mod, + "julia.call", + LLVM.FunctionType( + T_prjlvalue, + [LLVM.PointerType(gen_FT), T_prjlvalue]; + vararg = true, + ), + ) res = call!(B, FT, julia_call, LLVM.Value[inv, args...]) return res end @@ -818,7 +1067,7 @@ function emit_nthfield!(B::LLVM.IRBuilder, val::LLVM.Value, fld::LLVM.Value)::LL mod = LLVM.parent(fn) T_jlvalue = LLVM.StructType(LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) T_size_t = convert(LLVM.LLVMType, Int) gen_FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_size_t]) @@ -882,9 +1131,15 @@ function emit_apply_generic!(B::LLVM.IRBuilder, args)::LLVM.Value inv, _ = get_function!(mod, "ijl_apply_generic", gen_FT) # %5 = call nonnull {}* ({}* ({}*, {}**, i32)*, {}*, ...) @julia.call({}* ({}*, {}**, i32)* @jl_f_apply_type, {}* null, {}* inttoptr (i64 139640605802128 to {}*), {}* %4, {}* inttoptr (i64 139640590432896 to {}*)) - julia_call, FT = get_function!(mod, "julia.call", - LLVM.FunctionType(T_prjlvalue, - [LLVM.PointerType(gen_FT), T_prjlvalue]; vararg=true)) + julia_call, FT = get_function!( + mod, + "julia.call", + LLVM.FunctionType( + T_prjlvalue, + [LLVM.PointerType(gen_FT), T_prjlvalue]; + vararg = true, + ), + ) res = call!(B, FT, julia_call, LLVM.Value[inv, args...]) return res end @@ -900,13 +1155,20 @@ function emit_invoke!(B::LLVM.IRBuilder, args)::LLVM.Value T_int32 = LLVM.Int32Type() # {} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32, {} addrspace(10)*)* @ijl_invoke - gen_FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_pprjlvalue, T_int32, T_prjlvalue]) + gen_FT = + LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_pprjlvalue, T_int32, T_prjlvalue]) inv = get_function!(mod, "ijl_invoke", gen_FT) # %5 = call nonnull {}* ({}* ({}*, {}**, i32)*, {}*, ...) @julia.call({}* ({}*, {}**, i32)* @jl_f_apply_type, {}* null, {}* inttoptr (i64 139640605802128 to {}*), {}* %4, {}* inttoptr (i64 139640590432896 to {}*)) - julia_call, FT = get_function!(mod, "julia.call2", - LLVM.FunctionType(T_prjlvalue, - [LLVM.PointerType(generic_FT), T_prjlvalue]; vararg=true)) + julia_call, FT = get_function!( + mod, + "julia.call2", + LLVM.FunctionType( + T_prjlvalue, + [LLVM.PointerType(generic_FT), T_prjlvalue]; + vararg = true, + ), + ) res = call!(B, FT, julia_call, [inv, args...]) return res end @@ -920,13 +1182,13 @@ function emit_svec!(B, args)::LLVM.Value sz = convert(LLVMType, Csize_t) T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - LLVM.FunctionType(T_prjlvalue, [sz]; vararg=true) - + LLVM.FunctionType(T_prjlvalue, [sz]; vararg = true) + sz = convert(LLVMType, Csize_t) call!(B, fty, fn, [LLVM.ConstantInt(sz, length(args)), args...]) end -AnyArray(Length::Int) = NamedTuple{ntuple(i->Symbol(i), Val(Length)),NTuple{Length,Any}} +AnyArray(Length::Int) = NamedTuple{ntuple(i -> Symbol(i), Val(Length)),NTuple{Length,Any}} struct EnzymeRuntimeException <: Base.Exception msg::Cstring @@ -953,12 +1215,27 @@ end function Base.showerror(io::IO, ece::EnzymeRuntimeActivityError) println(io, "Constant memory is stored (or returned) to a differentiable variable.") - println(io, "As a result, Enzyme cannot provably ensure correctness and throws this error.") - println(io, "This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Runtime-Activity).") - println(io, "If Enzyme should be able to prove this use non-differentable, open an issue!"); - println(io, "To work around this issue, either:"); - println(io, " a) rewrite this variable to not be conditionally active (fastest, but requires a code change), or") - println(io, " b) set the Enzyme mode to turn on runtime activity (e.g. autodiff(set_runtime_activity(Reverse), ...) ). This will maintain correctness, but may slightly reduce performance.") + println( + io, + "As a result, Enzyme cannot provably ensure correctness and throws this error.", + ) + println( + io, + "This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Runtime-Activity).", + ) + println( + io, + "If Enzyme should be able to prove this use non-differentable, open an issue!", + ) + println(io, "To work around this issue, either:") + println( + io, + " a) rewrite this variable to not be conditionally active (fastest, but requires a code change), or", + ) + println( + io, + " b) set the Enzyme mode to turn on runtime activity (e.g. autodiff(set_runtime_activity(Reverse), ...) ). This will maintain correctness, but may slightly reduce performance.", + ) msg = Base.unsafe_string(ece.msg) print(io, msg, '\n') end @@ -992,7 +1269,7 @@ function Base.showerror(io::IO, ece::EnzymeNoDerivativeError) print(io, msg, '\n') end -const JuliaEnzymeNameMap = Dict{String, Any}( +const JuliaEnzymeNameMap = Dict{String,Any}( "enz_val_true" => Val(true), "enz_val_false" => Val(false), "enz_val_1" => Val(1), @@ -1007,7 +1284,7 @@ const JuliaEnzymeNameMap = Dict{String, Any}( "enz_no_derivative_exc" => EnzymeNoDerivativeError, ) -const JuliaGlobalNameMap = Dict{String, Any}( +const JuliaGlobalNameMap = Dict{String,Any}( "jl_type_type" => Type, "jl_any_type" => Any, "jl_datatype_type" => DataType, @@ -1015,12 +1292,10 @@ const JuliaGlobalNameMap = Dict{String, Any}( "jl_symbol_type" => Symbol, "jl_simplevector_type" => Core.SimpleVector, "jl_nothing_type" => Nothing, - "jl_tvar_type" => TypeVar, "jl_typeofbottom_type" => Core.TypeofBottom, "jl_bottom_type" => Union{}, "jl_unionall_type" => UnionAll, - "jl_uniontype_type" => Union, "jl_emptytuple_type" => Tuple{}, "jl_emptytuple" => (), @@ -1052,46 +1327,31 @@ const JuliaGlobalNameMap = Dict{String, Any}( "jl_ref_type" => Ref, "jl_pointer_typename" => Ptr, "jl_voidpointer_type" => Ptr{Nothing}, - "jl_abstractarray_type" => AbstractArray, - "jl_densearray_type" => DenseArray, - "jl_array_type" => Array, - - "jl_array_any_type" => Array{Any, 1}, - - "jl_array_symbol_type" => Array{Symbol, 1}, - - "jl_array_uint8_type" => Array{UInt8, 1}, + "jl_array_any_type" => Array{Any,1}, + "jl_array_symbol_type" => Array{Symbol,1}, + "jl_array_uint8_type" => Array{UInt8,1}, # "jl_array_uint32_type" => Array{UInt32, 1}, - "jl_array_int32_type" => Array{Int32, 1}, - - + "jl_array_int32_type" => Array{Int32,1}, "jl_expr_type" => Expr, - "jl_method_type" => Method, "jl_method_instance_type" => Core.MethodInstance, "jl_code_instance_type" => Core.CodeInstance, "jl_const_type" => Core.Const, "jl_llvmpointer_type" => Core.LLVMPtr, - - "jl_namedtuple_type" => NamedTuple, - "jl_task_type" => Task, - "jl_uint8pointer_type" => Ptr{UInt8}, - "jl_nothing" => nothing, - "jl_anytuple_type" => Tuple, "jl_vararg_type" => Core.TypeofVararg, "jl_opaque_closure_type" => Core.OpaqueClosure, - "jl_array_uint64_type" => Array{UInt64, 1}, - "jl_binding_type" => Core.Binding + "jl_array_uint64_type" => Array{UInt64,1}, + "jl_binding_type" => Core.Binding, ) include("absint.jl") @@ -1104,7 +1364,7 @@ function emit_apply_type!(B::LLVM.IRBuilder, Ty, args)::LLVM.Value legal = true found = [] for arg in args - slegal , foundv = absint(arg) + slegal, foundv = absint(arg) if slegal push!(found, foundv) else @@ -1127,10 +1387,21 @@ function emit_apply_type!(B::LLVM.IRBuilder, Ty, args)::LLVM.Value Ty = unsafe_to_llvm(B, Ty) # %5 = call nonnull {}* ({}* ({}*, {}**, i32)*, {}*, ...) @julia.call({}* ({}*, {}**, i32)* @jl_f_apply_type, {}* null, {}* inttoptr (i64 139640605802128 to {}*), {}* %4, {}* inttoptr (i64 139640590432896 to {}*)) - julia_call, FT = get_function!(mod, "julia.call", - LLVM.FunctionType(T_prjlvalue, - [LLVM.PointerType(generic_FT), T_prjlvalue]; vararg=true)) - tag = call!(B, FT, julia_call, LLVM.Value[f_apply_type, LLVM.PointerNull(T_prjlvalue), Ty, args...]) + julia_call, FT = get_function!( + mod, + "julia.call", + LLVM.FunctionType( + T_prjlvalue, + [LLVM.PointerType(generic_FT), T_prjlvalue]; + vararg = true, + ), + ) + tag = call!( + B, + FT, + julia_call, + LLVM.Value[f_apply_type, LLVM.PointerNull(T_prjlvalue), Ty, args...], + ) return tag end @@ -1142,7 +1413,7 @@ function emit_tuple!(B, args)::LLVM.Value legal = true found = [] for arg in args - slegal , foundv = absint(arg) + slegal, foundv = absint(arg) if slegal push!(found, foundv) else @@ -1164,10 +1435,21 @@ function emit_tuple!(B, args)::LLVM.Value f_apply_type, _ = get_function!(mod, "jl_f_tuple", generic_FT) # %5 = call nonnull {}* ({}* ({}*, {}**, i32)*, {}*, ...) @julia.call({}* ({}*, {}**, i32)* @jl_f_apply_type, {}* null, {}* inttoptr (i64 139640605802128 to {}*), {}* %4, {}* inttoptr (i64 139640590432896 to {}*)) - julia_call, FT = get_function!(mod, "julia.call", - LLVM.FunctionType(T_prjlvalue, - [LLVM.PointerType(generic_FT), T_prjlvalue]; vararg=true)) - tag = call!(B, FT, julia_call, LLVM.Value[f_apply_type, LLVM.PointerNull(T_prjlvalue), args...]) + julia_call, FT = get_function!( + mod, + "julia.call", + LLVM.FunctionType( + T_prjlvalue, + [LLVM.PointerType(generic_FT), T_prjlvalue]; + vararg = true, + ), + ) + tag = call!( + B, + FT, + julia_call, + LLVM.Value[f_apply_type, LLVM.PointerNull(T_prjlvalue), args...], + ) return tag end @@ -1176,14 +1458,14 @@ function emit_jltypeof!(B::LLVM.IRBuilder, arg::LLVM.Value)::LLVM.Value fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) - legal, val = abs_typeof(arg) + legal, val, byref = abs_typeof(arg) if legal return unsafe_to_llvm(B, val) end T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue]; vararg=true) + FT = LLVM.FunctionType(T_prjlvalue, [T_prjlvalue]; vararg = true) fn, _ = get_function!(mod, "jl_typeof", FT) call!(B, FT, fn, [arg]) end @@ -1206,43 +1488,65 @@ function emit_methodinstance!(B::LLVM.IRBuilder, func, args)::LLVM.Value meth = only(methods(func)) tag = emit_apply_type!(B, Tuple, primalvaltys) -# TT = meth.sig -# while TT isa UnionAll -# TT = TT.body -# end -# parms = TT.parameters -# -# tosv = primalvaltys -# if length(parms) > 0 && typeof(parms[end]) == Core.TypeofVararg -# tosv = LLVM.Value[tosv[1:length(parms)-1]..., emit_apply_type!(B, Tuple, tosv[length(parms):end])] -# end -# sv = emit_svec!(B, tosv[2:end]) -# + # TT = meth.sig + # while TT isa UnionAll + # TT = TT.body + # end + # parms = TT.parameters + # + # tosv = primalvaltys + # if length(parms) > 0 && typeof(parms[end]) == Core.TypeofVararg + # tosv = LLVM.Value[tosv[1:length(parms)-1]..., emit_apply_type!(B, Tuple, tosv[length(parms):end])] + # end + # sv = emit_svec!(B, tosv[2:end]) + # meth = unsafe_to_llvm(B, meth) T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - worlds, FT = get_function!(mod, "jl_gf_invoke_lookup_worlds", - LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_prjlvalue, sizeT, psizeT, psizeT])) + worlds, FT = get_function!( + mod, + "jl_gf_invoke_lookup_worlds", + LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_prjlvalue, sizeT, psizeT, psizeT]), + ) EB = LLVM.IRBuilder() position!(EB, first(LLVM.instructions(LLVM.entry(fn)))) minworld = alloca!(EB, sizeT) maxworld = alloca!(EB, sizeT) store!(B, LLVM.ConstantInt(sizeT, 0), minworld) store!(B, LLVM.ConstantInt(sizeT, -1), maxworld) - methodmatch = call!(B, FT, worlds, LLVM.Value[tag, unsafe_to_llvm(B, nothing), LLVM.ConstantInt(sizeT, world), minworld, maxworld]) + methodmatch = call!( + B, + FT, + worlds, + LLVM.Value[ + tag, + unsafe_to_llvm(B, nothing), + LLVM.ConstantInt(sizeT, world), + minworld, + maxworld, + ], + ) # emit_jl!(B, methodmatch) # emit_jl!(B, emit_jltypeof!(B, methodmatch)) offset = 1 - AT = LLVM.ArrayType(T_prjlvalue, offset+1) + AT = LLVM.ArrayType(T_prjlvalue, offset + 1) methodmatch = addrspacecast!(B, methodmatch, LLVM.PointerType(T_jlvalue, Derived)) methodmatch = bitcast!(B, methodmatch, LLVM.PointerType(AT, Derived)) - gep = LLVM.inbounds_gep!(B, AT, methodmatch, LLVM.Value[LLVM.ConstantInt(0), LLVM.ConstantInt(offset)]) + gep = LLVM.inbounds_gep!( + B, + AT, + methodmatch, + LLVM.Value[LLVM.ConstantInt(0), LLVM.ConstantInt(offset)], + ) sv = LLVM.load!(B, T_prjlvalue, gep) - fn, FT = get_function!(mod, "jl_specializations_get_linfo", - LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_prjlvalue, T_prjlvalue])) + fn, FT = get_function!( + mod, + "jl_specializations_get_linfo", + LLVM.FunctionType(T_prjlvalue, [T_prjlvalue, T_prjlvalue, T_prjlvalue]), + ) mi = call!(B, FT, fn, [meth, tag, sv]) @@ -1259,86 +1563,99 @@ end function get_array_struct() -@static if VERSION < v"1.11-" -# JL_EXTENSION typedef struct { -# JL_DATA_TYPE -# void *data; -# #ifdef STORE_ARRAY_LEN (just true new newer versions) -# size_t length; -# #endif -# jl_array_flags_t flags; -# uint16_t elsize; // element size including alignment (dim 1 memory stride) -# uint32_t offset; // for 1-d only. does not need to get big. -# size_t nrows; -# union { -# // 1d -# size_t maxsize; -# // Nd -# size_t ncols; -# }; -# // other dim sizes go here for ndims > 2 -# -# // followed by alignment padding and inline data, or owner pointer -# } jl_array_t; - - i8 = LLVM.IntType(8) - ptrty = LLVM.PointerType(i8, 13) - sizeT = LLVM.IntType(8*sizeof(Csize_t)) - arrayFlags = LLVM.IntType(16) - elsz = LLVM.IntType(16) - off = LLVM.IntType(32) - nrows = LLVM.IntType(8*sizeof(Csize_t)) - - return LLVM.StructType([ptrty, sizeT, arrayFlags, elsz, off, nrows]; packed=true) -else -# JL_EXTENSION typedef struct { -# JL_DATA_TYPE -# size_t length; -# void *ptr; -# // followed by padding and inline data, or owner pointer -# #ifdef _P64 -# // union { -# // jl_value_t *owner; -# // T inl[]; -# // }; -# #else -# // -# // jl_value_t *owner; -# // size_t padding[1]; -# // T inl[]; -# #endif -# } jl_genericmemory_t; -# -# JL_EXTENSION typedef struct { -# JL_DATA_TYPE -# void *ptr_or_offset; -# jl_genericmemory_t *mem; -# } jl_genericmemoryref_t; -# -# JL_EXTENSION typedef struct { -# JL_DATA_TYPE -# jl_genericmemoryref_t ref; -# size_t dimsize[]; // length for 1-D, otherwise length is mem->length -# } jl_array_t; - i8 = LLVM.IntType(8) - ptrty = LLVM.PointerType(i8, 10) - sizeT = LLVM.IntType(8*sizeof(Csize_t)) - return LLVM.StructType([ptrty, sizeT]; packed=true) -end + @static if VERSION < v"1.11-" + # JL_EXTENSION typedef struct { + # JL_DATA_TYPE + # void *data; + # #ifdef STORE_ARRAY_LEN (just true new newer versions) + # size_t length; + # #endif + # jl_array_flags_t flags; + # uint16_t elsize; // element size including alignment (dim 1 memory stride) + # uint32_t offset; // for 1-d only. does not need to get big. + # size_t nrows; + # union { + # // 1d + # size_t maxsize; + # // Nd + # size_t ncols; + # }; + # // other dim sizes go here for ndims > 2 + # + # // followed by alignment padding and inline data, or owner pointer + # } jl_array_t; + + i8 = LLVM.IntType(8) + ptrty = LLVM.PointerType(i8, 13) + sizeT = LLVM.IntType(8 * sizeof(Csize_t)) + arrayFlags = LLVM.IntType(16) + elsz = LLVM.IntType(16) + off = LLVM.IntType(32) + nrows = LLVM.IntType(8 * sizeof(Csize_t)) + + return LLVM.StructType([ptrty, sizeT, arrayFlags, elsz, off, nrows]; packed = true) + else + # JL_EXTENSION typedef struct { + # JL_DATA_TYPE + # size_t length; + # void *ptr; + # // followed by padding and inline data, or owner pointer + # #ifdef _P64 + # // union { + # // jl_value_t *owner; + # // T inl[]; + # // }; + # #else + # // + # // jl_value_t *owner; + # // size_t padding[1]; + # // T inl[]; + # #endif + # } jl_genericmemory_t; + # + # JL_EXTENSION typedef struct { + # JL_DATA_TYPE + # void *ptr_or_offset; + # jl_genericmemory_t *mem; + # } jl_genericmemoryref_t; + # + # JL_EXTENSION typedef struct { + # JL_DATA_TYPE + # jl_genericmemoryref_t ref; + # size_t dimsize[]; // length for 1-D, otherwise length is mem->length + # } jl_array_t; + i8 = LLVM.IntType(8) + ptrty = LLVM.PointerType(i8, 10) + sizeT = LLVM.IntType(8 * sizeof(Csize_t)) + return LLVM.StructType([ptrty, sizeT]; packed = true) + end end function get_array_data(B, array) i8 = LLVM.IntType(8) ptrty = LLVM.PointerType(i8, 13) - array = LLVM.pointercast!(B, array, LLVM.PointerType(ptrty, LLVM.addrspace(LLVM.value_type(array)))) + array = LLVM.pointercast!( + B, + array, + LLVM.PointerType(ptrty, LLVM.addrspace(LLVM.value_type(array))), + ) return LLVM.load!(B, ptrty, array) end function get_array_elsz(B, array) ST = get_array_struct() elsz = LLVM.IntType(16) - array = LLVM.pointercast!(B, array, LLVM.PointerType(ST, LLVM.addrspace(LLVM.value_type(array)))) - v = inbounds_gep!(B, ST, array, LLVM.Value[LLVM.ConstantInt(Int32(0)), LLVM.ConstantInt(Int32(3))]) + array = LLVM.pointercast!( + B, + array, + LLVM.PointerType(ST, LLVM.addrspace(LLVM.value_type(array))), + ) + v = inbounds_gep!( + B, + ST, + array, + LLVM.Value[LLVM.ConstantInt(Int32(0)), LLVM.ConstantInt(Int32(3))], + ) return LLVM.load!(B, elsz, v) end @@ -1351,13 +1668,16 @@ function get_array_len(B, array) end for (fname, num) in ( - ("jl_alloc_array_1d", 1), ("ijl_alloc_array_1d", 1), - ("jl_alloc_array_2d", 2), ("jl_alloc_array_2d", 2), - ("jl_alloc_array_2d", 3), ("jl_alloc_array_2d", 3), - ) + ("jl_alloc_array_1d", 1), + ("ijl_alloc_array_1d", 1), + ("jl_alloc_array_2d", 2), + ("jl_alloc_array_2d", 2), + ("jl_alloc_array_2d", 3), + ("jl_alloc_array_2d", 3), + ) if nm == fname res = operands(array)[2] - for i in 2:num + for i = 2:num res = mul!(B, res, operands(array)[1+i]) end return res @@ -1365,17 +1685,35 @@ function get_array_len(B, array) end end ST = get_array_struct() - array = LLVM.pointercast!(B, array, LLVM.PointerType(ST, LLVM.addrspace(LLVM.value_type(array)))) - v = inbounds_gep!(B, ST, array, LLVM.Value[LLVM.ConstantInt(Int32(0)), LLVM.ConstantInt(Int32(1))]) - sizeT = LLVM.IntType(8*sizeof(Csize_t)) + array = LLVM.pointercast!( + B, + array, + LLVM.PointerType(ST, LLVM.addrspace(LLVM.value_type(array))), + ) + v = inbounds_gep!( + B, + ST, + array, + LLVM.Value[LLVM.ConstantInt(Int32(0)), LLVM.ConstantInt(Int32(1))], + ) + sizeT = LLVM.IntType(8 * sizeof(Csize_t)) return LLVM.load!(B, sizeT, v) end function get_array_nrows(B, array) ST = get_array_struct() - array = LLVM.pointercast!(B, array, LLVM.PointerType(ST, LLVM.addrspace(LLVM.value_type(array)))) - v = inbounds_gep!(B, ST, array, LLVM.Value[LLVM.ConstantInt(Int32(0)), LLVM.ConstantInt(Int32(5))]) - nrows = LLVM.IntType(8*sizeof(Csize_t)) + array = LLVM.pointercast!( + B, + array, + LLVM.PointerType(ST, LLVM.addrspace(LLVM.value_type(array))), + ) + v = inbounds_gep!( + B, + ST, + array, + LLVM.Value[LLVM.ConstantInt(Int32(0)), LLVM.ConstantInt(Int32(5))], + ) + nrows = LLVM.IntType(8 * sizeof(Csize_t)) return LLVM.load!(B, nrows, v) end @@ -1391,7 +1729,14 @@ function permit_inlining!(f::LLVM.Function) if isa(inst, LLVM.LoadInst) md = metadata(inst) if haskey(md, LLVM.MD_tbaa) - modified = LLVM.Metadata(ccall((:EnzymeMakeNonConstTBAA, API.libEnzyme), LLVM.API.LLVMMetadataRef, (LLVM.API.LLVMMetadataRef,), md[LLVM.MD_tbaa])) + modified = LLVM.Metadata( + ccall( + (:EnzymeMakeNonConstTBAA, API.libEnzyme), + LLVM.API.LLVMMetadataRef, + (LLVM.API.LLVMMetadataRef,), + md[LLVM.MD_tbaa], + ), + ) setindex!(md, modified, LLVM.MD_tbaa) end if haskey(md, LLVM.MD_invariant_load) @@ -1406,11 +1751,15 @@ struct Tape{TapeTy,ShadowTy,ResT} shadow_return::ShadowTy end -function emit_gc_preserve_begin(B::LLVM.IRBuilder, args=LLVM.Value[]) +function emit_gc_preserve_begin(B::LLVM.IRBuilder, args = LLVM.Value[]) curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) - func, FT = get_function!(mod, "llvm.julia.gc_preserve_begin", LLVM.FunctionType(LLVM.TokenType(), vararg=true)) + func, FT = get_function!( + mod, + "llvm.julia.gc_preserve_begin", + LLVM.FunctionType(LLVM.TokenType(), vararg = true), + ) token = call!(B, FT, func, args) return token @@ -1421,7 +1770,11 @@ function emit_gc_preserve_end(B::LLVM.IRBuilder, token) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) - func, FT = get_function!(mod, "llvm.julia.gc_preserve_end", LLVM.FunctionType(LLVM.VoidType(), [LLVM.TokenType()])) + func, FT = get_function!( + mod, + "llvm.julia.gc_preserve_end", + LLVM.FunctionType(LLVM.VoidType(), [LLVM.TokenType()]), + ) call!(B, FT, func, [token]) return @@ -1440,20 +1793,29 @@ function allocate_sret!(gutils::API.EnzymeGradientUtilsRef, N) allocate_sret!(B, N) end -@inline function EnzymeCore.make_zero(x::FT)::FT where {FT <: AbstractFloat} +@inline function EnzymeCore.make_zero(x::FT)::FT where {FT<:AbstractFloat} return Base.zero(x) end -@inline function EnzymeCore.make_zero(x::Complex{FT})::Complex{FT} where {FT <: AbstractFloat} +@inline function EnzymeCore.make_zero(x::Complex{FT})::Complex{FT} where {FT<:AbstractFloat} return Base.zero(x) end -@inline function EnzymeCore.make_zero(x::Array{FT, N})::Array{FT, N} where {FT <: AbstractFloat, N} +@inline function EnzymeCore.make_zero( + x::Array{FT,N}, +)::Array{FT,N} where {FT<:AbstractFloat,N} return Base.zero(x) end -@inline function EnzymeCore.make_zero(x::Array{Complex{FT}, N})::Array{Complex{FT}, N} where {FT <: AbstractFloat, N} +@inline function EnzymeCore.make_zero( + x::Array{Complex{FT},N}, +)::Array{Complex{FT},N} where {FT<:AbstractFloat,N} return Base.zero(x) end -@inline function EnzymeCore.make_zero(::Type{Array{FT, N}}, seen::IdDict, prev::Array{FT, N}, ::Val{copy_if_inactive}=Val(false))::Array{FT, N} where {copy_if_inactive, FT<:AbstractFloat, N} +@inline function EnzymeCore.make_zero( + ::Type{Array{FT,N}}, + seen::IdDict, + prev::Array{FT,N}, + ::Val{copy_if_inactive} = Val(false), +)::Array{FT,N} where {copy_if_inactive,FT<:AbstractFloat,N} if haskey(seen, prev) return seen[prev] end @@ -1461,7 +1823,12 @@ end seen[prev] = newa return newa end -@inline function EnzymeCore.make_zero(::Type{Array{Complex{FT}, N}}, seen::IdDict, prev::Array{Complex{FT}, N}, ::Val{copy_if_inactive}=Val(false))::Array{Complex{FT}, N} where {copy_if_inactive, FT<:AbstractFloat, N} +@inline function EnzymeCore.make_zero( + ::Type{Array{Complex{FT},N}}, + seen::IdDict, + prev::Array{Complex{FT},N}, + ::Val{copy_if_inactive} = Val(false), +)::Array{Complex{FT},N} where {copy_if_inactive,FT<:AbstractFloat,N} if haskey(seen, prev) return seen[prev] end @@ -1470,15 +1837,30 @@ end return newa end -@inline function EnzymeCore.make_zero(::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false))::RT where {copy_if_inactive, RT<:AbstractFloat} +@inline function EnzymeCore.make_zero( + ::Type{RT}, + seen::IdDict, + prev::RT, + ::Val{copy_if_inactive} = Val(false), +)::RT where {copy_if_inactive,RT<:AbstractFloat} return RT(0) end -@inline function EnzymeCore.make_zero(::Type{Complex{RT}}, seen::IdDict, prev::Complex{RT}, ::Val{copy_if_inactive}=Val(false))::Complex{RT} where {copy_if_inactive, RT<:AbstractFloat} +@inline function EnzymeCore.make_zero( + ::Type{Complex{RT}}, + seen::IdDict, + prev::Complex{RT}, + ::Val{copy_if_inactive} = Val(false), +)::Complex{RT} where {copy_if_inactive,RT<:AbstractFloat} return RT(0) end -@inline function EnzymeCore.make_zero(::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false))::RT where {copy_if_inactive, RT<:Array} +@inline function EnzymeCore.make_zero( + ::Type{RT}, + seen::IdDict, + prev::RT, + ::Val{copy_if_inactive} = Val(false), +)::RT where {copy_if_inactive,RT<:Array} if haskey(seen, prev) return seen[prev] end @@ -1491,35 +1873,58 @@ end if isassigned(prev, I) pv = prev[I] innerty = Core.Typeof(pv) - @inbounds newa[I] = EnzymeCore.make_zero(innerty, seen, pv, Val(copy_if_inactive)) + @inbounds newa[I] = + EnzymeCore.make_zero(innerty, seen, pv, Val(copy_if_inactive)) end end return newa end -@inline function EnzymeCore.make_zero(::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false))::RT where {copy_if_inactive, RT<:Tuple} +@inline function EnzymeCore.make_zero( + ::Type{RT}, + seen::IdDict, + prev::RT, + ::Val{copy_if_inactive} = Val(false), +)::RT where {copy_if_inactive,RT<:Tuple} return ntuple(length(prev)) do i Base.@_inline_meta EnzymeCore.make_zero(RT.parameters[i], seen, prev[i], Val(copy_if_inactive)) end end -@inline function EnzymeCore.make_zero(::Type{NamedTuple{A,RT}}, seen::IdDict, prev::NamedTuple{A,RT}, ::Val{copy_if_inactive}=Val(false))::NamedTuple{A,RT} where {copy_if_inactive, A,RT} +@inline function EnzymeCore.make_zero( + ::Type{NamedTuple{A,RT}}, + seen::IdDict, + prev::NamedTuple{A,RT}, + ::Val{copy_if_inactive} = Val(false), +)::NamedTuple{A,RT} where {copy_if_inactive,A,RT} return NamedTuple{A,RT}(EnzymeCore.make_zero(RT, seen, RT(prev), Val(copy_if_inactive))) end -@inline function EnzymeCore.make_zero(::Type{Core.Box}, seen::IdDict, prev::Core.Box, ::Val{copy_if_inactive}=Val(false)) where {copy_if_inactive} +@inline function EnzymeCore.make_zero( + ::Type{Core.Box}, + seen::IdDict, + prev::Core.Box, + ::Val{copy_if_inactive} = Val(false), +) where {copy_if_inactive} if haskey(seen, prev) return seen[prev] end prev2 = prev.contents res = Core.Box() seen[prev] = res - res.contents = Base.Ref(EnzymeCore.make_zero(Core.Typeof(prev2), seen, prev2, Val(copy_if_inactive))) + res.contents = Base.Ref( + EnzymeCore.make_zero(Core.Typeof(prev2), seen, prev2, Val(copy_if_inactive)), + ) return res end -@inline function EnzymeCore.make_zero(::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false))::RT where {copy_if_inactive, RT} +@inline function EnzymeCore.make_zero( + ::Type{RT}, + seen::IdDict, + prev::RT, + ::Val{copy_if_inactive} = Val(false), +)::RT where {copy_if_inactive,RT} if guaranteed_const_nongen(RT, nothing) return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev end @@ -1529,11 +1934,11 @@ end @assert !Base.isabstracttype(RT) @assert Base.isconcretetype(RT) nf = fieldcount(RT) - + if ismutable(prev) y = ccall(:jl_new_struct_uninit, Any, (Any,), RT) seen[prev] = y - for i in 1:nf + for i = 1:nf if isdefined(prev, i) xi = getfield(prev, i) T = Core.Typeof(xi) @@ -1543,13 +1948,13 @@ end end return y end - + if nf == 0 return prev end flds = Vector{Any}(undef, nf) - for i in 1:nf + for i = 1:nf if isdefined(prev, i) xi = getfield(prev, i) xi = EnzymeCore.make_zero(Core.Typeof(xi), seen, xi, Val(copy_if_inactive)) @@ -1564,32 +1969,33 @@ end return y end -function make_zero_immutable!(prev::T, seen::S)::T where {T <: AbstractFloat, S} +function make_zero_immutable!(prev::T, seen::S)::T where {T<:AbstractFloat,S} zero(T) end -function make_zero_immutable!(prev::Complex{T}, seen::S)::Complex{T} where {T <: AbstractFloat, S} +function make_zero_immutable!( + prev::Complex{T}, + seen::S, +)::Complex{T} where {T<:AbstractFloat,S} zero(T) end -function make_zero_immutable!(prev::T, seen::S)::T where {T <: Tuple, S} +function make_zero_immutable!(prev::T, seen::S)::T where {T<:Tuple,S} ntuple(Val(length(T.parameters))) do i Base.@_inline_meta make_zero_immutable!(prev[i], seen) end end -function make_zero_immutable!(prev::NamedTuple{a, b}, seen::S)::NamedTuple{a, b} where {a,b, S} - NamedTuple{a, b}( - ntuple(Val(length(T.parameters))) do i +function make_zero_immutable!(prev::NamedTuple{a,b}, seen::S)::NamedTuple{a,b} where {a,b,S} + NamedTuple{a,b}(ntuple(Val(length(T.parameters))) do i Base.@_inline_meta make_zero_immutable!(prev[a[i]], seen) - end - ) + end) end -function make_zero_immutable!(prev::T, seen::S)::T where {T, S} +function make_zero_immutable!(prev::T, seen::S)::T where {T,S} if guaranteed_const_nongen(T, nothing) return prev end @@ -1601,11 +2007,11 @@ function make_zero_immutable!(prev::T, seen::S)::T where {T, S} nf = fieldcount(RT) flds = Vector{Any}(undef, nf) - for i in 1:nf + for i = 1:nf if isdefined(prev, i) xi = getfield(prev, i) ST = Core.Typeof(xi) - flds[i] = if active_reg_inner(ST, (), nothing, #=justActive=#Val(true)) == ActiveState + flds[i] = if active_reg_inner(ST, (), nothing, Val(true)) == ActiveState #=justActive=# make_zero_immutable!(xi, seen) else EnzymeCore.make_zero!(xi, seen) @@ -1619,47 +2025,65 @@ function make_zero_immutable!(prev::T, seen::S)::T where {T, S} ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), RT, flds, nf)::T end -@inline function EnzymeCore.make_zero!(prev::Base.RefValue{T}, seen::ST)::Nothing where {T <: AbstractFloat, ST} +@inline function EnzymeCore.make_zero!( + prev::Base.RefValue{T}, + seen::ST, +)::Nothing where {T<:AbstractFloat,ST} T[] = zero(T) nothing end -@inline function EnzymeCore.make_zero!(prev::Base.RefValue{Complex{T}}, seen::ST)::Nothing where {T <: AbstractFloat, ST} +@inline function EnzymeCore.make_zero!( + prev::Base.RefValue{Complex{T}}, + seen::ST, +)::Nothing where {T<:AbstractFloat,ST} T[] = zero(Complex{T}) nothing end -@inline function EnzymeCore.make_zero!(prev::Array{T, N}, seen::ST)::Nothing where {T <: AbstractFloat, N, ST} +@inline function EnzymeCore.make_zero!( + prev::Array{T,N}, + seen::ST, +)::Nothing where {T<:AbstractFloat,N,ST} fill!(prev, zero(T)) nothing end -@inline function EnzymeCore.make_zero!(prev::Array{Complex{T}, N}, seen::ST)::Nothing where {T <: AbstractFloat, N, ST} +@inline function EnzymeCore.make_zero!( + prev::Array{Complex{T},N}, + seen::ST, +)::Nothing where {T<:AbstractFloat,N,ST} fill!(prev, zero(Complex{T})) nothing end -@inline function EnzymeCore.make_zero!(prev::Base.RefValue{T})::Nothing where {T <: AbstractFloat} +@inline function EnzymeCore.make_zero!( + prev::Base.RefValue{T}, +)::Nothing where {T<:AbstractFloat} EnzymeCore.make_zero!(prev, nothing) nothing end -@inline function EnzymeCore.make_zero!(prev::Base.RefValue{Complex{T}})::Nothing where {T <: AbstractFloat} +@inline function EnzymeCore.make_zero!( + prev::Base.RefValue{Complex{T}}, +)::Nothing where {T<:AbstractFloat} EnzymeCore.make_zero!(prev, nothing) nothing end -@inline function EnzymeCore.make_zero!(prev::Array{T, N})::Nothing where {T <: AbstractFloat, N} +@inline function EnzymeCore.make_zero!(prev::Array{T,N})::Nothing where {T<:AbstractFloat,N} EnzymeCore.make_zero!(prev, nothing) nothing end -@inline function EnzymeCore.make_zero!(prev::Array{Complex{T}, N})::Nothing where {T <: AbstractFloat, N} +@inline function EnzymeCore.make_zero!( + prev::Array{Complex{T},N}, +)::Nothing where {T<:AbstractFloat,N} EnzymeCore.make_zero!(prev, nothing) nothing end -@inline function EnzymeCore.make_zero!(prev::Array{T, N}, seen::ST)::Nothing where {T, N, ST} +@inline function EnzymeCore.make_zero!(prev::Array{T,N}, seen::ST)::Nothing where {T,N,ST} if guaranteed_const_nongen(T, nothing) return end @@ -1672,7 +2096,7 @@ end if isassigned(prev, I) pv = prev[I] SBT = Core.Typeof(pv) - if active_reg_inner(SBT, (), nothing, #=justActive=#Val(true)) == ActiveState + if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# @inbounds prev[I] = make_zero_immutable!(pv, seen) nothing else @@ -1684,7 +2108,10 @@ end nothing end -@inline function EnzymeCore.make_zero!(prev::Base.RefValue{T}, seen::ST)::Nothing where {T, ST} +@inline function EnzymeCore.make_zero!( + prev::Base.RefValue{T}, + seen::ST, +)::Nothing where {T,ST} if guaranteed_const_nongen(T, nothing) return end @@ -1695,7 +2122,7 @@ end pv = prev[] SBT = Core.Typeof(pv) - if active_reg_inner(SBT, (), nothing, #=justActive=#Val(true)) == ActiveState + if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# prev[] = make_zero_immutable!(pv, seen) nothing else @@ -1716,7 +2143,7 @@ end end push!(seen, prev) SBT = Core.Typeof(pv) - if active_reg_inner(SBT, (), nothing, #=justActive=#Val(true)) == ActiveState + if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# prev.contents = EnzymeCore.make_zero_immutable!(pv, seen) nothing else @@ -1726,7 +2153,10 @@ end nothing end -@inline function EnzymeCore.make_zero!(prev::T, seen::S=Base.IdSet{Any}())::Nothing where {T, S} +@inline function EnzymeCore.make_zero!( + prev::T, + seen::S = Base.IdSet{Any}(), +)::Nothing where {T,S} if guaranteed_const_nongen(T, nothing) return end @@ -1736,7 +2166,7 @@ end @assert !Base.isabstracttype(T) @assert Base.isconcretetype(T) nf = fieldcount(T) - + if nf == 0 return @@ -1744,14 +2174,14 @@ end push!(seen, prev) - for i in 1:nf + for i = 1:nf if isdefined(prev, i) xi = getfield(prev, i) SBT = Core.Typeof(xi) if guaranteed_const_nongen(SBT, nothing) continue end - if active_reg_inner(SBT, (), nothing, #=justActive=#Val(true)) == ActiveState + if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# setfield!(prev, i, make_zero_immutable!(xi, seen)) nothing else @@ -1763,7 +2193,7 @@ end return end -function emit_error(B::LLVM.IRBuilder, orig, string, errty=EnzymeRuntimeException) +function emit_error(B::LLVM.IRBuilder, orig, string, errty = EnzymeRuntimeException) curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -1777,17 +2207,22 @@ function emit_error(B::LLVM.IRBuilder, orig, string, errty=EnzymeRuntimeExceptio vt = LLVM.VoidType() ptr = convert(LLVMType, Ptr{Cvoid}) - exc, _ = get_function!(mod, "gpu_report_exception", LLVM.FunctionType(vt, [ptr])) + exc, _ = + get_function!(mod, "gpu_report_exception", LLVM.FunctionType(vt, [ptr])) string = ptrtoint!(B, string, ptr) call!(B, LLVM.function_type(exc), exc, [string]) - framefn, ft = get_function!(mod, "gpu_report_exception_frame", LLVM.FunctionType(vt, [LLVM.Int32Type(), ptr, ptr, LLVM.Int32Type()])) + framefn, ft = get_function!( + mod, + "gpu_report_exception_frame", + LLVM.FunctionType(vt, [LLVM.Int32Type(), ptr, ptr, LLVM.Int32Type()]), + ) if orig !== nothing bt = GPUCompiler.backtrace(orig) - for (i,frame) in enumerate(bt) + for (i, frame) in enumerate(bt) idx = ConstantInt(parameters(ft)[1], i) func = globalstring_ptr!(B, String(frame.func), "di_func") func = ptrtoint!(B, func, ptr) @@ -1797,27 +2232,42 @@ function emit_error(B::LLVM.IRBuilder, orig, string, errty=EnzymeRuntimeExceptio call!(B, ft, framefn, [idx, func, file, line]) end end - - sigfn, sigft = get_function!(mod, "gpu_signal_exception", LLVM.FunctionType(vt, LLVM.LLVMType[])) - call!(B, sigft, sigfn) - trap_ft = LLVM.FunctionType(LLVM.VoidType()) - trap = if haskey(functions(mod), "llvm.trap") - functions(mod)["llvm.trap"] - else - LLVM.Function(mod, "llvm.trap", trap_ft) - end - call!(B, trap_ft, trap) + + sigfn, sigft = get_function!( + mod, + "gpu_signal_exception", + LLVM.FunctionType(vt, LLVM.LLVMType[]), + ) + call!(B, sigft, sigfn) + trap_ft = LLVM.FunctionType(LLVM.VoidType()) + trap = if haskey(functions(mod), "llvm.trap") + functions(mod)["llvm.trap"] + else + LLVM.Function(mod, "llvm.trap", trap_ft) + end + call!(B, trap_ft, trap) else err = emit_allocobj!(B, errty) err2 = bitcast!(B, err, LLVM.PointerType(LLVM.PointerType(LLVM.Int8Type()), 10)) store!(B, string, err2) - emit_jl_throw!(B, addrspacecast!(B, err, LLVM.PointerType(LLVM.StructType(LLVMType[]), 12))) + emit_jl_throw!( + B, + addrspacecast!(B, err, LLVM.PointerType(LLVM.StructType(LLVMType[]), 12)), + ) end # 2. Call error function and insert unreachable - LLVM.API.LLVMAddCallSiteAttribute(ct, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), EnumAttribute("noreturn")) + LLVM.API.LLVMAddCallSiteAttribute( + ct, + reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), + EnumAttribute("noreturn"), + ) if EnzymeMutabilityException != errty - LLVM.API.LLVMAddCallSiteAttribute(ct, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), StringAttribute("enzyme_error")) + LLVM.API.LLVMAddCallSiteAttribute( + ct, + reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), + StringAttribute("enzyme_error"), + ) end return ct end @@ -1839,15 +2289,21 @@ function prepare_llvm(mod, job, meta) continue end llvmfn = functions(mod)[k_name] - + RT = Core.Compiler.typeinf_ext_toplevel(interp, mi).rettype _, _, returnRoots = get_return_info(RT) returnRoots = returnRoots !== nothing attributes = function_attributes(llvmfn) - push!(attributes, StringAttribute("enzymejl_mi", string(convert(UInt, pointer_from_objref(mi))))) - push!(attributes, StringAttribute("enzymejl_rt", string(convert(UInt, unsafe_to_pointer(RT))))) + push!( + attributes, + StringAttribute("enzymejl_mi", string(convert(UInt, pointer_from_objref(mi)))), + ) + push!( + attributes, + StringAttribute("enzymejl_rt", string(convert(UInt, unsafe_to_pointer(RT)))), + ) if returnRoots attr = StringAttribute("enzymejl_returnRoots", "") push!(parameter_attributes(llvmfn, 2), attr) @@ -1860,10 +2316,15 @@ function prepare_llvm(mod, job, meta) end end -function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, funcspec::Core.MethodInstance, world) +function nested_codegen!( + mode::API.CDerivativeMode, + mod::LLVM.Module, + funcspec::Core.MethodInstance, + world, +) # TODO: Put a cache here index on `mod` and f->tt - + # 3) Use the MI to create the correct augmented fwd/reverse # TODO: # - GPU support @@ -1871,16 +2332,23 @@ function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, funcspec:: target = DefaultCompilerTarget() params = PrimalCompilerParams(mode) - job = CompilerJob(funcspec, CompilerConfig(target, params; kernel=false), world) + job = CompilerJob(funcspec, CompilerConfig(target, params; kernel = false), world) # TODO parent_job = nothing - otherMod, meta = GPUCompiler.codegen(:llvm, job; optimize=false, cleanup=false, validate=false, parent_job=parent_job) + otherMod, meta = GPUCompiler.codegen( + :llvm, + job; + optimize = false, + cleanup = false, + validate = false, + parent_job = parent_job, + ) prepare_llvm(otherMod, job, meta) entry = name(meta.entry) - + for f in functions(otherMod) permit_inlining!(f) end @@ -1907,7 +2375,7 @@ function removed_ret_parms(F::LLVM.Function) parmrem = nothing retRemove = false for a in collect(function_attributes(F)) - if isa(a, StringAttribute) + if isa(a, StringAttribute) if kind(a) == "enzyme_parmremove" parmrem = a end @@ -1928,8 +2396,8 @@ end abstract type CompilationException <: Base.Exception end struct NoDerivativeException <: CompilationException msg::String - ir::Union{Nothing, String} - bt::Union{Nothing, Vector{StackTraces.StackFrame}} + ir::Union{Nothing,String} + bt::Union{Nothing,Vector{StackTraces.StackFrame}} end function Base.showerror(io::IO, ece::NoDerivativeException) @@ -1948,8 +2416,8 @@ end struct IllegalTypeAnalysisException <: CompilationException msg::String sval::String - ir::Union{Nothing, String} - bt::Union{Nothing, Vector{StackTraces.StackFrame}} + ir::Union{Nothing,String} + bt::Union{Nothing,Vector{StackTraces.StackFrame}} end function Base.showerror(io::IO, ece::IllegalTypeAnalysisException) @@ -1962,7 +2430,7 @@ function Base.showerror(io::IO, ece::IllegalTypeAnalysisException) write(io, ece.sval) print(io, '\n', ece.msg, '\n') if ece.bt !== nothing - print(io,"\nCaused by:") + print(io, "\nCaused by:") Base.show_backtrace(io, ece.bt) println(io) end @@ -1970,8 +2438,8 @@ end struct IllegalFirstPointerException <: CompilationException msg::String - ir::Union{Nothing, String} - bt::Union{Nothing, Vector{StackTraces.StackFrame}} + ir::Union{Nothing,String} + bt::Union{Nothing,Vector{StackTraces.StackFrame}} end function Base.showerror(io::IO, ece::IllegalFirstPointerException) @@ -1989,8 +2457,8 @@ end struct EnzymeInternalError <: CompilationException msg::String - ir::Union{Nothing, String} - bt::Union{Nothing, Vector{StackTraces.StackFrame}} + ir::Union{Nothing,String} + bt::Union{Nothing,Vector{StackTraces.StackFrame}} end function Base.showerror(io::IO, ece::EnzymeInternalError) @@ -2006,63 +2474,76 @@ function Base.showerror(io::IO, ece::EnzymeInternalError) end end -parent_scope(val::LLVM.Function, depth=0) = depth==0 ? LLVM.parent(val) : val -parent_scope(val::LLVM.Module, depth=0) = val -parent_scope(val::LLVM.Value, depth=0) = parent_scope(LLVM.parent(val), depth+1) -parent_scope(val::LLVM.Argument, depth=0) = parent_scope(LLVM.Function(LLVM.API.LLVMGetParamParent(val)), depth+1) +parent_scope(val::LLVM.Function, depth = 0) = depth == 0 ? LLVM.parent(val) : val +parent_scope(val::LLVM.Module, depth = 0) = val +parent_scope(val::LLVM.Value, depth = 0) = parent_scope(LLVM.parent(val), depth + 1) +parent_scope(val::LLVM.Argument, depth = 0) = + parent_scope(LLVM.Function(LLVM.API.LLVMGetParamParent(val)), depth + 1) -const CheckNan = Ref(false) -function julia_sanitize(orig::LLVM.API.LLVMValueRef, val::LLVM.API.LLVMValueRef, B::LLVM.API.LLVMBuilderRef, mask::LLVM.API.LLVMValueRef)::LLVM.API.LLVMValueRef - orig = LLVM.Value(orig) - val = LLVM.Value(val) - B = LLVM.IRBuilder(B) - if CheckNan[] - curent_bb = position(B) - fn = LLVM.parent(curent_bb) - mod = LLVM.parent(fn) - ty = LLVM.value_type(val) - vt = LLVM.VoidType() - FT = LLVM.FunctionType(vt, [ty, LLVM.PointerType(LLVM.Int8Type())]) +const CheckNan = Ref(false) +function julia_sanitize( + orig::LLVM.API.LLVMValueRef, + val::LLVM.API.LLVMValueRef, + B::LLVM.API.LLVMBuilderRef, + mask::LLVM.API.LLVMValueRef, +)::LLVM.API.LLVMValueRef + orig = LLVM.Value(orig) + val = LLVM.Value(val) + B = LLVM.IRBuilder(B) + if CheckNan[] + curent_bb = position(B) + fn = LLVM.parent(curent_bb) + mod = LLVM.parent(fn) + ty = LLVM.value_type(val) + vt = LLVM.VoidType() + FT = LLVM.FunctionType(vt, [ty, LLVM.PointerType(LLVM.Int8Type())]) - stringv = "Enzyme: Found nan while computing derivative of "*string(orig) - if orig !== nothing && isa(orig, LLVM.Instruction) - bt = GPUCompiler.backtrace(orig) - function printBT(io) - print(io,"\nCaused by:") - Base.show_backtrace(io, bt) + stringv = "Enzyme: Found nan while computing derivative of " * string(orig) + if orig !== nothing && isa(orig, LLVM.Instruction) + bt = GPUCompiler.backtrace(orig) + function printBT(io) + print(io, "\nCaused by:") + Base.show_backtrace(io, bt) + end + stringv *= sprint(io -> Base.show_backtrace(io, bt)) end - stringv*=sprint(io->Base.show_backtrace(io, bt)) - end - fn, _ = get_function!(mod, "julia.sanitize."*string(ty), FT) - if isempty(blocks(fn)) - let builder = IRBuilder() - entry = BasicBlock(fn, "entry") - good = BasicBlock(fn, "good") - bad = BasicBlock(fn, "bad") - position!(builder, entry) - inp, sval = collect(parameters(fn)) - cmp = fcmp!(builder, LLVM.API.LLVMRealUNO, inp, inp) + fn, _ = get_function!(mod, "julia.sanitize." * string(ty), FT) + if isempty(blocks(fn)) + let builder = IRBuilder() + entry = BasicBlock(fn, "entry") + good = BasicBlock(fn, "good") + bad = BasicBlock(fn, "bad") + position!(builder, entry) + inp, sval = collect(parameters(fn)) + cmp = fcmp!(builder, LLVM.API.LLVMRealUNO, inp, inp) - br!(builder, cmp, bad, good) + br!(builder, cmp, bad, good) - position!(builder, good) - ret!(builder) + position!(builder, good) + ret!(builder) - position!(builder, bad) + position!(builder, bad) - emit_error(builder, nothing, sval, EnzymeNoDerivativeError) - unreachable!(builder) - dispose(builder) + emit_error(builder, nothing, sval, EnzymeNoDerivativeError) + unreachable!(builder) + dispose(builder) + end end + # val = + call!(B, FT, fn, LLVM.Value[val, globalstring_ptr!(B, stringv)]) end - # val = - call!(B, FT, fn, LLVM.Value[val, globalstring_ptr!(B, stringv)]) - end - return val.ref + return val.ref end -function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.ErrorType, data::Ptr{Cvoid}, data2::LLVM.API.LLVMValueRef, B::LLVM.API.LLVMBuilderRef)::LLVM.API.LLVMValueRef +function julia_error( + cstr::Cstring, + val::LLVM.API.LLVMValueRef, + errtype::API.ErrorType, + data::Ptr{Cvoid}, + data2::LLVM.API.LLVMValueRef, + B::LLVM.API.LLVMBuilderRef, +)::LLVM.API.LLVMValueRef msg = Base.unsafe_string(cstr) bt = nothing ir = nothing @@ -2098,12 +2579,15 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err end if errtype == API.ET_NoDerivative - if occursin("No create nofree of empty function", msg) || occursin("No forward mode derivative found for", msg) || occursin("No augmented forward pass", msg) || occursin("No reverse pass found", msg) + if occursin("No create nofree of empty function", msg) || + occursin("No forward mode derivative found for", msg) || + occursin("No augmented forward pass", msg) || + occursin("No reverse pass found", msg) ir = nothing end if B != C_NULL B = IRBuilder(B) - msg2 = sprint() do io + msg2 = sprint() do io if ir !== nothing print(io, "Current scope: \n") print(io, ir) @@ -2124,7 +2608,7 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err msgN = sprint() do io::IO if isa(val, LLVM.Argument) fn = parent_scope(val) - ir = string(LLVM.name(fn))*string(function_type(fn)) + ir = string(LLVM.name(fn)) * string(function_type(fn)) print(io, "Current scope: \n") print(io, ir) end @@ -2137,7 +2621,7 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err end print(io, '\n', msg, '\n') if bt !== nothing - print(io,"\nCaused by:") + print(io, "\nCaused by:") Base.show_backtrace(io, bt) println(io) end @@ -2149,9 +2633,12 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err ip = API.EnzymeTypeAnalyzerToString(data) sval = Base.unsafe_string(ip) API.EnzymeStringFree(ip) - + if isa(val, LLVM.Instruction) - mi, rt = enzyme_custom_extract_mi(LLVM.parent(LLVM.parent(val))::LLVM.Function, #=error=#false) + mi, rt = enzyme_custom_extract_mi( + LLVM.parent(LLVM.parent(val))::LLVM.Function, + false, + ) #=error=# if mi !== nothing msg *= "\n" * string(mi) * "\n" end @@ -2160,7 +2647,7 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err elseif errtype == API.ET_NoType @assert B != C_NULL B = IRBuilder(B) - + data = API.EnzymeTypeAnalyzerRef(data) ip = API.EnzymeTypeAnalyzerToString(data) sval = Base.unsafe_string(ip) @@ -2177,12 +2664,12 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err end print(io, '\n', msg, '\n') if bt !== nothing - print(io,"\nCaused by:") + print(io, "\nCaused by:") Base.show_backtrace(io, bt) println(io) end pscope = parent_scope(val) - mi, rt = enzyme_custom_extract_mi(pscope, #=error=#false) + mi, rt = enzyme_custom_extract_mi(pscope, false) #=error=# if mi !== nothing println(io, "within ", mi) end @@ -2227,18 +2714,20 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err badval = nothing gutils = GradientUtils(API.EnzymeGradientUtilsRef(data)) # Ignore mismatched activity if phi/store of ghost - seen = Dict{LLVM.Value, LLVM.Value}() + seen = Dict{LLVM.Value,LLVM.Value}() illegal = false - created = LLVM.Instruction[] + created = LLVM.Instruction[] world = enzyme_extract_world(LLVM.parent(position(IRBuilder(B)))) - width = get_width(gutils) + width = get_width(gutils) function make_batched(cur, B) if width == 1 return cur else - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(cur)))) - for idx in 1:width - shadowres = insert_value!(B, shadowres, cur, idx-1) + shadowres = UndefValue( + LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(cur))), + ) + for idx = 1:width + shadowres = insert_value!(B, shadowres, cur, idx - 1) if isa(shadowres, LLVM.Instruction) push!(created, shadowres) end @@ -2254,8 +2743,8 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err if cur in keys(seen) return seen[cur] end - - legal, TT = abs_typeof(cur, true) + + legal, TT, byref = abs_typeof(cur, true) if legal if guaranteed_const_nongen(TT, world) return make_batched(ncur, prevbb) @@ -2264,16 +2753,21 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err legal2, obj = absint(cur) # Only do so for the immediate operand/etc to a phi, since otherwise we will make multiple - if legal2 && active_reg_inner(TT, (), world) == ActiveState && isa(cur, LLVM.ConstantExpr) && cur == data2 + if legal2 && + active_reg_inner(TT, (), world) == ActiveState && + isa(cur, LLVM.ConstantExpr) && + cur == data2 if width == 1 res = emit_allocobj!(prevbb, Base.RefValue{TT}) push!(created, res) return res else - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(cur)))) - for idx in 1:width + shadowres = UndefValue( + LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(cur))), + ) + for idx = 1:width res = emit_allocobj!(prevbb, Base.RefValue{TT}) - shadowres = insert_value!(prevbb, shadowres, res, idx-1) + shadowres = insert_value!(prevbb, shadowres, res, idx - 1) push!(created, shadowres) end return shadowres @@ -2281,15 +2775,15 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err end badval = if legal2 - string(obj)*" of type"*" "*string(TT) + string(obj) * " of type" * " " * string(TT) else - "Unknown object of type"*" "*string(TT) + "Unknown object of type" * " " * string(TT) end illegalVal = cur illegal = true return make_batched(ncur, prevbb) end - + if isa(cur, LLVM.PointerNull) return make_batched(ncur, prevbb) end @@ -2297,9 +2791,9 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err return make_batched(ncur, prevbb) end @static if LLVM.version() >= v"12" - if isa(cur, LLVM.PoisonValue) - return make_batched(ncur, prevbb) - end + if isa(cur, LLVM.PoisonValue) + return make_batched(ncur, prevbb) + end end if isa(cur, LLVM.ConstantAggregateZero) return make_batched(ncur, prevbb) @@ -2313,10 +2807,10 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err end end if isa(cur, LLVM.ConstantFP) - return make_batched(ConstantFP(value_type(cur), 0), prevbb) + return make_batched(ConstantFP(value_type(cur), 0), prevbb) end if isa(cur, LLVM.ConstantDataSequential) - cvals = LLVM.Value[] + cvals = LLVM.Value[] changed = false for v in collect(cur) tmp = make_replacement(v, prevbb) @@ -2340,20 +2834,23 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err return cur2 end if isa(cur, LLVM.ConstantInt) - if LLVM.width(value_type(cur)) <= sizeof(Int)*8 + if LLVM.width(value_type(cur)) <= sizeof(Int) * 8 return make_batched(ncur, prevbb) end - if LLVM.width(value_type(cur)) == sizeof(Int)*8 && abs(convert(Int, cur)) < 10000 + if LLVM.width(value_type(cur)) == sizeof(Int) * 8 && + abs(convert(Int, cur)) < 10000 return make_batched(ncur, prevbb) end # if storing a constant int as a non-pointer, presume it is not a GC'd var and is safe # for activity state to mix - if isa(val, LLVM.StoreInst) operands(val)[1] == cur && !isa(value_type(operands(val)[1]), LLVM.PointerType) + if isa(val, LLVM.StoreInst) + operands(val)[1] == cur && + !isa(value_type(operands(val)[1]), LLVM.PointerType) return make_batched(ncur, prevbb) end end - - if isa(cur, LLVM.SelectInst) + + if isa(cur, LLVM.SelectInst) lhs = make_replacement(operands(cur)[2], prevbb) if illegal return ncur @@ -2366,22 +2863,37 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err return make_batched(ncur, prevbb) end if width == 1 - nv = select!(prevbb, new_from_original(gutils, operands(cur)[1]), lhs, rhs) + nv = select!( + prevbb, + new_from_original(gutils, operands(cur)[1]), + lhs, + rhs, + ) push!(created, nv) seen[cur] = nv return nv else shadowres = LLVM.UndefValue(value_type(lhs)) - for idx in 1:width - shadowres = insert_value!(prevbb, shadowres, select!(prevbb, new_from_original(gutils, operands(cur)[1]), extract_value!(prevbb, lhs, idx-1), extract_value!(prevbb, rhs, idx-1)), idx-1) + for idx = 1:width + shadowres = insert_value!( + prevbb, + shadowres, + select!( + prevbb, + new_from_original(gutils, operands(cur)[1]), + extract_value!(prevbb, lhs, idx - 1), + extract_value!(prevbb, rhs, idx - 1), + ), + idx - 1, + ) if isa(shadowres, LLVM.Instruction) push!(created, shadowres) end end return shadowres end - end - + end + if isa(cur, LLVM.InsertValueInst) lhs = make_replacement(operands(cur)[1], prevbb) if illegal @@ -2396,7 +2908,7 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err end inds = LLVM.API.LLVMGetIndices(cur.ref) ninds = LLVM.API.LLVMGetNumIndices(cur.ref) - jinds = Cuint[unsafe_load(inds, i) for i in 1:ninds] + jinds = Cuint[unsafe_load(inds, i) for i = 1:ninds] if width == 1 nv = API.EnzymeInsertValue(prevbb, lhs, rhs, jinds) push!(created, nv) @@ -2404,10 +2916,15 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err return nv else shadowres = lhs - for idx in 1:width + for idx = 1:width jindsv = copy(jinds) - pushfirst!(jindsv, idx-1) - shadowres = API.EnzymeInsertValue(prevbb, shadowres, extract_value!(prevbb, rhs, idx-1), jindsv) + pushfirst!(jindsv, idx - 1) + shadowres = API.EnzymeInsertValue( + prevbb, + shadowres, + extract_value!(prevbb, rhs, idx - 1), + jindsv, + ) if isa(shadowres, LLVM.Instruction) push!(created, shadowres) end @@ -2415,15 +2932,15 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err return shadowres end end - + if isa(cur, LLVM.PHIInst) Bphi = IRBuilder() position!(Bphi, ncur) shadowty = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(cur))) - phi2 = phi!(Bphi, shadowty, "tempphi"*LLVM.name(cur)) + phi2 = phi!(Bphi, shadowty, "tempphi" * LLVM.name(cur)) seen[cur] = phi2 changed = false - recsize = length(created)+1 + recsize = length(created) + 1 for (v, bb) in LLVM.incoming(cur) B2 = IRBuilder() position!(B2, new_from_original(gutils, last(instructions(bb)))) @@ -2442,15 +2959,15 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err LLVM.API.LLVMInstructionEraseFromParent(phi2) seen[cur] = ncur plen = length(created) - for i in recsize:plen + for i = recsize:plen u = created[i] replace_uses!(u, LLVM.UndefValue(value_type(u))) end - for i in recsize:plen + for i = recsize:plen u = created[i] LLVM.API.LLVMInstructionEraseFromParent(u) end - for i in recsize:plen + for i = recsize:plen pop!(created) end return illegal ? ncur : make_batched(ncur, prevbb) @@ -2477,7 +2994,10 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err LLVM.API.LLVMInstructionEraseFromParent(u) end if LLVM.API.LLVMIsAReturnInst(val) != C_NULL - mi, rt = enzyme_custom_extract_mi(LLVM.parent(LLVM.parent(val))::LLVM.Function, #=error=#false) + mi, rt = enzyme_custom_extract_mi( + LLVM.parent(LLVM.parent(val))::LLVM.Function, + false, + ) #=error=# if mi !== nothing && isghostty(rt) return C_NULL end @@ -2486,7 +3006,7 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err print(io, msg) println(io) if badval !== nothing - println(io, " value="*badval) + println(io, " value=" * badval) else ttval = val if isa(ttval, LLVM.StoreInst) @@ -2499,7 +3019,7 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err API.EnzymeStringFree(st) end if illegalVal !== nothing - println(io, " llvalue="*string(illegalVal)) + println(io, " llvalue=" * string(illegalVal)) end if bt !== nothing Base.show_backtrace(io, bt) @@ -2512,9 +3032,9 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err B = IRBuilder(B) msg5 = sprint() do io::IO print(io, "Enzyme internal error\n") - print(io, msg, '\n') + print(io, msg, '\n') if bt !== nothing - print(io,"\nCaused by:") + print(io, "\nCaused by:") Base.show_backtrace(io, bt) println(io) end @@ -2535,7 +3055,7 @@ function any_jltypes(Type::LLVM.PointerType) end any_jltypes(Type::LLVM.StructType) = any(any_jltypes, LLVM.elements(Type)) -any_jltypes(Type::Union{LLVM.VectorType, LLVM.ArrayType}) = any_jltypes(eltype(Type)) +any_jltypes(Type::Union{LLVM.VectorType,LLVM.ArrayType}) = any_jltypes(eltype(Type)) any_jltypes(::LLVM.IntegerType) = false any_jltypes(::LLVM.FloatingPointType) = false any_jltypes(::LLVM.VoidType) = false @@ -2543,12 +3063,13 @@ any_jltypes(::LLVM.VoidType) = false @inline any_jltypes(::Type{Nothing}) = false @inline any_jltypes(::Type{T}) where {T<:AbstractFloat} = false @inline any_jltypes(::Type{T}) where {T<:Integer} = false -@inline any_jltypes(::Type{Complex{T}}) where T = any_jltypes(T) +@inline any_jltypes(::Type{Complex{T}}) where {T} = any_jltypes(T) @inline any_jltypes(::Type{Tuple{}}) = false -@inline any_jltypes(::Type{NTuple{Size, T}}) where {Size, T} = any_jltypes(T) -@inline any_jltypes(::Type{Core.LLVMPtr{T, Addr}}) where {T, Addr} = 10 <= Addr <= 12 +@inline any_jltypes(::Type{NTuple{Size,T}}) where {Size,T} = any_jltypes(T) +@inline any_jltypes(::Type{Core.LLVMPtr{T,Addr}}) where {T,Addr} = 10 <= Addr <= 12 @inline any_jltypes(::Type{Any}) = true -@inline any_jltypes(::Type{NamedTuple{A,B}}) where {A,B} = any(any_jltypes(b) for b in B.parameters) +@inline any_jltypes(::Type{NamedTuple{A,B}}) where {A,B} = + any(any_jltypes(b) for b in B.parameters) @inline any_jltypes(::Type{T}) where {T<:Tuple} = any(any_jltypes(b) for b in T.parameters) nfields(Type::LLVM.StructType) = length(LLVM.elements(Type)) @@ -2559,9 +3080,9 @@ nfields(Type::LLVM.PointerType) = 1 mutable struct EnzymeTapeToLoad{T} data::T end -Base.eltype(::EnzymeTapeToLoad{T}) where T = T +Base.eltype(::EnzymeTapeToLoad{T}) where {T} = T -const TapeTypes = Dict{String, DataType}() +const TapeTypes = Dict{String,DataType}() base_type(T::UnionAll) = base_type(T.body) base_type(T::DataType) = T @@ -2570,8 +3091,10 @@ const WideIntWidths = [256, 512, 1024, 2048] let for n ∈ WideIntWidths - let T = Symbol(:UInt,n) - eval(quote primitive type $T <: Unsigned $n end end) + let T = Symbol(:UInt, n) + eval(quote + primitive type $T <: Unsigned $n end + end) end end end @@ -2583,8 +3106,8 @@ function to_tape_type(Type::LLVM.API.LLVMTypeRef)::Tuple{DataType,Bool} nelems = LLVM.API.LLVMCountStructElementTypes(Type) containsAny = false syms = Symbol[] - for i in 1:nelems - e = LLVM.API.LLVMStructGetTypeAtIndex(Type, i-1) + for i = 1:nelems + e = LLVM.API.LLVMStructGetTypeAtIndex(Type, i - 1) T, sub = to_tape_type(e) containsAny |= sub push!(tys, T) @@ -2593,7 +3116,7 @@ function to_tape_type(Type::LLVM.API.LLVMTypeRef)::Tuple{DataType,Bool} Tup = Tuple{tys...} if containsAny res = (syms...,) - return NamedTuple{res, Tup}, false + return NamedTuple{res,Tup}, false else return Tup, false end @@ -2606,9 +3129,9 @@ function to_tape_type(Type::LLVM.API.LLVMTypeRef)::Tuple{DataType,Bool} e = LLVM.API.LLVMGetElementType(Type) tkind2 = LLVM.API.LLVMGetTypeKind(e) if tkind2 == LLVM.API.LLVMFunctionTypeKind - return Core.LLVMPtr{Cvoid, Int(addrspace)}, false + return Core.LLVMPtr{Cvoid,Int(addrspace)}, false else - return Core.LLVMPtr{to_tape_type(e)[1], Int(addrspace)}, false + return Core.LLVMPtr{to_tape_type(e)[1],Int(addrspace)}, false end end end @@ -2616,9 +3139,9 @@ function to_tape_type(Type::LLVM.API.LLVMTypeRef)::Tuple{DataType,Bool} e = LLVM.API.LLVMGetElementType(Type) T, sub = to_tape_type(e) len = Int(LLVM.API.LLVMGetArrayLength(Type)) - Tup = NTuple{len, T} + Tup = NTuple{len,T} if sub - return NamedTuple{ntuple(Core.Symbol, Val(len)), Tup}, false + return NamedTuple{ntuple(Core.Symbol, Val(len)),Tup}, false else return Tup, false end @@ -2627,9 +3150,9 @@ function to_tape_type(Type::LLVM.API.LLVMTypeRef)::Tuple{DataType,Bool} e = LLVM.API.LLVMGetElementType(Type) T, sub = to_tape_type(e) len = Int(LLVM.API.LLVMGetVectorSize(Type)) - Tup = NTuple{len, T} + Tup = NTuple{len,T} if sub - return NamedTuple{ntuple(Core.Symbol, Val(len)), Tup}, false + return NamedTuple{ntuple(Core.Symbol, Val(len)),Tup}, false else return Tup, false end @@ -2637,7 +3160,7 @@ function to_tape_type(Type::LLVM.API.LLVMTypeRef)::Tuple{DataType,Bool} if tkind == LLVM.API.LLVMIntegerTypeKind N = LLVM.API.LLVMGetIntTypeWidth(Type) if N == 1 - return Bool, false + return Bool, false elseif N == 8 return UInt8, false elseif N == 16 @@ -2683,10 +3206,12 @@ function tape_type(LLVMType::LLVM.LLVMType) return TT end -from_tape_type(::Type{T}) where T<:AbstractFloat = convert(LLVMType, T) -from_tape_type(::Type{T}) where T<:Integer = convert(LLVMType, T) -from_tape_type(::Type{NTuple{Size, T}}) where {Size, T} = LLVM.ArrayType(from_tape_type(T), Size) -from_tape_type(::Type{Core.LLVMPtr{T, Addr}}) where {T, Addr} = LLVM.PointerType(from_tape_type(UInt8), Addr) +from_tape_type(::Type{T}) where {T<:AbstractFloat} = convert(LLVMType, T) +from_tape_type(::Type{T}) where {T<:Integer} = convert(LLVMType, T) +from_tape_type(::Type{NTuple{Size,T}}) where {Size,T} = + LLVM.ArrayType(from_tape_type(T), Size) +from_tape_type(::Type{Core.LLVMPtr{T,Addr}}) where {T,Addr} = + LLVM.PointerType(from_tape_type(UInt8), Addr) # from_tape_type(::Type{Core.LLVMPtr{T, Addr}}, ctx) where {T, Addr} = LLVM.PointerType(from_tape_type(T, ctx), Addr) from_tape_type(::Type{Any}) = LLVM.PointerType(LLVM.StructType(LLVM.LLVMType[]), Tracked) function from_tape_type(::Type{NamedTuple{A,B}}) where {A,B} @@ -2702,10 +3227,12 @@ function from_tape_type(::Type{B}) where {B<:Tuple} end # See get_current_task_from_pgcstack (used from 1.7+) -current_task_offset() = -(unsafe_load(cglobal(:jl_task_gcstack_offset, Cint)) ÷ sizeof(Ptr{Cvoid})) +current_task_offset() = + -(unsafe_load(cglobal(:jl_task_gcstack_offset, Cint)) ÷ sizeof(Ptr{Cvoid})) # See get_current_ptls_from_task (used from 1.7+) -current_ptls_offset() = unsafe_load(cglobal(:jl_task_ptls_offset, Cint)) ÷ sizeof(Ptr{Cvoid}) +current_ptls_offset() = + unsafe_load(cglobal(:jl_task_ptls_offset, Cint)) ÷ sizeof(Ptr{Cvoid}) function store_nonjl_types!(B, startval, p) T_jlvalue = LLVM.StructType(LLVMType[]) @@ -2714,7 +3241,7 @@ function store_nonjl_types!(B, startval, p) if p != nothing push!(vals, p) end - todo = Tuple{Tuple, LLVM.Value}[((), startval)] + todo = Tuple{Tuple,LLVM.Value}[((), startval)] while length(todo) != 0 path, cur = popfirst!(todo) ty = value_type(cur) @@ -2725,9 +3252,9 @@ function store_nonjl_types!(B, startval, p) end if isa(ty, LLVM.ArrayType) if any_jltypes(ty) - for i=1:length(ty) - ev = extract_value!(B, cur, i-1) - push!(todo, ((path..., i-1), ev)) + for i = 1:length(ty) + ev = extract_value!(B, cur, i - 1) + push!(todo, ((path..., i - 1), ev)) end continue end @@ -2735,8 +3262,8 @@ function store_nonjl_types!(B, startval, p) if isa(ty, LLVM.StructType) if any_jltypes(ty) for (i, t) in enumerate(LLVM.elements(ty)) - ev = extract_value!(B, cur, i-1) - push!(todo, ((path..., i-1), ev)) + ev = extract_value!(B, cur, i - 1) + push!(todo, ((path..., i - 1), ev)) end continue end @@ -2751,7 +3278,7 @@ function store_nonjl_types!(B, startval, p) return end -function get_julia_inner_types(B, p, startvals...; added=LLVM.API.LLVMValueRef[]) +function get_julia_inner_types(B, p, startvals...; added = LLVM.API.LLVMValueRef[]) T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) vals = LLVM.Value[] @@ -2765,7 +3292,12 @@ function get_julia_inner_types(B, p, startvals...; added=LLVM.API.LLVMValueRef[] if isa(ty, LLVM.PointerType) if any_jltypes(ty) if addrspace(ty) != Tracked - cur = addrspacecast!(B, cur, LLVM.PointerType(eltype(ty), Tracked), LLVM.name(cur)*".innertracked") + cur = addrspacecast!( + B, + cur, + LLVM.PointerType(eltype(ty), Tracked), + LLVM.name(cur) * ".innertracked", + ) if isa(cur, LLVM.Instruction) push!(added, cur.ref) end @@ -2782,8 +3314,8 @@ function get_julia_inner_types(B, p, startvals...; added=LLVM.API.LLVMValueRef[] end if isa(ty, LLVM.ArrayType) if any_jltypes(ty) - for i=1:length(ty) - ev = extract_value!(B, cur, i-1) + for i = 1:length(ty) + ev = extract_value!(B, cur, i - 1) if isa(ev, LLVM.Instruction) push!(added, ev.ref) end @@ -2795,7 +3327,7 @@ function get_julia_inner_types(B, p, startvals...; added=LLVM.API.LLVMValueRef[] if isa(ty, LLVM.StructType) for (i, t) in enumerate(LLVM.elements(ty)) if any_jltypes(t) - ev = extract_value!(B, cur, i-1) + ev = extract_value!(B, cur, i - 1) if isa(ev, LLVM.Instruction) push!(added, ev.ref) end @@ -2822,14 +3354,20 @@ function get_julia_inner_types(B, p, startvals...; added=LLVM.API.LLVMValueRef[] return vals end -function julia_post_cache_store(SI::LLVM.API.LLVMValueRef, B::LLVM.API.LLVMBuilderRef, R2)::Ptr{LLVM.API.LLVMValueRef} +function julia_post_cache_store( + SI::LLVM.API.LLVMValueRef, + B::LLVM.API.LLVMBuilderRef, + R2, +)::Ptr{LLVM.API.LLVMValueRef} B = LLVM.IRBuilder(B) SI = LLVM.Instruction(SI) v = operands(SI)[1] p = operands(SI)[2] added = LLVM.API.LLVMValueRef[] while true - if isa(p, LLVM.GetElementPtrInst) || isa(p, LLVM.BitCastInst) || isa(p, LLVM.AddrSpaceCastInst) + if isa(p, LLVM.GetElementPtrInst) || + isa(p, LLVM.BitCastInst) || + isa(p, LLVM.AddrSpaceCastInst) p = operands(p)[1] continue end @@ -2845,14 +3383,17 @@ function julia_post_cache_store(SI::LLVM.API.LLVMValueRef, B::LLVM.API.LLVMBuild end p = pn - vals = get_julia_inner_types(B, p, v, added=added) + vals = get_julia_inner_types(B, p, v, added = added) r = emit_writebarrier!(B, vals) @assert isa(r, LLVM.Instruction) push!(added, r.ref) end if R2 != C_NULL unsafe_store!(R2, length(added)) - ptr = Base.unsafe_convert(Ptr{LLVM.API.LLVMValueRef}, Libc.malloc(sizeof(LLVM.API.LLVMValueRef)*length(added))) + ptr = Base.unsafe_convert( + Ptr{LLVM.API.LLVMValueRef}, + Libc.malloc(sizeof(LLVM.API.LLVMValueRef) * length(added)), + ) for (i, v) in enumerate(added) @assert isa(LLVM.Value(v), LLVM.Instruction) unsafe_store!(ptr, v, i) @@ -2868,7 +3409,11 @@ function julia_default_tape_type(C::LLVM.API.LLVMContextRef) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) return T_prjlvalue.ref end -function julia_undef_value_for_type(mod::LLVM.API.LLVMModuleRef, Ty::LLVM.API.LLVMTypeRef, forceZero::UInt8)::LLVM.API.LLVMValueRef +function julia_undef_value_for_type( + mod::LLVM.API.LLVMModuleRef, + Ty::LLVM.API.LLVMTypeRef, + forceZero::UInt8, +)::LLVM.API.LLVMValueRef ty = LLVM.LLVMType(Ty) if !any_jltypes(ty) if forceZero != 0 @@ -2889,7 +3434,7 @@ function julia_undef_value_for_type(mod::LLVM.API.LLVMModuleRef, Ty::LLVM.API.LL end if isa(ty, LLVM.ArrayType) st = LLVM.Value(julia_undef_value_for_type(mod, eltype(ty).ref, forceZero)) - return ConstantArray(eltype(ty), [st for i in 1:length(ty)]).ref + return ConstantArray(eltype(ty), [st for i = 1:length(ty)]).ref end if isa(ty, LLVM.StructType) vals = LLVM.Constant[] @@ -2905,11 +3450,13 @@ function shadow_alloc_rewrite(V::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradie V = LLVM.CallInst(V) gutils = GradientUtils(gutils) mode = get_mode(gutils) - if mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient || mode == API.DEM_ReverseModeCombined + if mode == API.DEM_ReverseModePrimal || + mode == API.DEM_ReverseModeGradient || + mode == API.DEM_ReverseModeCombined fn = LLVM.parent(LLVM.parent(V)) world = enzyme_extract_world(fn) - has, Ty = abs_typeof(V) - @assert has + has, Ty, byref = abs_typeof(V) + @assert has rt = active_reg_inner(Ty, (), world) if rt == ActiveState || rt == MixedState B = LLVM.IRBuilder() @@ -2920,7 +3467,14 @@ function shadow_alloc_rewrite(V::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradie nothing end -function julia_allocator(B::LLVM.API.LLVMBuilderRef, LLVMType::LLVM.API.LLVMTypeRef, Count::LLVM.API.LLVMValueRef, AlignedSize::LLVM.API.LLVMValueRef, IsDefault::UInt8, ZI) +function julia_allocator( + B::LLVM.API.LLVMBuilderRef, + LLVMType::LLVM.API.LLVMTypeRef, + Count::LLVM.API.LLVMValueRef, + AlignedSize::LLVM.API.LLVMValueRef, + IsDefault::UInt8, + ZI, +) B = LLVM.IRBuilder(B) Count = LLVM.Value(Count) AlignedSize = LLVM.Value(AlignedSize) @@ -2972,7 +3526,11 @@ function zero_single_allocation(builder, jlType, LLVMType, nobj, zeroAll, idx) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) T_prjlvalue_UT = LLVM.PointerType(T_jlvalue) - todo = Tuple{Vector{LLVM.Value},LLVM.LLVMType,DataType}[(LLVM.Value[idx], LLVMType, jlType)] + todo = Tuple{Vector{LLVM.Value},LLVM.LLVMType,DataType}[( + LLVM.Value[idx], + LLVMType, + jlType, + )] while length(todo) != 0 path, ty, jlty = popfirst!(todo) @@ -2981,7 +3539,11 @@ function zero_single_allocation(builder, jlType, LLVMType, nobj, zeroAll, idx) loc = gep!(builder, LLVMType, nobj, path) mod = LLVM.parent(LLVM.parent(Base.position(builder))) fill_val = unsafe_nothing_to_llvm(mod) - loc = bitcast!(builder, loc, LLVM.PointerType(T_prjlvalue, addrspace(value_type(loc)))) + loc = bitcast!( + builder, + loc, + LLVM.PointerType(T_prjlvalue, addrspace(value_type(loc))), + ) store!(builder, fill_val, loc) elseif zeroAll loc = gep!(builder, LLVMType, nobj, path) @@ -2996,36 +3558,36 @@ function zero_single_allocation(builder, jlType, LLVMType, nobj, zeroAll, idx) end continue end - if isa(ty, LLVM.ArrayType) - for i=1:length(ty) + if isa(ty, LLVM.ArrayType) + for i = 1:length(ty) npath = copy(path) - push!(npath, LLVM.ConstantInt(LLVM.IntType(32), i-1)) + push!(npath, LLVM.ConstantInt(LLVM.IntType(32), i - 1)) push!(todo, (npath, eltype(ty), eltype(jlty))) end continue end - if isa(ty, LLVM.VectorType) - for i=1:size(ty) + if isa(ty, LLVM.VectorType) + for i = 1:size(ty) npath = copy(path) - push!(npath, LLVM.ConstantInt(LLVM.IntType(32), i-1)) + push!(npath, LLVM.ConstantInt(LLVM.IntType(32), i - 1)) push!(todo, (npath, eltype(ty), eltype(jlty))) end continue end if isa(ty, LLVM.StructType) i = 1 - for ii in 1:fieldcount(jlty) + for ii = 1:fieldcount(jlty) jlet = fieldtype(jlty, ii) if isghostty(jlet) || Core.Compiler.isconstType(jlet) continue end t = LLVM.elements(ty)[i] npath = copy(path) - push!(npath, LLVM.ConstantInt(LLVM.IntType(32), i-1)) + push!(npath, LLVM.ConstantInt(LLVM.IntType(32), i - 1)) push!(todo, (npath, t, jlet)) - i+=1 + i += 1 end - @assert i == Int(length(LLVM.elements(ty)))+1 + @assert i == Int(length(LLVM.elements(ty))) + 1 continue end end @@ -3034,7 +3596,15 @@ function zero_single_allocation(builder, jlType, LLVMType, nobj, zeroAll, idx) end -function zero_allocation(B::LLVM.IRBuilder, jlType, LLVMType, obj, AlignedSize, Size, zeroAll::Bool)::LLVM.API.LLVMValueRef +function zero_allocation( + B::LLVM.IRBuilder, + jlType, + LLVMType, + obj, + AlignedSize, + Size, + zeroAll::Bool, +)::LLVM.API.LLVMValueRef func = LLVM.parent(position(B)) mod = LLVM.parent(func) T_int8 = LLVM.Int8Type() @@ -3043,7 +3613,11 @@ function zero_allocation(B::LLVM.IRBuilder, jlType, LLVMType, obj, AlignedSize, T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) T_prjlvalue_UT = LLVM.PointerType(T_jlvalue) - wrapper_f = LLVM.Function(mod, "zeroType", LLVM.FunctionType(LLVM.VoidType(), [value_type(obj), T_int8, value_type(Size)])) + wrapper_f = LLVM.Function( + mod, + "zeroType", + LLVM.FunctionType(LLVM.VoidType(), [value_type(obj), T_int8, value_type(Size)]), + ) push!(function_attributes(wrapper_f), StringAttribute("enzyme_math", "enzyme_zerotype")) push!(function_attributes(wrapper_f), StringAttribute("enzyme_inactive")) push!(function_attributes(wrapper_f), StringAttribute("enzyme_no_escaping_allocation")) @@ -3064,24 +3638,46 @@ function zero_allocation(B::LLVM.IRBuilder, jlType, LLVMType, obj, AlignedSize, exit = BasicBlock(wrapper_f, "exit") position!(builder, entry) nobj, _, nsize = collect(parameters(wrapper_f)) - nobj = pointercast!(builder, nobj, LLVM.PointerType(LLVMType, addrspace(value_type(nobj)))) + nobj = pointercast!( + builder, + nobj, + LLVM.PointerType(LLVMType, addrspace(value_type(nobj))), + ) LLVM.br!(builder, loop) position!(builder, loop) idx = LLVM.phi!(builder, value_type(Size)) inc = add!(builder, idx, LLVM.ConstantInt(value_type(Size), 1)) - append!(LLVM.incoming(idx), [(LLVM.ConstantInt(value_type(Size), 0), entry), (inc, loop)]) + append!( + LLVM.incoming(idx), + [(LLVM.ConstantInt(value_type(Size), 0), entry), (inc, loop)], + ) zero_single_allocation(builder, jlType, LLVMType, nobj, zeroAll, idx) - br!(builder, icmp!(builder, LLVM.API.LLVMIntEQ, inc, LLVM.Value(LLVM.API.LLVMBuildExactUDiv(builder, nsize, AlignedSize, ""))), exit, loop) + br!( + builder, + icmp!( + builder, + LLVM.API.LLVMIntEQ, + inc, + LLVM.Value(LLVM.API.LLVMBuildExactUDiv(builder, nsize, AlignedSize, "")), + ), + exit, + loop, + ) position!(builder, exit) ret!(builder) dispose(builder) end - return call!(B, LLVM.function_type(wrapper_f), wrapper_f, [obj, LLVM.ConstantInt(T_int8, 0), Size]).ref + return call!( + B, + LLVM.function_type(wrapper_f), + wrapper_f, + [obj, LLVM.ConstantInt(T_int8, 0), Size], + ).ref end function julia_allocator(B, LLVMType, Count, AlignedSize, IsDefault, ZI) @@ -3102,7 +3698,8 @@ function julia_allocator(B, LLVMType, Count, AlignedSize, IsDefault, ZI) TT = tape_type(LLVMType) if esizeof(TT) != convert(Int, AlignedSize) - GPUCompiler.@safe_error "Enzyme aligned size and Julia size disagree" AlignedSize=convert(Int, AlignedSize) esizeof(TT) fieldtypes(TT) + GPUCompiler.@safe_error "Enzyme aligned size and Julia size disagree" AlignedSize = + convert(Int, AlignedSize) esizeof(TT) fieldtypes(TT) emit_error(B, nothing, "Enzyme: Tape allocation failed.") # TODO: Pick appropriate orig return LLVM.API.LLVMValueRef(LLVM.UndefValue(LLVMType).ref) end @@ -3110,9 +3707,11 @@ function julia_allocator(B, LLVMType, Count, AlignedSize, IsDefault, ZI) if Count isa LLVM.ConstantInt N = convert(Int, Count) - ETT = N == 1 ? TT : NTuple{N, TT} - if sizeof(ETT) != N*convert(Int, AlignedSize) - GPUCompiler.@safe_error "Size of Enzyme tape is incorrect. Please report this issue" ETT sizeof(ETT) TargetSize = N*convert(Int, AlignedSize) LLVMType + ETT = N == 1 ? TT : NTuple{N,TT} + if sizeof(ETT) != N * convert(Int, AlignedSize) + GPUCompiler.@safe_error "Size of Enzyme tape is incorrect. Please report this issue" ETT sizeof( + ETT, + ) TargetSize = N * convert(Int, AlignedSize) LLVMType emit_error(B, nothing, "Enzyme: Tape allocation failed.") # TODO: Pick appropriate orig return LLVM.API.LLVMValueRef(LLVM.UndefValue(LLVMType).ref) @@ -3137,7 +3736,8 @@ function julia_allocator(B, LLVMType, Count, AlignedSize, IsDefault, ZI) @static if VERSION >= v"1.10.5" needs_dynamic_size_workaround = false else - needs_dynamic_size_workaround = !isa(Size, LLVM.ConstantInt) || convert(Int, Size) != 1 + needs_dynamic_size_workaround = + !isa(Size, LLVM.ConstantInt) || convert(Int, Size) != 1 end T_size_t = convert(LLVM.LLVMType, Int) @@ -3150,12 +3750,16 @@ function julia_allocator(B, LLVMType, Count, AlignedSize, IsDefault, ZI) obj = emit_allocobj!(B, tag, allocSize, needs_dynamic_size_workaround) if ZI != C_NULL - unsafe_store!(ZI, zero_allocation(B, TT, LLVMType, obj, AlignedSize, Size, #=ZeroAll=#false)) + unsafe_store!( + ZI, + zero_allocation(B, TT, LLVMType, obj, AlignedSize, Size, false), + ) #=ZeroAll=# end AS = Tracked else ptr8 = LLVM.PointerType(LLVM.IntType(8)) - mallocF, fty = get_function!(mod, "malloc", LLVM.FunctionType(ptr8, [value_type(Count)])) + mallocF, fty = + get_function!(mod, "malloc", LLVM.FunctionType(ptr8, [value_type(Count)])) obj = call!(B, fty, mallocF, [Size]) # if ZI != C_NULL @@ -3166,13 +3770,29 @@ function julia_allocator(B, LLVMType, Count, AlignedSize, IsDefault, ZI) AS = 0 end - LLVM.API.LLVMAddCallSiteAttribute(obj, LLVM.API.LLVMAttributeReturnIndex, EnumAttribute("noalias")) - LLVM.API.LLVMAddCallSiteAttribute(obj, LLVM.API.LLVMAttributeReturnIndex, EnumAttribute("nonnull")) + LLVM.API.LLVMAddCallSiteAttribute( + obj, + LLVM.API.LLVMAttributeReturnIndex, + EnumAttribute("noalias"), + ) + LLVM.API.LLVMAddCallSiteAttribute( + obj, + LLVM.API.LLVMAttributeReturnIndex, + EnumAttribute("nonnull"), + ) if isa(Count, LLVM.ConstantInt) val = convert(UInt, AlignedSize) val *= convert(UInt, Count) - LLVM.API.LLVMAddCallSiteAttribute(obj, LLVM.API.LLVMAttributeReturnIndex, EnumAttribute("dereferenceable", val)) - LLVM.API.LLVMAddCallSiteAttribute(obj, LLVM.API.LLVMAttributeReturnIndex, EnumAttribute("dereferenceable_or_null", val)) + LLVM.API.LLVMAddCallSiteAttribute( + obj, + LLVM.API.LLVMAttributeReturnIndex, + EnumAttribute("dereferenceable", val), + ) + LLVM.API.LLVMAddCallSiteAttribute( + obj, + LLVM.API.LLVMAttributeReturnIndex, + EnumAttribute("dereferenceable_or_null", val), + ) end mem = pointercast!(B, obj, LLVM.PointerType(LLVMType, AS)) @@ -3195,7 +3815,11 @@ function julia_deallocator(B::LLVM.IRBuilder, Obj::LLVM.Value) ptr8 = LLVM.PointerType(LLVM.IntType(8)) freeF, fty = get_function!(mod, "free", LLVM.FunctionType(T_void, [ptr8])) callf = call!(B, fty, freeF, [pointercast!(B, Obj, ptr8)]) - LLVM.API.LLVMAddCallSiteAttribute(callf, LLVM.API.LLVMAttributeIndex(1), EnumAttribute("nonnull")) + LLVM.API.LLVMAddCallSiteAttribute( + callf, + LLVM.API.LLVMAttributeIndex(1), + EnumAttribute("nonnull"), + ) end return LLVM.API.LLVMValueRef(callf.ref) end @@ -3208,10 +3832,14 @@ function emit_inacterror(B, V, orig) mod = LLVM.parent(fn) bt = GPUCompiler.backtrace(orig) - bts = sprint(io->Base.show_backtrace(io, bt)) - fmt = globalstring_ptr!(B, "%s:\nBacktrace\n"*bts) + bts = sprint(io -> Base.show_backtrace(io, bt)) + fmt = globalstring_ptr!(B, "%s:\nBacktrace\n" * bts) - funcT = LLVM.FunctionType(LLVM.VoidType(), LLVMType[LLVM.PointerType(LLVM.Int8Type())], vararg=true) + funcT = LLVM.FunctionType( + LLVM.VoidType(), + LLVMType[LLVM.PointerType(LLVM.Int8Type())], + vararg = true, + ) func, _ = get_function!(mod, "jl_errorf", funcT, [EnumAttribute("noreturn")]) call!(B, funcT, func, LLVM.Value[fmt, LLVM.Value(V)]) @@ -3224,31 +3852,22 @@ include("rules/llvmrules.jl") for (k, v) in ( ("enz_runtime_newtask_fwd", Enzyme.Compiler.runtime_newtask_fwd), ("enz_runtime_newtask_augfwd", Enzyme.Compiler.runtime_newtask_augfwd), - ("enz_runtime_generic_fwd", Enzyme.Compiler.runtime_generic_fwd), ("enz_runtime_generic_augfwd", Enzyme.Compiler.runtime_generic_augfwd), ("enz_runtime_generic_rev", Enzyme.Compiler.runtime_generic_rev), - ("enz_runtime_iterate_fwd", Enzyme.Compiler.runtime_iterate_fwd), ("enz_runtime_iterate_augfwd", Enzyme.Compiler.runtime_iterate_augfwd), ("enz_runtime_iterate_rev", Enzyme.Compiler.runtime_iterate_rev), - ("enz_runtime_newstruct_augfwd", Enzyme.Compiler.runtime_newstruct_augfwd), ("enz_runtime_newstruct_rev", Enzyme.Compiler.runtime_newstruct_rev), - ("enz_runtime_tuple_augfwd", Enzyme.Compiler.runtime_tuple_augfwd), ("enz_runtime_tuple_rev", Enzyme.Compiler.runtime_tuple_rev), - - ("enz_runtime_jl_getfield_aug", Enzyme.Compiler.rt_jl_getfield_aug), ("enz_runtime_jl_getfield_rev", Enzyme.Compiler.rt_jl_getfield_rev), - ("enz_runtime_idx_jl_getfield_aug", Enzyme.Compiler.idx_jl_getfield_aug), ("enz_runtime_idx_jl_getfield_rev", Enzyme.Compiler.idx_jl_getfield_aug), - ("enz_runtime_jl_setfield_aug", Enzyme.Compiler.rt_jl_setfield_aug), ("enz_runtime_jl_setfield_rev", Enzyme.Compiler.rt_jl_setfield_rev), - ("enz_runtime_error_if_differentiable", Enzyme.Compiler.error_if_differentiable), ("enz_runtime_error_if_active", Enzyme.Compiler.error_if_active), ) @@ -3258,30 +3877,103 @@ end function __init__() API.memmove_warning!(false) API.typeWarning!(false) - API.EnzymeSetHandler(@cfunction(julia_error, LLVM.API.LLVMValueRef, (Cstring, LLVM.API.LLVMValueRef, API.ErrorType, Ptr{Cvoid}, LLVM.API.LLVMValueRef, LLVM.API.LLVMBuilderRef))) - API.EnzymeSetSanitizeDerivatives(@cfunction(julia_sanitize, LLVM.API.LLVMValueRef, (LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef))); - API.EnzymeSetRuntimeInactiveError(@cfunction(emit_inacterror, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef))) - API.EnzymeSetDefaultTapeType(@cfunction( - julia_default_tape_type, LLVM.API.LLVMTypeRef, (LLVM.API.LLVMContextRef,))) - API.EnzymeSetCustomAllocator(@cfunction( - julia_allocator, LLVM.API.LLVMValueRef, - (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMTypeRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef, UInt8, Ptr{LLVM.API.LLVMValueRef}))) - API.EnzymeSetCustomDeallocator(@cfunction( - julia_deallocator, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef))) - API.EnzymeSetPostCacheStore(@cfunction( - julia_post_cache_store, Ptr{LLVM.API.LLVMValueRef}, - (LLVM.API.LLVMValueRef, LLVM.API.LLVMBuilderRef, Ptr{UInt64}))) - - API.EnzymeSetCustomZero(@cfunction( - zero_allocation, Cvoid, - (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMTypeRef, LLVM.API.LLVMValueRef, UInt8))) - API.EnzymeSetFixupReturn(@cfunction( - fixup_return, LLVM.API.LLVMValueRef, - (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef))) - API.EnzymeSetUndefinedValueForType(@cfunction( - julia_undef_value_for_type, LLVM.API.LLVMValueRef, (LLVM.API.LLVMModuleRef, LLVM.API.LLVMTypeRef,UInt8))) - API.EnzymeSetShadowAllocRewrite(@cfunction( - shadow_alloc_rewrite, Cvoid, (LLVM.API.LLVMValueRef,API.EnzymeGradientUtilsRef))) + API.EnzymeSetHandler( + @cfunction( + julia_error, + LLVM.API.LLVMValueRef, + ( + Cstring, + LLVM.API.LLVMValueRef, + API.ErrorType, + Ptr{Cvoid}, + LLVM.API.LLVMValueRef, + LLVM.API.LLVMBuilderRef, + ) + ) + ) + API.EnzymeSetSanitizeDerivatives( + @cfunction( + julia_sanitize, + LLVM.API.LLVMValueRef, + ( + LLVM.API.LLVMValueRef, + LLVM.API.LLVMValueRef, + LLVM.API.LLVMBuilderRef, + LLVM.API.LLVMValueRef, + ) + ) + ) + API.EnzymeSetRuntimeInactiveError( + @cfunction( + emit_inacterror, + Cvoid, + (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef) + ) + ) + API.EnzymeSetDefaultTapeType( + @cfunction( + julia_default_tape_type, + LLVM.API.LLVMTypeRef, + (LLVM.API.LLVMContextRef,) + ) + ) + API.EnzymeSetCustomAllocator( + @cfunction( + julia_allocator, + LLVM.API.LLVMValueRef, + ( + LLVM.API.LLVMBuilderRef, + LLVM.API.LLVMTypeRef, + LLVM.API.LLVMValueRef, + LLVM.API.LLVMValueRef, + UInt8, + Ptr{LLVM.API.LLVMValueRef}, + ) + ) + ) + API.EnzymeSetCustomDeallocator( + @cfunction( + julia_deallocator, + LLVM.API.LLVMValueRef, + (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef) + ) + ) + API.EnzymeSetPostCacheStore( + @cfunction( + julia_post_cache_store, + Ptr{LLVM.API.LLVMValueRef}, + (LLVM.API.LLVMValueRef, LLVM.API.LLVMBuilderRef, Ptr{UInt64}) + ) + ) + + API.EnzymeSetCustomZero( + @cfunction( + zero_allocation, + Cvoid, + (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMTypeRef, LLVM.API.LLVMValueRef, UInt8) + ) + ) + API.EnzymeSetFixupReturn( + @cfunction( + fixup_return, + LLVM.API.LLVMValueRef, + (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef) + ) + ) + API.EnzymeSetUndefinedValueForType( + @cfunction( + julia_undef_value_for_type, + LLVM.API.LLVMValueRef, + (LLVM.API.LLVMModuleRef, LLVM.API.LLVMTypeRef, UInt8) + ) + ) + API.EnzymeSetShadowAllocRewrite( + @cfunction( + shadow_alloc_rewrite, + Cvoid, + (LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef) + ) + ) register_alloc_rules() register_llvm_rules() @@ -3291,8 +3983,7 @@ function __init__() end # Define EnzymeTarget -Base.@kwdef struct EnzymeTarget <: AbstractCompilerTarget -end +Base.@kwdef struct EnzymeTarget <: AbstractCompilerTarget end GPUCompiler.llvm_triple(::EnzymeTarget) = LLVM.triple(JIT.get_jit()) GPUCompiler.llvm_datalayout(::EnzymeTarget) = LLVM.datalayout(JIT.get_jit()) @@ -3301,20 +3992,19 @@ function GPUCompiler.llvm_machine(::EnzymeTarget) return JIT.get_tm() end -module Runtime -end +module Runtime end abstract type AbstractEnzymeCompilerParams <: AbstractCompilerParams end struct EnzymeCompilerParams <: AbstractEnzymeCompilerParams TT::Type{<:Tuple} mode::API.CDerivativeMode width::Int - rt::Type{<:Annotation{T} where T} + rt::Type{<:Annotation{T} where {T}} run_enzyme::Bool abiwrap::Bool # Whether, in split mode, acessible primal argument data is modified # between the call and the split - modifiedBetween::NTuple{N, Bool} where N + modifiedBetween::NTuple{N,Bool} where {N} # Whether to also return the primal returnPrimal::Bool # Whether to (in aug fwd) += by one @@ -3335,7 +4025,8 @@ struct PrimalCompilerParams <: AbstractEnzymeCompilerParams mode::API.CDerivativeMode end -DefaultCompilerTarget(;kwargs...) = GPUCompiler.NativeCompilerTarget(;jlruntime=true, kwargs...) +DefaultCompilerTarget(; kwargs...) = + GPUCompiler.NativeCompilerTarget(; jlruntime = true, kwargs...) ## job @@ -3352,41 +4043,55 @@ GPUCompiler.runtime_slug(job::CompilerJob{EnzymeTarget}) = "enzyme" # provide a specific interpreter to use. if VERSION >= v"1.11.0-DEV.1552" -struct EnzymeCacheToken - target_type::Type - always_inline - method_table::Core.MethodTable - param_type::Type - is_fwd::Bool -end - -GPUCompiler.ci_cache_token(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = - EnzymeCacheToken( - typeof(job.config.target), job.config.always_inline, GPUCompiler.method_table(job), - typeof(job.config.params), job.config.params.mode == API.DEM_ForwardMode, - ) + struct EnzymeCacheToken + target_type::Type + always_inline::Any + method_table::Core.MethodTable + param_type::Type + is_fwd::Bool + end -GPUCompiler.get_interpreter(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = - Interpreter.EnzymeInterpreter(GPUCompiler.ci_cache_token(job), GPUCompiler.method_table(job), job.world, job.config.params.mode) + GPUCompiler.ci_cache_token(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = + EnzymeCacheToken( + typeof(job.config.target), + job.config.always_inline, + GPUCompiler.method_table(job), + typeof(job.config.params), + job.config.params.mode == API.DEM_ForwardMode, + ) + + GPUCompiler.get_interpreter(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = + Interpreter.EnzymeInterpreter( + GPUCompiler.ci_cache_token(job), + GPUCompiler.method_table(job), + job.world, + job.config.params.mode, + ) else -# the codeinstance cache to use -- should only be used for the constructor -# Note that the only way the interpreter modifies codegen is either not inlining a fwd mode -# rule or not inlining a rev mode rule. Otherwise, all caches can be re-used. -const GLOBAL_FWD_CACHE = GPUCompiler.CodeCache() -const GLOBAL_REV_CACHE = GPUCompiler.CodeCache() -function enzyme_ci_cache(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) - return if job.config.params.mode == API.DEM_ForwardMode - GLOBAL_FWD_CACHE - else - GLOBAL_REV_CACHE + # the codeinstance cache to use -- should only be used for the constructor + # Note that the only way the interpreter modifies codegen is either not inlining a fwd mode + # rule or not inlining a rev mode rule. Otherwise, all caches can be re-used. + const GLOBAL_FWD_CACHE = GPUCompiler.CodeCache() + const GLOBAL_REV_CACHE = GPUCompiler.CodeCache() + function enzyme_ci_cache(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) + return if job.config.params.mode == API.DEM_ForwardMode + GLOBAL_FWD_CACHE + else + GLOBAL_REV_CACHE + end end -end -GPUCompiler.ci_cache(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = enzyme_ci_cache(job) + GPUCompiler.ci_cache(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = + enzyme_ci_cache(job) -GPUCompiler.get_interpreter(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = - Interpreter.EnzymeInterpreter(enzyme_ci_cache(job), GPUCompiler.method_table(job), job.world, job.config.params.mode) + GPUCompiler.get_interpreter(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = + Interpreter.EnzymeInterpreter( + enzyme_ci_cache(job), + GPUCompiler.method_table(job), + job.world, + job.config.params.mode, + ) end include("compiler/passes.jl") @@ -3399,55 +4104,75 @@ import .Interpreter: isKWCallSignature """ Create the methodinstance pair, and lookup the primal return type. """ -@inline function fspec(@nospecialize(F), @nospecialize(TT), world::Union{Integer, Nothing}=nothing) +@inline function fspec( + @nospecialize(F), + @nospecialize(TT), + world::Union{Integer,Nothing} = nothing, +) # primal function. Inferred here to get return type _tt = (TT.parameters...,) primal_tt = Tuple{map(eltype, _tt)...} primal = if world isa Nothing - GPUCompiler.methodinstance(F, primal_tt) + GPUCompiler.methodinstance(F, primal_tt) else - GPUCompiler.methodinstance(F, primal_tt, world) + GPUCompiler.methodinstance(F, primal_tt, world) end return primal end -@generated function primal_return_type(::ReverseMode, ::Val{world}, ::Type{FT}, ::Type{TT}) where {world, FT, TT} +@generated function primal_return_type( + ::ReverseMode, + ::Val{world}, + ::Type{FT}, + ::Type{TT}, +) where {world,FT,TT} mode = Enzyme.API.DEM_ReverseModeCombined CT = @static if VERSION >= v"1.11.0-DEV.1552" EnzymeCacheToken( - typeof(DefaultCompilerTarget()), #=job.config.always_inline=#false, GPUCompiler.GLOBAL_METHOD_TABLE, - EnzymeCompilerParams, false, + typeof(DefaultCompilerTarget()), + false, + GPUCompiler.GLOBAL_METHOD_TABLE, #=job.config.always_inline=# + EnzymeCompilerParams, + false, ) else Enzyme.Compiler.GLOBAL_REV_CACHE end interp = Enzyme.Compiler.Interpreter.EnzymeInterpreter(CT, nothing, world, mode) - res = Core.Compiler._return_type(interp, Tuple{FT, TT.parameters...}) + res = Core.Compiler._return_type(interp, Tuple{FT,TT.parameters...}) return quote Base.@_inline_meta $res end end -@generated function primal_return_type(::ForwardMode, ::Val{world}, ::Type{FT}, ::Type{TT}) where {world, FT, TT} +@generated function primal_return_type( + ::ForwardMode, + ::Val{world}, + ::Type{FT}, + ::Type{TT}, +) where {world,FT,TT} mode = Enzyme.API.DEM_ForwardMode CT = @static if VERSION >= v"1.11.0-DEV.1552" EnzymeCacheToken( - typeof(DefaultCompilerTarget()), #=always_inline=#false, GPUCompiler.GLOBAL_METHOD_TABLE, - EnzymeCompilerParams, false, + typeof(DefaultCompilerTarget()), + false, + GPUCompiler.GLOBAL_METHOD_TABLE, #=always_inline=# + EnzymeCompilerParams, + false, ) else Enzyme.Compiler.GLOBAL_FWD_CACHE end interp = Enzyme.Compiler.Interpreter.EnzymeInterpreter(CT, nothing, world, mode) - res = Core.Compiler._return_type(interp, Tuple{FT, TT.parameters...}) + res = Core.Compiler._return_type(interp, Tuple{FT,TT.parameters...}) return quote Base.@_inline_meta $res @@ -3467,7 +4192,7 @@ function annotate!(mod, mode) for f in fns API.EnzymeAttributeKnownFunctions(f.ref) end - + for gname in inactiveglobs globs = LLVM.globals(mod) if haskey(globs, gname) @@ -3496,8 +4221,22 @@ function annotate!(mod, mode) if operands(c)[1] != fn continue end - LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), inactive) - LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), no_escaping_alloc) + LLVM.API.LLVMAddCallSiteAttribute( + c, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + inactive, + ) + LLVM.API.LLVMAddCallSiteAttribute( + c, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + no_escaping_alloc, + ) end end end @@ -3521,7 +4260,14 @@ function annotate!(mod, mode) if operands(c)[1] != fn continue end - LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), LLVM.EnumAttribute("nofree", 0)) + LLVM.API.LLVMAddCallSiteAttribute( + c, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + LLVM.EnumAttribute("nofree", 0), + ) end end end @@ -3533,7 +4279,8 @@ function annotate!(mod, mode) end end - for fname in ("julia.typeof", "jl_object_id_", "jl_object_id", "ijl_object_id_", "ijl_object_id") + for fname in + ("julia.typeof", "jl_object_id_", "jl_object_id", "ijl_object_id_", "ijl_object_id") if haskey(fns, fname) fn = fns[fname] if LLVM.version().major <= 15 @@ -3551,14 +4298,25 @@ function annotate!(mod, mode) end end - for fname in ("jl_excstack_state","ijl_excstack_state", "ijl_field_index", "jl_field_index") + for fname in + ("jl_excstack_state", "ijl_excstack_state", "ijl_field_index", "jl_field_index") if haskey(fns, fname) fn = fns[fname] if LLVM.version().major <= 15 push!(function_attributes(fn), LLVM.EnumAttribute("readonly")) push!(function_attributes(fn), LLVM.StringAttribute("inaccessiblememonly")) else - push!(function_attributes(fn), EnumAttribute("memory", MemoryEffect((MRI_NoModRef << getLocationPos(ArgMem)) | (MRI_Ref << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data)) + push!( + function_attributes(fn), + EnumAttribute( + "memory", + MemoryEffect( + (MRI_NoModRef << getLocationPos(ArgMem)) | + (MRI_Ref << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ), + ) end end end @@ -3570,7 +4328,14 @@ function annotate!(mod, mode) end end - for fname in ("jl_f_getfield","ijl_f_getfield","jl_get_nth_field_checked","ijl_get_nth_field_checked", "jl_f__svec_ref", "ijl_f__svec_ref") + for fname in ( + "jl_f_getfield", + "ijl_f_getfield", + "jl_get_nth_field_checked", + "ijl_get_nth_field_checked", + "jl_f__svec_ref", + "ijl_f__svec_ref", + ) if haskey(fns, fname) fn = fns[fname] push!(function_attributes(fn), LLVM.EnumAttribute("readonly", 0)) @@ -3592,9 +4357,23 @@ function annotate!(mod, mode) attr = if LLVM.version().major <= 15 LLVM.EnumAttribute("readonly") else - EnumAttribute("memory", MemoryEffect((MRI_Ref << getLocationPos(ArgMem)) | (MRI_NoModRef << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data) + EnumAttribute( + "memory", + MemoryEffect( + (MRI_Ref << getLocationPos(ArgMem)) | + (MRI_NoModRef << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ) end - LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), attr) + LLVM.API.LLVMAddCallSiteAttribute( + c, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + attr, + ) end end end @@ -3619,30 +4398,52 @@ function annotate!(mod, mode) end end - for fname in ("julia.get_pgcstack", "julia.ptls_states", "jl_get_ptls_states", "julia.safepoint", "ijl_throw", "julia.pointer_from_objref", - "ijl_array_grow_end", "jl_array_grow_end", "ijl_array_del_end", "jl_array_del_end", - "ijl_array_grow_beg", "jl_array_grow_beg", "ijl_array_del_beg", "jl_array_del_beg", - "ijl_array_grow_at", "jl_array_grow_at", - "ijl_array_del_at", "jl_array_del_at", - "ijl_pop_handler", "jl_pop_handler", - "ijl_push_handler", "jl_push_handler", - "ijl_module_name", "jl_module_name", - "ijl_restore_excstack", "jl_restore_excstack", - "julia.except_enter", - "ijl_get_nth_field_checked", "jl_get_nth_field_checked", - "jl_egal__unboxed", - "ijl_reshape_array", "jl_reshape_array", - "ijl_eqtable_get", "jl_eqtable_get", - "jl_gc_run_pending_finalizers", - "ijl_try_substrtod", "jl_try_substrtod", - ) + for fname in ( + "julia.get_pgcstack", + "julia.ptls_states", + "jl_get_ptls_states", + "julia.safepoint", + "ijl_throw", + "julia.pointer_from_objref", + "ijl_array_grow_end", + "jl_array_grow_end", + "ijl_array_del_end", + "jl_array_del_end", + "ijl_array_grow_beg", + "jl_array_grow_beg", + "ijl_array_del_beg", + "jl_array_del_beg", + "ijl_array_grow_at", + "jl_array_grow_at", + "ijl_array_del_at", + "jl_array_del_at", + "ijl_pop_handler", + "jl_pop_handler", + "ijl_push_handler", + "jl_push_handler", + "ijl_module_name", + "jl_module_name", + "ijl_restore_excstack", + "jl_restore_excstack", + "julia.except_enter", + "ijl_get_nth_field_checked", + "jl_get_nth_field_checked", + "jl_egal__unboxed", + "ijl_reshape_array", + "jl_reshape_array", + "ijl_eqtable_get", + "jl_eqtable_get", + "jl_gc_run_pending_finalizers", + "ijl_try_substrtod", + "jl_try_substrtod", + ) if haskey(fns, fname) fn = fns[fname] push!(function_attributes(fn), no_escaping_alloc) end end - + for fname in ("julia.pointer_from_objref",) if haskey(fns, fname) @@ -3655,14 +4456,35 @@ function annotate!(mod, mode) end end - for boxfn in ("julia.gc_alloc_obj", "jl_gc_alloc_typed", "ijl_gc_alloc_typed", - "jl_box_float32", "jl_box_float64", "jl_box_int32", "jl_box_int64", - "ijl_box_float32", "ijl_box_float64", "ijl_box_int32", "ijl_box_int64", - "jl_alloc_array_1d", "jl_alloc_array_2d", "jl_alloc_array_3d", - "ijl_alloc_array_1d", "ijl_alloc_array_2d", "ijl_alloc_array_3d", - "jl_array_copy", "ijl_array_copy", "jl_idtable_rehash", "ijl_idtable_rehash", - "jl_f_tuple", "ijl_f_tuple", "jl_new_structv", "ijl_new_structv", - "ijl_new_array", "jl_new_array") + for boxfn in ( + "julia.gc_alloc_obj", + "jl_gc_alloc_typed", + "ijl_gc_alloc_typed", + "jl_box_float32", + "jl_box_float64", + "jl_box_int32", + "jl_box_int64", + "ijl_box_float32", + "ijl_box_float64", + "ijl_box_int32", + "ijl_box_int64", + "jl_alloc_array_1d", + "jl_alloc_array_2d", + "jl_alloc_array_3d", + "ijl_alloc_array_1d", + "ijl_alloc_array_2d", + "ijl_alloc_array_3d", + "jl_array_copy", + "ijl_array_copy", + "jl_idtable_rehash", + "ijl_idtable_rehash", + "jl_f_tuple", + "ijl_f_tuple", + "jl_new_structv", + "ijl_new_structv", + "ijl_new_array", + "jl_new_array", + ) if haskey(fns, boxfn) fn = fns[boxfn] push!(return_attributes(fn), LLVM.EnumAttribute("noalias", 0)) @@ -3670,9 +4492,23 @@ function annotate!(mod, mode) accattr = if LLVM.version().major <= 15 LLVM.EnumAttribute("inaccessiblememonly") else - EnumAttribute("memory", MemoryEffect((MRI_NoModRef << getLocationPos(ArgMem)) | (MRI_ModRef << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data) + EnumAttribute( + "memory", + MemoryEffect( + (MRI_NoModRef << getLocationPos(ArgMem)) | + (MRI_ModRef << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ) end - if !(boxfn in ("jl_array_copy", "ijl_array_copy", "jl_idtable_rehash", "ijl_idtable_rehash")) + if !( + boxfn in ( + "jl_array_copy", + "ijl_array_copy", + "jl_idtable_rehash", + "ijl_idtable_rehash", + ) + ) push!(function_attributes(fn), accattr) end for u in LLVM.uses(fn) @@ -3682,9 +4518,27 @@ function annotate!(mod, mode) end cf = LLVM.called_operand(c) if cf == fn - LLVM.API.LLVMAddCallSiteAttribute(c, LLVM.API.LLVMAttributeReturnIndex, LLVM.EnumAttribute("noalias", 0)) - if !(boxfn in ("jl_array_copy", "ijl_array_copy", "jl_idtable_rehash", "ijl_idtable_rehash")) - LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), accattr) + LLVM.API.LLVMAddCallSiteAttribute( + c, + LLVM.API.LLVMAttributeReturnIndex, + LLVM.EnumAttribute("noalias", 0), + ) + if !( + boxfn in ( + "jl_array_copy", + "ijl_array_copy", + "jl_idtable_rehash", + "ijl_idtable_rehash", + ) + ) + LLVM.API.LLVMAddCallSiteAttribute( + c, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + accattr, + ) end end if !isa(cf, LLVM.Function) @@ -3696,15 +4550,47 @@ function annotate!(mod, mode) if operands(c)[1] != fn continue end - LLVM.API.LLVMAddCallSiteAttribute(c, LLVM.API.LLVMAttributeReturnIndex, LLVM.EnumAttribute("noalias", 0)) - LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), no_escaping_alloc) - if !(boxfn in ("jl_array_copy", "ijl_array_copy", "jl_idtable_rehash", "ijl_idtable_rehash")) + LLVM.API.LLVMAddCallSiteAttribute( + c, + LLVM.API.LLVMAttributeReturnIndex, + LLVM.EnumAttribute("noalias", 0), + ) + LLVM.API.LLVMAddCallSiteAttribute( + c, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + no_escaping_alloc, + ) + if !( + boxfn in ( + "jl_array_copy", + "ijl_array_copy", + "jl_idtable_rehash", + "ijl_idtable_rehash", + ) + ) attr = if LLVM.version().major <= 15 LLVM.EnumAttribute("inaccessiblememonly") else - EnumAttribute("memory", MemoryEffect((MRI_NoModRef << getLocationPos(ArgMem)) | (MRI_ModRef << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data) + EnumAttribute( + "memory", + MemoryEffect( + (MRI_NoModRef << getLocationPos(ArgMem)) | + (MRI_ModRef << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ) end - LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), attr) + LLVM.API.LLVMAddCallSiteAttribute( + c, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + attr, + ) end end end @@ -3716,7 +4602,17 @@ function annotate!(mod, mode) if LLVM.version().major <= 15 push!(function_attributes(fn), LLVM.EnumAttribute("inaccessiblememonly")) else - push!(function_attributes(fn), EnumAttribute("memory", MemoryEffect((MRI_NoModRef << getLocationPos(ArgMem)) | (MRI_ModRef << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data)) + push!( + function_attributes(fn), + EnumAttribute( + "memory", + MemoryEffect( + (MRI_NoModRef << getLocationPos(ArgMem)) | + (MRI_ModRef << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ), + ) end end end @@ -3730,7 +4626,17 @@ function annotate!(mod, mode) push!(function_attributes(fn), LLVM.EnumAttribute("readonly")) push!(function_attributes(fn), LLVM.EnumAttribute("argmemonly")) else - push!(function_attributes(fn), EnumAttribute("memory", MemoryEffect((MRI_Ref << getLocationPos(ArgMem)) | (MRI_NoModRef << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data)) + push!( + function_attributes(fn), + EnumAttribute( + "memory", + MemoryEffect( + (MRI_Ref << getLocationPos(ArgMem)) | + (MRI_NoModRef << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ), + ) end end end @@ -3745,7 +4651,17 @@ function annotate!(mod, mode) if LLVM.version().major <= 15 push!(function_attributes(fn), LLVM.EnumAttribute("argmemonly")) else - push!(function_attributes(fn), EnumAttribute("memory", MemoryEffect((MRI_ModRef << getLocationPos(ArgMem)) | (MRI_NoModRef << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data)) + push!( + function_attributes(fn), + EnumAttribute( + "memory", + MemoryEffect( + (MRI_ModRef << getLocationPos(ArgMem)) | + (MRI_NoModRef << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ), + ) end end end @@ -3757,7 +4673,17 @@ function annotate!(mod, mode) push!(function_attributes(fn), LLVM.EnumAttribute("readonly")) push!(function_attributes(fn), LLVM.EnumAttribute("inaccessiblememonly")) else - push!(function_attributes(fn), EnumAttribute("memory", MemoryEffect((MRI_NoModRef << getLocationPos(ArgMem)) | (MRI_Ref << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data)) + push!( + function_attributes(fn), + EnumAttribute( + "memory", + MemoryEffect( + (MRI_NoModRef << getLocationPos(ArgMem)) | + (MRI_Ref << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ), + ) end end end @@ -3774,7 +4700,7 @@ function enzyme_extract_world(fn::LLVM.Function)::UInt throw(AssertionError("Enzyme: could not find world in $(string(fn))")) end -function enzyme_custom_extract_mi(orig::LLVM.Instruction, error=true) +function enzyme_custom_extract_mi(orig::LLVM.Instruction, error = true) operand = LLVM.called_operand(orig) if isa(operand, LLVM.Function) return enzyme_custom_extract_mi(operand::LLVM.Function, error) @@ -3784,7 +4710,7 @@ function enzyme_custom_extract_mi(orig::LLVM.Instruction, error=true) return nothing, nothing end -function enzyme_custom_extract_mi(orig::LLVM.Function, error=true) +function enzyme_custom_extract_mi(orig::LLVM.Function, error = true) mi = nothing RT = nothing for fattr in collect(function_attributes(orig)) @@ -3805,7 +4731,7 @@ function enzyme_custom_extract_mi(orig::LLVM.Function, error=true) return mi, RT end -function enzyme_extract_parm_type(fn::LLVM.Function, idx::Int, error=true) +function enzyme_extract_parm_type(fn::LLVM.Function, idx::Int, error = true) ty = nothing byref = nothing for fattr in collect(parameter_attributes(fn, idx)) @@ -3820,7 +4746,9 @@ function enzyme_extract_parm_type(fn::LLVM.Function, idx::Int, error=true) end end if error && (byref === nothing || ty === nothing) - GPUCompiler.@safe_error "Enzyme: Custom handler, could not find parm type at index", idx, fn + GPUCompiler.@safe_error "Enzyme: Custom handler, could not find parm type at index", + idx, + fn end return ty, byref end @@ -3828,18 +4756,39 @@ end include("rules/typerules.jl") include("rules/activityrules.jl") -@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where A <: Const = API.DFT_CONSTANT -@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where A <: Active = API.DFT_OUT_DIFF -@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where A <: Duplicated = API.DFT_DUP_ARG -@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where A <: BatchDuplicated = API.DFT_DUP_ARG -@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where A <: BatchDuplicatedFunc = API.DFT_DUP_ARG -@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where A <: DuplicatedNoNeed = API.DFT_DUP_NONEED -@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where A <: BatchDuplicatedNoNeed = API.DFT_DUP_NONEED +@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where {A<:Const} = API.DFT_CONSTANT +@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where {A<:Active} = + API.DFT_OUT_DIFF +@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where {A<:Duplicated} = + API.DFT_DUP_ARG +@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where {A<:BatchDuplicated} = + API.DFT_DUP_ARG +@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where {A<:BatchDuplicatedFunc} = + API.DFT_DUP_ARG +@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where {A<:DuplicatedNoNeed} = + API.DFT_DUP_NONEED +@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where {A<:BatchDuplicatedNoNeed} = + API.DFT_DUP_NONEED const DumpPreEnzyme = Ref(false) const DumpPostWrap = Ref(false) -function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wrap, modifiedBetween, returnPrimal, expectedTapeType, loweredArgs, boxedArgs) +function enzyme!( + job, + mod, + primalf, + TT, + mode, + width, + parallel, + actualRetType, + wrap, + modifiedBetween, + returnPrimal, + expectedTapeType, + loweredArgs, + boxedArgs, +) if DumpPreEnzyme[] API.EnzymeDumpModuleRef(mod.ref) end @@ -3853,17 +4802,24 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr ctx = context(mod) dl = string(LLVM.datalayout(mod)) - tt = [TT.parameters[2:end]...,] + tt = [TT.parameters[2:end]...] - args_activity = API.CDIFFE_TYPE[] - uncacheable_args = Bool[] - args_typeInfo = TypeTree[] + args_activity = API.CDIFFE_TYPE[] + uncacheable_args = Bool[] + args_typeInfo = TypeTree[] args_known_values = API.IntList[] @assert length(modifiedBetween) == length(TT.parameters) - swiftself = any(any(map(k->kind(k)==kind(EnumAttribute("swiftself")), collect(parameter_attributes(primalf, i)))) for i in 1:length(collect(parameters(primalf)))) + swiftself = any( + any( + map( + k -> kind(k) == kind(EnumAttribute("swiftself")), + collect(parameter_attributes(primalf, i)), + ), + ) for i = 1:length(collect(parameters(primalf))) + ) if swiftself push!(args_activity, API.DFT_CONSTANT) push!(args_typeInfo, TypeTree()) @@ -3876,7 +4832,11 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr source_typ = eltype(T) if isghostty(source_typ) || Core.Compiler.isconstType(source_typ) if !(T <: Const) - error("Type of ghost or constant type "*string(T)*" is marked as differentiable.") + error( + "Type of ghost or constant type " * + string(T) * + " is marked as differentiable.", + ) end continue end @@ -3890,9 +4850,13 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr else push!(args_activity, API.DFT_OUT_DIFF) end - elseif T <: Duplicated || T<: BatchDuplicated || T<: BatchDuplicatedFunc || T <: MixedDuplicated || T <: BatchMixedDuplicated + elseif T <: Duplicated || + T <: BatchDuplicated || + T <: BatchDuplicatedFunc || + T <: MixedDuplicated || + T <: BatchMixedDuplicated push!(args_activity, API.DFT_DUP_ARG) - elseif T <: DuplicatedNoNeed || T<: BatchDuplicatedNoNeed + elseif T <: DuplicatedNoNeed || T <: BatchDuplicatedNoNeed push!(args_activity, API.DFT_DUP_NONEED) else error("illegal annotation type $T") @@ -3922,36 +4886,105 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr convert(API.CDIFFE_TYPE, rt) end - rules = Dict{String, API.CustomRuleType}( - "jl_array_copy" => @cfunction(inout_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "ijl_array_copy" => @cfunction(inout_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "julia.pointer_from_objref" => @cfunction(inout_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "jl_inactive_inout" => @cfunction(inout_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "jl_excstack_state" => @cfunction(int_return_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "ijl_excstack_state" => @cfunction(int_return_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "julia.except_enter" => @cfunction(int_return_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), + rules = Dict{String,API.CustomRuleType}( + "jl_array_copy" => @cfunction( + inout_rule, + UInt8, + ( + Cint, + API.CTypeTreeRef, + Ptr{API.CTypeTreeRef}, + Ptr{API.IntList}, + Csize_t, + LLVM.API.LLVMValueRef, + ) + ), + "ijl_array_copy" => @cfunction( + inout_rule, + UInt8, + ( + Cint, + API.CTypeTreeRef, + Ptr{API.CTypeTreeRef}, + Ptr{API.IntList}, + Csize_t, + LLVM.API.LLVMValueRef, + ) + ), + "julia.pointer_from_objref" => @cfunction( + inout_rule, + UInt8, + ( + Cint, + API.CTypeTreeRef, + Ptr{API.CTypeTreeRef}, + Ptr{API.IntList}, + Csize_t, + LLVM.API.LLVMValueRef, + ) + ), + "jl_inactive_inout" => @cfunction( + inout_rule, + UInt8, + ( + Cint, + API.CTypeTreeRef, + Ptr{API.CTypeTreeRef}, + Ptr{API.IntList}, + Csize_t, + LLVM.API.LLVMValueRef, + ) + ), + "jl_excstack_state" => @cfunction( + int_return_rule, + UInt8, + ( + Cint, + API.CTypeTreeRef, + Ptr{API.CTypeTreeRef}, + Ptr{API.IntList}, + Csize_t, + LLVM.API.LLVMValueRef, + ) + ), + "ijl_excstack_state" => @cfunction( + int_return_rule, + UInt8, + ( + Cint, + API.CTypeTreeRef, + Ptr{API.CTypeTreeRef}, + Ptr{API.IntList}, + Csize_t, + LLVM.API.LLVMValueRef, + ) + ), + "julia.except_enter" => @cfunction( + int_return_rule, + UInt8, + ( + Cint, + API.CTypeTreeRef, + Ptr{API.CTypeTreeRef}, + Ptr{API.IntList}, + Csize_t, + LLVM.API.LLVMValueRef, + ) + ), ) logic = Logic() TA = TypeAnalysis(logic, rules) - retT = (!isa(actualRetType, Union) && GPUCompiler.deserves_retbox(actualRetType)) ? - Ptr{actualRetType} : actualRetType - retTT = (!isa(actualRetType, Union) && actualRetType <: Tuple && in(Any, actualRetType.parameters)) ? TypeTree() : typetree(retT, ctx, dl, seen) + retT = + (!isa(actualRetType, Union) && GPUCompiler.deserves_retbox(actualRetType)) ? + Ptr{actualRetType} : actualRetType + retTT = + ( + !isa(actualRetType, Union) && + actualRetType <: Tuple && + in(Any, actualRetType.parameters) + ) ? TypeTree() : typetree(retT, ctx, dl, seen) typeInfo = FnTypeInfo(retTT, args_typeInfo, args_known_values) @@ -3959,15 +4992,33 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr if mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient returnUsed = !(isghostty(actualRetType) || Core.Compiler.isconstType(actualRetType)) - shadowReturnUsed = returnUsed && (retType == API.DFT_DUP_ARG || retType == API.DFT_DUP_NONEED || rt <: MixedDuplicated || rt <: BatchMixedDuplicated) + shadowReturnUsed = + returnUsed && ( + retType == API.DFT_DUP_ARG || + retType == API.DFT_DUP_NONEED || + rt <: MixedDuplicated || + rt <: BatchMixedDuplicated + ) returnUsed &= returnPrimal augmented = API.EnzymeCreateAugmentedPrimal( - logic, primalf, retType, args_activity, TA, #=returnUsed=# returnUsed, - #=shadowReturnUsed=#shadowReturnUsed, - typeInfo, uncacheable_args, #=forceAnonymousTape=# false, runtimeActivity, width, #=atomicAdd=# parallel) + logic, + primalf, + retType, + args_activity, + TA, + returnUsed, #=returnUsed=# + shadowReturnUsed, #=shadowReturnUsed=# + typeInfo, + uncacheable_args, + false, + runtimeActivity, + width, + parallel, + ) #=atomicAdd=# # 2. get new_primalf and tape - augmented_primalf = LLVM.Function(API.EnzymeExtractFunctionFromAugmentation(augmented)) + augmented_primalf = + LLVM.Function(API.EnzymeExtractFunctionFromAugmentation(augmented)) tape = API.EnzymeExtractTapeTypeFromAugmentation(augmented) utape = API.EnzymeExtractUnderlyingTapeTypeFromAugmentation(augmented) if utape != C_NULL @@ -3983,55 +5034,145 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr end if wrap - augmented_primalf = create_abi_wrapper(augmented_primalf, TT, rt, actualRetType, API.DEM_ReverseModePrimal, augmented, width, returnPrimal, shadow_init, world, interp) + augmented_primalf = create_abi_wrapper( + augmented_primalf, + TT, + rt, + actualRetType, + API.DEM_ReverseModePrimal, + augmented, + width, + returnPrimal, + shadow_init, + world, + interp, + ) end # TODOs: # 1. Handle mutable or !pointerfree arguments by introducing caching # + specifically by setting uncacheable_args[i] = true - adjointf = LLVM.Function(API.EnzymeCreatePrimalAndGradient( - logic, primalf, retType, args_activity, TA, - #=returnValue=#false, #=dretUsed=#false, #=mode=#API.DEM_ReverseModeGradient, runtimeActivity, width, - #=additionalArg=#tape, #=forceAnonymousTape=#false, typeInfo, - uncacheable_args, augmented, #=atomicAdd=# parallel)) + adjointf = LLVM.Function( + API.EnzymeCreatePrimalAndGradient( + logic, + primalf, + retType, + args_activity, + TA, + false, + false, + API.DEM_ReverseModeGradient, + runtimeActivity, + width, #=mode=# + tape, + false, + typeInfo, #=forceAnonymousTape=# + uncacheable_args, + augmented, + parallel, + ), + ) #=atomicAdd=# if wrap - adjointf = create_abi_wrapper(adjointf, TT, rt, actualRetType, API.DEM_ReverseModeGradient, augmented, width, #=returnPrimal=#false, shadow_init, world, interp) + adjointf = create_abi_wrapper( + adjointf, + TT, + rt, + actualRetType, + API.DEM_ReverseModeGradient, + augmented, + width, + false, + shadow_init, + world, + interp, + ) #=returnPrimal=# end elseif mode == API.DEM_ReverseModeCombined returnUsed = !isghostty(actualRetType) returnUsed &= returnPrimal - adjointf = LLVM.Function(API.EnzymeCreatePrimalAndGradient( - logic, primalf, retType, args_activity, TA, - #=returnValue=#returnUsed, #=dretUsed=#false, #=mode=#API.DEM_ReverseModeCombined, runtimeActivity, width, - #=additionalArg=#C_NULL, #=forceAnonymousTape=#false, typeInfo, - uncacheable_args, #=augmented=#C_NULL, #=atomicAdd=# parallel)) + adjointf = LLVM.Function( + API.EnzymeCreatePrimalAndGradient( + logic, + primalf, + retType, + args_activity, + TA, + returnUsed, + false, + API.DEM_ReverseModeCombined, + runtimeActivity, + width, #=mode=# + C_NULL, + false, + typeInfo, #=forceAnonymousTape=# + uncacheable_args, + C_NULL, + parallel, + ), + ) #=atomicAdd=# augmented_primalf = nothing if wrap - adjointf = create_abi_wrapper(adjointf, TT, rt, actualRetType, API.DEM_ReverseModeCombined, nothing, width, returnPrimal, shadow_init, world, interp) + adjointf = create_abi_wrapper( + adjointf, + TT, + rt, + actualRetType, + API.DEM_ReverseModeCombined, + nothing, + width, + returnPrimal, + shadow_init, + world, + interp, + ) end elseif mode == API.DEM_ForwardMode returnUsed = !(isghostty(actualRetType) || Core.Compiler.isconstType(actualRetType)) returnUsed &= returnPrimal - adjointf = LLVM.Function(API.EnzymeCreateForwardDiff( - logic, primalf, retType, args_activity, TA, - #=returnValue=#returnUsed, #=mode=#API.DEM_ForwardMode, runtimeActivity, width, - #=additionalArg=#C_NULL, typeInfo, - uncacheable_args)) + adjointf = LLVM.Function( + API.EnzymeCreateForwardDiff( + logic, + primalf, + retType, + args_activity, + TA, + returnUsed, + API.DEM_ForwardMode, + runtimeActivity, + width, #=mode=# + C_NULL, + typeInfo, #=additionalArg=# + uncacheable_args, + ), + ) augmented_primalf = nothing if wrap - pf = adjointf - adjointf = create_abi_wrapper(adjointf, TT, rt, actualRetType, API.DEM_ForwardMode, nothing, width, returnPrimal, shadow_init, world, interp) + pf = adjointf + adjointf = create_abi_wrapper( + adjointf, + TT, + rt, + actualRetType, + API.DEM_ForwardMode, + nothing, + width, + returnPrimal, + shadow_init, + world, + interp, + ) end else @assert "Unhandled derivative mode", mode end if DumpPostWrap[] API.EnzymeDumpModuleRef(mod.ref) - end + end API.EnzymeLogicErasePreprocessedFunctions(logic) adjointfname = adjointf == nothing ? nothing : LLVM.name(adjointf) - augmented_primalfname = augmented_primalf == nothing ? nothing : LLVM.name(augmented_primalf) + augmented_primalfname = + augmented_primalf == nothing ? nothing : LLVM.name(augmented_primalf) for f in collect(functions(mod)) API.EnzymeFixupBatchedJuliaCallingConvention(f) end @@ -4041,7 +5182,8 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr end fix_decayaddr!(mod) adjointf = adjointf == nothing ? nothing : functions(mod)[adjointfname] - augmented_primalf = augmented_primalf == nothing ? nothing : functions(mod)[augmented_primalfname] + augmented_primalf = + augmented_primalf == nothing ? nothing : functions(mod)[augmented_primalfname] return adjointf, augmented_primalf, TapeType end @@ -4061,18 +5203,39 @@ function set_subprogram!(f::LLVM.Function, sp) end end -function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, Mode::API.CDerivativeMode, augmented, width, returnPrimal, shadow_init, world, interp) +function create_abi_wrapper( + enzymefn::LLVM.Function, + TT, + rettype, + actualRetType, + Mode::API.CDerivativeMode, + augmented, + width, + returnPrimal, + shadow_init, + world, + interp, +) is_adjoint = Mode == API.DEM_ReverseModeGradient || Mode == API.DEM_ReverseModeCombined - is_split = Mode == API.DEM_ReverseModeGradient || Mode == API.DEM_ReverseModePrimal + is_split = Mode == API.DEM_ReverseModeGradient || Mode == API.DEM_ReverseModePrimal needs_tape = Mode == API.DEM_ReverseModeGradient mod = LLVM.parent(enzymefn) ctx = LLVM.context(mod) push!(function_attributes(enzymefn), EnumAttribute("alwaysinline", 0)) - hasNoInline = any(map(k->kind(k)==kind(EnumAttribute("noinline")), collect(function_attributes(enzymefn)))) + hasNoInline = any( + map( + k -> kind(k) == kind(EnumAttribute("noinline")), + collect(function_attributes(enzymefn)), + ), + ) if hasNoInline - LLVM.API.LLVMRemoveEnumAttributeAtIndex(enzymefn, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), kind(EnumAttribute("noinline"))) + LLVM.API.LLVMRemoveEnumAttributeAtIndex( + enzymefn, + reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), + kind(EnumAttribute("noinline")), + ) end T_void = convert(LLVMType, Nothing) ptr8 = LLVM.PointerType(LLVM.IntType(8)) @@ -4082,7 +5245,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, # Create Enzyme calling convention T_wrapperargs = LLVMType[] # Arguments of the wrapper - sret_types = Type[] # Julia types of all returned variables + sret_types = Type[] # Julia types of all returned variables pactualRetType = actualRetType sret_union = is_sret_union(actualRetType) @@ -4122,7 +5285,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, if width == 1 push!(ActiveRetTypes, source_typ) else - push!(ActiveRetTypes, NTuple{width, source_typ}) + push!(ActiveRetTypes, NTuple{width,source_typ}) end end elseif T <: Duplicated || T <: DuplicatedNoNeed @@ -4148,7 +5311,10 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, if is_adjoint NT = Tuple{ActiveRetTypes...} - if any(any_jltypes(convert(LLVM.LLVMType, b; allow_boxed=true)) for b in ActiveRetTypes) + if any( + any_jltypes(convert(LLVM.LLVMType, b; allow_boxed = true)) for + b in ActiveRetTypes + ) NT = AnonymousStruct(NT) end push!(sret_types, NT) @@ -4156,26 +5322,40 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, # API.DFT_OUT_DIFF if is_adjoint - if rettype <: Active || rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated + if rettype <: Active || + rettype <: MixedDuplicated || + rettype <: BatchMixedDuplicated @assert !sret_union if allocatedinline(actualRetType) != allocatedinline(literal_rt) msg = sprint() do io println(io, string(enzymefn)) - println(io, "Base.allocatedinline(actualRetType) != Base.allocatedinline(literal_rt): actualRetType = $(actualRetType), literal_rt = $(literal_rt), rettype = $(rettype), sret_union=$(sret_union), pactualRetType=$(pactualRetType)") + println( + io, + "Base.allocatedinline(actualRetType) != Base.allocatedinline(literal_rt): actualRetType = $(actualRetType), literal_rt = $(literal_rt), rettype = $(rettype), sret_union=$(sret_union), pactualRetType=$(pactualRetType)", + ) end throw(AssertionError(msg)) end - if rettype <: Active + if rettype <: Active if !allocatedinline(actualRetType) - throw(AssertionError("Base.allocatedinline(actualRetType) returns false: actualRetType = $(actualRetType), rettype = $(rettype)")) + throw( + AssertionError( + "Base.allocatedinline(actualRetType) returns false: actualRetType = $(actualRetType), rettype = $(rettype)", + ), + ) end end - dretTy = LLVM.LLVMType(API.EnzymeGetShadowType(width, convert(LLVMType, actualRetType; allow_boxed=!(rettype <: Active)))) + dretTy = LLVM.LLVMType( + API.EnzymeGetShadowType( + width, + convert(LLVMType, actualRetType; allow_boxed = !(rettype <: Active)), + ), + ) push!(T_wrapperargs, dretTy) end end - data = Array{Int64}(undef, 3) + data = Array{Int64}(undef, 3) existed = Array{UInt8}(undef, 3) if Mode == API.DEM_ReverseModePrimal API.EnzymeExtractReturnInfo(augmented, data, existed) @@ -4208,17 +5388,24 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, end # shadow return if existed[3] != 0 - if rettype <: Duplicated || rettype <: DuplicatedNoNeed || rettype <: BatchDuplicated || rettype <: BatchDuplicatedNoNeed || rettype <: BatchDuplicatedFunc + if rettype <: Duplicated || + rettype <: DuplicatedNoNeed || + rettype <: BatchDuplicated || + rettype <: BatchDuplicatedNoNeed || + rettype <: BatchDuplicatedFunc if width == 1 push!(sret_types, literal_rt) else - push!(sret_types, AnonymousStruct(NTuple{width, literal_rt})) + push!(sret_types, AnonymousStruct(NTuple{width,literal_rt})) end elseif rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated if width == 1 push!(sret_types, Base.RefValue{literal_rt}) else - push!(sret_types, AnonymousStruct(NTuple{width, Base.RefValue{literal_rt}})) + push!( + sret_types, + AnonymousStruct(NTuple{width,Base.RefValue{literal_rt}}), + ) end end else @@ -4236,7 +5423,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, if width == 1 push!(sret_types, literal_rt) else - push!(sret_types, AnonymousStruct(NTuple{width, literal_rt})) + push!(sret_types, AnonymousStruct(NTuple{width,literal_rt})) end end if returnPrimal @@ -4244,11 +5431,14 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, end end - combinedReturn = if any(any_jltypes(convert(LLVM.LLVMType, T; allow_boxed=true)) for T in sret_types) - AnonymousStruct(Tuple{sret_types...}) - else - Tuple{sret_types...} - end + combinedReturn = + if any( + any_jltypes(convert(LLVM.LLVMType, T; allow_boxed = true)) for T in sret_types + ) + AnonymousStruct(Tuple{sret_types...}) + else + Tuple{sret_types...} + end uses_sret = is_sret(combinedReturn) @@ -4268,14 +5458,14 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, returnRoots = false root_ty = nothing if uses_sret - returnRoots = deserves_rooting(jltype) - if returnRoots - tracked = CountTrackedPointers(jltype) + returnRoots = deserves_rooting(jltype) + if returnRoots + tracked = CountTrackedPointers(jltype) root_ty = LLVM.ArrayType(T_prjlvalue, tracked.count) pushfirst!(T_wrapperargs, LLVM.PointerType(root_ty)) pushfirst!(T_wrapperargs, LLVM.PointerType(jltype)) - end + end end if needs_tape @@ -4286,7 +5476,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, end if tape != C_NULL tape = LLVM.LLVMType(tape) - jltape = convert(LLVM.LLVMType, tape_type(tape); allow_boxed=true) + jltape = convert(LLVM.LLVMType, tape_type(tape); allow_boxed = true) push!(T_wrapperargs, jltape) else needs_tape = false @@ -4295,7 +5485,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, T_ret = returnRoots ? T_void : jltype FT = LLVM.FunctionType(T_ret, T_wrapperargs) - llvm_f = LLVM.Function(mod, safe_name(LLVM.name(enzymefn)*"wrap"), FT) + llvm_f = LLVM.Function(mod, safe_name(LLVM.name(enzymefn) * "wrap"), FT) API.EnzymeCloneFunctionDISubprogramInto(llvm_f, enzymefn) dl = datalayout(mod) @@ -4316,7 +5506,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, if returnRoots sret = params[i] - i+= 1 + i += 1 attr = if LLVM.version().major >= 12 TypeAttribute("sret", jltype) @@ -4332,7 +5522,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, rootRet = nothing if returnRoots rootRet = params[i] - i+=1 + i += 1 end activeNum = 0 @@ -4348,7 +5538,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, llty = value_type(params[i]) - convty = convert(LLVMType, T′; allow_boxed=true) + convty = convert(LLVMType, T′; allow_boxed = true) if (T <: MixedDuplicated || T <: BatchMixedDuplicated) && !isboxed # && (isa(llty, LLVM.ArrayType) || isa(llty, LLVM.StructType)) al0 = al = emit_allocobj!(builder, Base.RefValue{T′}, "mixedparameter") @@ -4368,7 +5558,10 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, if isboxed if is_split msg = sprint() do io - println(io, "Unimplemented: Had active input arg needing a box in split mode") + println( + io, + "Unimplemented: Had active input arg needing a box in split mode", + ) println(io, T, " at index ", i) println(io, TT) end @@ -4376,13 +5569,28 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, end @assert !is_split # TODO replace with better enzyme_zero - ptr = gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), activeNum)]) + ptr = gep!( + builder, + jltype, + sret, + [ + LLVM.ConstantInt(LLVM.IntType(64), 0), + LLVM.ConstantInt(LLVM.IntType(32), activeNum), + ], + ) cst = pointercast!(builder, ptr, ptr8) push!(realparms, ptr) - LLVM.memset!(builder, cst, LLVM.ConstantInt(LLVM.IntType(8), 0), - LLVM.ConstantInt(LLVM.IntType(64), LLVM.storage_size(dl, Base.eltype(LLVM.value_type(ptr)) )), - #=align=#0 ) + LLVM.memset!( + builder, + cst, + LLVM.ConstantInt(LLVM.IntType(8), 0), + LLVM.ConstantInt( + LLVM.IntType(64), + LLVM.storage_size(dl, Base.eltype(LLVM.value_type(ptr))), + ), + 0, + ) #=align=# end activeNum += 1 elseif T <: Duplicated || T <: DuplicatedNoNeed @@ -4392,9 +5600,13 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, parmsi = params[i] if T <: BatchMixedDuplicated - if GPUCompiler.deserves_argbox(NTuple{width, Base.RefValue{T′}}) + if GPUCompiler.deserves_argbox(NTuple{width,Base.RefValue{T′}}) njlvalue = LLVM.ArrayType(Int(width), T_prjlvalue) - parmsi = bitcast!(builder, parmsi, LLVM.PointerType(njlvalue, addrspace(value_type(parmsi)))) + parmsi = bitcast!( + builder, + parmsi, + LLVM.PointerType(njlvalue, addrspace(value_type(parmsi))), + ) parmsi = load!(builder, njlvalue, parmsi) end end @@ -4404,23 +5616,24 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, resty = isboxed ? llty : LLVM.PointerType(llty, Derived) ival = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, resty))) - for idx in 1:width - pv = (width == 1) ? parmsi : extract_value!(builder, parmsi, idx-1) - pv = bitcast!(builder, pv, LLVM.PointerType(llty, addrspace(value_type(pv)))) + for idx = 1:width + pv = (width == 1) ? parmsi : extract_value!(builder, parmsi, idx - 1) + pv = + bitcast!(builder, pv, LLVM.PointerType(llty, addrspace(value_type(pv)))) pv = addrspacecast!(builder, pv, LLVM.PointerType(llty, Derived)) if isboxed pv = load!(builder, llty, pv, "mixedboxload") end - ival = (width == 1 ) ? pv : insert_value!(builder, ival, pv, idx-1) + ival = (width == 1) ? pv : insert_value!(builder, ival, pv, idx - 1) end push!(realparms, ival) i += 1 elseif T <: BatchDuplicated || T <: BatchDuplicatedNoNeed - isboxed = GPUCompiler.deserves_argbox(NTuple{width, T′}) + isboxed = GPUCompiler.deserves_argbox(NTuple{width,T′}) val = params[i] if isboxed - val = load!(builder, val) + val = load!(builder, val) end i += 1 push!(realparms, val) @@ -4430,7 +5643,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, llvmf = nested_codegen!(Mode, mod, funcspec, world) push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0)) Func_RT = Core.Compiler.typeinf_ext_toplevel(interp, funcspec).rettype - @assert Func_RT == NTuple{width, T′} + @assert Func_RT == NTuple{width,T′} _, psret, _ = get_return_info(Func_RT) args = LLVM.Value[] if psret !== nothing @@ -4439,7 +5652,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, end res = LLVM.call!(builder, LLVM.function_type(llvmf), llvmf, args) if get_subprogram(llvmf) !== nothing - metadata(res)[LLVM.MD_dbg] = DILocation( 0, 0, get_subprogram(llvm_f) ) + metadata(res)[LLVM.MD_dbg] = DILocation(0, 0, get_subprogram(llvm_f)) end if psret !== nothing res = load!(builder, convert(LLVMType, Func_RT), psret) @@ -4450,7 +5663,8 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, end end - if is_adjoint && (rettype <: Active || rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated) + if is_adjoint && + (rettype <: Active || rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated) push!(realparms, params[i]) i += 1 end @@ -4466,7 +5680,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, val = call!(builder, LLVM.function_type(enzymefn), enzymefn, realparms) if get_subprogram(llvm_f) !== nothing - metadata(val)[LLVM.MD_dbg] = DILocation( 0, 0, get_subprogram(llvm_f) ) + metadata(val)[LLVM.MD_dbg] = DILocation(0, 0, get_subprogram(llvm_f)) end @inline function fixup_abi(index, value) @@ -4486,11 +5700,13 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, # if in split mode and the return is a union marked duplicated, upgrade floating point like shadow returns into ref{ty} since otherwise use of the value will create problems. # 3 is index of shadow - if existed[3] != 0 && sret_union && active_reg_inner(pactualRetType, (), world, #=justActive=#Val(true), #=UnionSret=#Val(true)) == ActiveState + if existed[3] != 0 && + sret_union && + active_reg_inner(pactualRetType, (), world, Val(true), Val(true)) == ActiveState #=UnionSret=# rewrite_union_returns_as_ref(enzymefn, data[3], world, width) end returnNum = 0 - for i in 1:3 + for i = 1:3 if existed[i] != 0 eval = val if data[i] != -1 @@ -4498,31 +5714,56 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, end if i == 3 if rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated - ival = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, T_prjlvalue))) - for idx in 1:width - pv = (width == 1) ? eval : extract_value!(builder, eval, idx-1) - al0 = al = emit_allocobj!(builder, Base.RefValue{eltype(rettype)}, "batchmixedret") + ival = UndefValue( + LLVM.LLVMType(API.EnzymeGetShadowType(width, T_prjlvalue)), + ) + for idx = 1:width + pv = + (width == 1) ? eval : extract_value!(builder, eval, idx - 1) + al0 = + al = emit_allocobj!( + builder, + Base.RefValue{eltype(rettype)}, + "batchmixedret", + ) llty = value_type(pv) - al = bitcast!(builder, al, LLVM.PointerType(llty, addrspace(value_type(al)))) + al = bitcast!( + builder, + al, + LLVM.PointerType(llty, addrspace(value_type(al))), + ) store!(builder, pv, al) - emit_writebarrier!(builder, get_julia_inner_types(builder, al0, pv)) - ival = (width == 1 ) ? al0 : insert_value!(builder, ival, al0, idx-1) + emit_writebarrier!( + builder, + get_julia_inner_types(builder, al0, pv), + ) + ival = + (width == 1) ? al0 : + insert_value!(builder, ival, al0, idx - 1) end eval = ival end end eval = fixup_abi(i, eval) - ptr = inbounds_gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), returnNum)]) + ptr = inbounds_gep!( + builder, + jltype, + sret, + [ + LLVM.ConstantInt(LLVM.IntType(64), 0), + LLVM.ConstantInt(LLVM.IntType(32), returnNum), + ], + ) ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval))) si = store!(builder, eval, ptr) - returnNum+=1 + returnNum += 1 if i == 3 && shadow_init shadows = LLVM.Value[] if width == 1 push!(shadows, eval) else - for i in 1:width - push!(shadows, extract_value!(builder, eval, i-1)) + for i = 1:width + push!(shadows, extract_value!(builder, eval, i - 1)) end end @@ -4531,7 +5772,8 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, for shadowv in shadows c = call!(builder, LLVM.function_type(cf), cf, [shadowv]) if get_subprogram(llvm_f) !== nothing - metadata(c)[LLVM.MD_dbg] = DILocation( 0, 0, get_subprogram(llvm_f) ) + metadata(c)[LLVM.MD_dbg] = + DILocation(0, 0, get_subprogram(llvm_f)) end end end @@ -4541,14 +5783,24 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, if i == 2 ty = actualRetType end - @assert !(isghostty(combinedReturn) || Core.Compiler.isconstType(combinedReturn) ) + @assert !( + isghostty(combinedReturn) || Core.Compiler.isconstType(combinedReturn) + ) @assert Core.Compiler.isconstType(ty) eval = makeInstanceOf(builder, ty) eval = fixup_abi(i, eval) - ptr = inbounds_gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), returnNum)]) + ptr = inbounds_gep!( + builder, + jltype, + sret, + [ + LLVM.ConstantInt(LLVM.IntType(64), 0), + LLVM.ConstantInt(LLVM.IntType(32), returnNum), + ], + ) ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval))) si = store!(builder, eval, ptr) - returnNum+=1 + returnNum += 1 end end @assert returnNum == numLLVMReturns @@ -4571,16 +5823,27 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, count_Sret += 1 end end - for returnNum in 0:(count_Sret-1) - eval = fixup_abi(returnNum+1, if count_llvm_Sret == 0 - makeInstanceOf(builder, sret_types[returnNum+1]) - elseif count_llvm_Sret == 1 - val - else - @assert count_llvm_Sret > 1 - extract_value!(builder, val, 1-returnNum) - end) - ptr = inbounds_gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), returnNum)]) + for returnNum = 0:(count_Sret-1) + eval = fixup_abi( + returnNum + 1, + if count_llvm_Sret == 0 + makeInstanceOf(builder, sret_types[returnNum+1]) + elseif count_llvm_Sret == 1 + val + else + @assert count_llvm_Sret > 1 + extract_value!(builder, val, 1 - returnNum) + end, + ) + ptr = inbounds_gep!( + builder, + jltype, + sret, + [ + LLVM.ConstantInt(LLVM.IntType(64), 0), + LLVM.ConstantInt(LLVM.IntType(32), returnNum), + ], + ) ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval))) si = store!(builder, eval, ptr) end @@ -4591,13 +5854,31 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, if Mode == API.DEM_ReverseModeCombined if returnPrimal if !isghostty(literal_rt) - eval = fixup_abi(returnNum+1, if !isghostty(actualRetType) - extract_value!(builder, val, returnNum) - else - makeInstanceOf(builder, sret_types[returnNum+1]) - end) - store!(builder, eval, inbounds_gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), length(elements(jltype))-1 )])) - returnNum+=1 + eval = fixup_abi( + returnNum + 1, + if !isghostty(actualRetType) + extract_value!(builder, val, returnNum) + else + makeInstanceOf(builder, sret_types[returnNum+1]) + end, + ) + store!( + builder, + eval, + inbounds_gep!( + builder, + jltype, + sret, + [ + LLVM.ConstantInt(LLVM.IntType(64), 0), + LLVM.ConstantInt( + LLVM.IntType(32), + length(elements(jltype)) - 1, + ), + ], + ), + ) + returnNum += 1 end end end @@ -4607,10 +5888,23 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, isboxed = GPUCompiler.deserves_argbox(T′) if !isboxed eval = extract_value!(builder, val, returnNum) - store!(builder, eval, inbounds_gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 0), LLVM.ConstantInt(LLVM.IntType(32), activeNum)])) - returnNum+=1 + store!( + builder, + eval, + inbounds_gep!( + builder, + jltype, + sret, + [ + LLVM.ConstantInt(LLVM.IntType(64), 0), + LLVM.ConstantInt(LLVM.IntType(32), 0), + LLVM.ConstantInt(LLVM.IntType(32), activeNum), + ], + ), + ) + returnNum += 1 end - activeNum+=1 + activeNum += 1 end end @assert (returnNum - activeNum) + (activeNum != 0 ? 1 : 0) == numLLVMReturns @@ -4618,21 +5912,32 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, if returnRoots count = 0 - todo = Tuple{Vector{LLVM.Value},LLVM.LLVMType}[([LLVM.ConstantInt(LLVM.IntType(64), 0)], jltype)] + todo = Tuple{Vector{LLVM.Value},LLVM.LLVMType}[( + [LLVM.ConstantInt(LLVM.IntType(64), 0)], + jltype, + )] while length(todo) != 0 path, ty = popfirst!(todo) if isa(ty, LLVM.PointerType) - loc = inbounds_gep!(builder, root_ty, rootRet, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), count)]) - count+=1 + loc = inbounds_gep!( + builder, + root_ty, + rootRet, + [ + LLVM.ConstantInt(LLVM.IntType(64), 0), + LLVM.ConstantInt(LLVM.IntType(32), count), + ], + ) + count += 1 outloc = inbounds_gep!(builder, jltype, sret, path) store!(builder, load!(builder, ty, outloc), loc) continue end if isa(ty, LLVM.ArrayType) if any_jltypes(ty) - for i=1:length(ty) + for i = 1:length(ty) npath = copy(path) - push!(npath, LLVM.ConstantInt(LLVM.IntType(32), i-1)) + push!(npath, LLVM.ConstantInt(LLVM.IntType(32), i - 1)) push!(todo, (npath, eltype(ty))) end end @@ -4640,9 +5945,9 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, end if isa(ty, LLVM.VectorType) if any_jltypes(ty) - for i=1:size(ty) + for i = 1:size(ty) npath = copy(path) - push!(npath, LLVM.ConstantInt(LLVM.IntType(32), i-1)) + push!(npath, LLVM.ConstantInt(LLVM.IntType(32), i - 1)) push!(todo, (npath, eltype(ty))) end end @@ -4652,7 +5957,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, for (i, t) in enumerate(LLVM.elements(ty)) if any_jltypes(t) npath = copy(path) - push!(npath, LLVM.ConstantInt(LLVM.IntType(32), i-1)) + push!(npath, LLVM.ConstantInt(LLVM.IntType(32), i - 1)) push!(todo, (npath, t)) end end @@ -4704,69 +6009,103 @@ function fixup_metadata!(f::LLVM.Function) end end -struct RemovedParam -end +struct RemovedParam end # Modified from GPUCompiler classify_arguments -function classify_arguments(source_sig::Type, codegen_ft::LLVM.FunctionType, has_sret::Bool, has_returnroots::Bool, has_swiftself::Bool, parmsRemoved::Vector{UInt64}) +function classify_arguments( + source_sig::Type, + codegen_ft::LLVM.FunctionType, + has_sret::Bool, + has_returnroots::Bool, + has_swiftself::Bool, + parmsRemoved::Vector{UInt64}, +) codegen_types = parameters(codegen_ft) args = [] codegen_i = 1 orig_i = 1 if has_sret - if !in(orig_i-1, parmsRemoved) + if !in(orig_i - 1, parmsRemoved) codegen_i += 1 end orig_i += 1 end if has_returnroots - if !in(orig_i-1, parmsRemoved) + if !in(orig_i - 1, parmsRemoved) codegen_i += 1 end orig_i += 1 end if has_swiftself - if !in(orig_i-1, parmsRemoved) + if !in(orig_i - 1, parmsRemoved) codegen_i += 1 end orig_i += 1 end for (source_i, source_typ) in enumerate(source_sig.parameters) if isghostty(source_typ) || Core.Compiler.isconstType(source_typ) - push!(args, (cc=GPUCompiler.GHOST, typ=source_typ, arg_i=source_i)) + push!(args, (cc = GPUCompiler.GHOST, typ = source_typ, arg_i = source_i)) continue end - if in(orig_i-1, parmsRemoved) - push!(args, (cc=RemovedParam, typ=source_typ)) + if in(orig_i - 1, parmsRemoved) + push!(args, (cc = RemovedParam, typ = source_typ)) orig_i += 1 continue end codegen_typ = codegen_types[codegen_i] if codegen_typ isa LLVM.PointerType - llvm_source_typ = convert(LLVMType, source_typ; allow_boxed=true) + llvm_source_typ = convert(LLVMType, source_typ; allow_boxed = true) # pointers are used for multiple kinds of arguments # - literal pointer values if source_typ <: Ptr || source_typ <: Core.LLVMPtr @assert llvm_source_typ == codegen_typ - push!(args, (cc=GPUCompiler.BITS_VALUE, typ=source_typ, arg_i=source_i, - codegen=(typ=codegen_typ, i=codegen_i))) - # - boxed values - # XXX: use `deserves_retbox` instead? + push!( + args, + ( + cc = GPUCompiler.BITS_VALUE, + typ = source_typ, + arg_i = source_i, + codegen = (typ = codegen_typ, i = codegen_i), + ), + ) + # - boxed values + # XXX: use `deserves_retbox` instead? elseif llvm_source_typ isa LLVM.PointerType @assert llvm_source_typ == codegen_typ - push!(args, (cc=GPUCompiler.MUT_REF, typ=source_typ, arg_i=source_i, - codegen=(typ=codegen_typ, i=codegen_i))) - # - references to aggregates + push!( + args, + ( + cc = GPUCompiler.MUT_REF, + typ = source_typ, + arg_i = source_i, + codegen = (typ = codegen_typ, i = codegen_i), + ), + ) + # - references to aggregates else @assert llvm_source_typ != codegen_typ - push!(args, (cc=GPUCompiler.BITS_REF, typ=source_typ, arg_i=source_i, - codegen=(typ=codegen_typ, i=codegen_i))) + push!( + args, + ( + cc = GPUCompiler.BITS_REF, + typ = source_typ, + arg_i = source_i, + codegen = (typ = codegen_typ, i = codegen_i), + ), + ) end else - push!(args, (cc=GPUCompiler.BITS_VALUE, typ=source_typ, arg_i=source_i, - codegen=(typ=codegen_typ, i=codegen_i))) + push!( + args, + ( + cc = GPUCompiler.BITS_VALUE, + typ = source_typ, + arg_i = source_i, + codegen = (typ = codegen_typ, i = codegen_i), + ), + ) end codegen_i += 1 @@ -4778,9 +6117,9 @@ end function isSpecialPtr(Ty) if !isa(Ty, LLVM.PointerType) - return false - end - AS = LLVM.addrspace(Ty) + return false + end + AS = LLVM.addrspace(Ty) return 10 <= AS && AS <= 13 end @@ -4791,14 +6130,14 @@ mutable struct CountTrackedPointers end function CountTrackedPointers(T) - res = CountTrackedPointers(0, true, false) + res = CountTrackedPointers(0, true, false) if isa(T, LLVM.PointerType) if isSpecialPtr(T) res.count += 1 if LLVM.addrspace(T) != Tracked res.derived = true - end + end end elseif isa(T, LLVM.StructType) for ElT in elements(T) @@ -4807,43 +6146,43 @@ function CountTrackedPointers(T) res.all &= sub.all res.derived |= sub.derived end - elseif isa(T, LLVM.ArrayType) - sub = CountTrackedPointers(eltype(T)) - res.count += sub.count - res.all &= sub.all - res.derived |= sub.derived - res.count *= length(T) - elseif isa(T, LLVM.VectorType) - sub = CountTrackedPointers(eltype(T)) - res.count += sub.count - res.all &= sub.all - res.derived |= sub.derived - res.count *= size(T) + elseif isa(T, LLVM.ArrayType) + sub = CountTrackedPointers(eltype(T)) + res.count += sub.count + res.all &= sub.all + res.derived |= sub.derived + res.count *= length(T) + elseif isa(T, LLVM.VectorType) + sub = CountTrackedPointers(eltype(T)) + res.count += sub.count + res.all &= sub.all + res.derived |= sub.derived + res.count *= size(T) end if res.count == 0 res.all = false - end - return res + end + return res end # must deserve sret function deserves_rooting(T) - tracked = CountTrackedPointers(T) - @assert !tracked.derived - if tracked.count != 0 && !tracked.all - return true # tracked.count; - end - return false + tracked = CountTrackedPointers(T) + @assert !tracked.derived + if tracked.count != 0 && !tracked.all + return true # tracked.count; + end + return false end # https://github.com/JuliaLang/julia/blob/64378db18b512677fc6d3b012e6d1f02077af191/src/cgutils.cpp#L823 # returns if all unboxed -function for_each_uniontype_small(f, ty, counter=Ref(0)) +function for_each_uniontype_small(f, ty, counter = Ref(0)) if counter[] > 127 return false end if ty isa Union - allunbox = for_each_uniontype_small(f, ty.a, counter) + allunbox = for_each_uniontype_small(f, ty.a, counter) allunbox &= for_each_uniontype_small(f, ty.b, counter) return allunbox end @@ -4860,8 +6199,8 @@ end function union_alloca_type(UT) nbytes = 0 function inner(jlrettype) - if !(Base.issingletontype(jlrettype) &&isa(jlrettype, DataType)) - nbytes = max(nbytes, sizeof(jlrettype)) + if !(Base.issingletontype(jlrettype) && isa(jlrettype, DataType)) + nbytes = max(nbytes, sizeof(jlrettype)) end end for_each_uniontype_small(inner, UT) @@ -4873,7 +6212,9 @@ function is_sret(jlrettype) if jlrettype === Union{} # jlrettype == (jl_value_t*)jl_bottom_type return false - elseif Base.isstructtype(jlrettype) && Base.issingletontype(jlrettype) &&isa(jlrettype, DataType) + elseif Base.isstructtype(jlrettype) && + Base.issingletontype(jlrettype) && + isa(jlrettype, DataType) # jl_is_structtype(jlrettype) && jl_is_datatype_singleton((jl_datatype_t*)jlrettype) return false elseif jlrettype isa Union # jl_is_uniontype(jlrettype) @@ -4883,7 +6224,7 @@ function is_sret(jlrettype) end return false elseif !GPUCompiler.deserves_retbox(jlrettype) - rt = convert(LLVMType, jlrettype ) + rt = convert(LLVMType, jlrettype) if !isa(rt, LLVM.VoidType) && GPUCompiler.deserves_sret(jlrettype, rt) return true end @@ -4894,7 +6235,9 @@ function is_sret_union(jlrettype) if jlrettype === Union{} # jlrettype == (jl_value_t*)jl_bottom_type return false - elseif Base.isstructtype(jlrettype) && Base.issingletontype(jlrettype) &&isa(jlrettype, DataType) + elseif Base.isstructtype(jlrettype) && + Base.issingletontype(jlrettype) && + isa(jlrettype, DataType) # jl_is_structtype(jlrettype) && jl_is_datatype_singleton((jl_datatype_t*)jlrettype) return false elseif jlrettype isa Union # jl_is_uniontype(jlrettype) @@ -4907,19 +6250,23 @@ function is_sret_union(jlrettype) end # https://github.com/JuliaLang/julia/blob/0a696a3842750fcedca8832bc0aabe9096c7658f/src/codegen.cpp#L6812 -function get_return_info(jlrettype)::Tuple{Union{Nothing, Type}, Union{Nothing, Type}, Union{Nothing, Type}} +function get_return_info( + jlrettype, +)::Tuple{Union{Nothing,Type},Union{Nothing,Type},Union{Nothing,Type}} sret = nothing returnRoots = nothing rt = nothing if jlrettype === Union{} rt = Nothing - elseif Base.isstructtype(jlrettype) && Base.issingletontype(jlrettype) &&isa(jlrettype, DataType) + elseif Base.isstructtype(jlrettype) && + Base.issingletontype(jlrettype) && + isa(jlrettype, DataType) rt = Nothing elseif jlrettype isa Union nbytes = 0 allunbox = for_each_uniontype_small(jlrettype) do jlrettype if !(Base.issingletontype(jlrettype) && isa(jlrettype, DataType)) - nbytes = max(nbytes, sizeof(jlrettype)) + nbytes = max(nbytes, sizeof(jlrettype)) end end if nbytes != 0 @@ -4934,7 +6281,7 @@ function get_return_info(jlrettype)::Tuple{Union{Nothing, Type}, Union{Nothing, elseif jlrettype <: Tuple && in(Any, jlrettype.parameters) rt = Any elseif !GPUCompiler.deserves_retbox(jlrettype) - lRT = convert(LLVMType, jlrettype ) + lRT = convert(LLVMType, jlrettype) if !isa(lRT, LLVM.VoidType) && GPUCompiler.deserves_sret(jlrettype, lRT) sret = Ptr{jlrettype} tracked = CountTrackedPointers(lRT) @@ -4954,7 +6301,15 @@ function get_return_info(jlrettype)::Tuple{Union{Nothing, Type}, Union{Nothing, end # Modified from GPUCompiler/src/irgen.jl:365 lower_byval -function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function, actualRetType::Type, RetActivity, TT, run_enzyme) +function lower_convention( + functy::Type, + mod::LLVM.Module, + entry_f::LLVM.Function, + actualRetType::Type, + RetActivity, + TT, + run_enzyme, +) entry_ft = LLVM.function_type(entry_f) RT = LLVM.return_type(entry_ft) @@ -4977,9 +6332,17 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function # TODO removed implications retRemoved, parmsRemoved = removed_ret_parms(entry_f) - swiftself = any(any(map(k->kind(k)==kind(EnumAttribute("swiftself")), collect(parameter_attributes(entry_f, i)))) for i in 1:length(collect(parameters(entry_f)))) + swiftself = any( + any( + map( + k -> kind(k) == kind(EnumAttribute("swiftself")), + collect(parameter_attributes(entry_f, i)), + ), + ) for i = 1:length(collect(parameters(entry_f))) + ) @assert !swiftself "Swiftself attribute coming from differentiable context is not supported" - prargs = classify_arguments(functy, entry_ft, sret, returnRoots, swiftself, parmsRemoved) + prargs = + classify_arguments(functy, entry_ft, sret, returnRoots, swiftself, parmsRemoved) args = copy(prargs) filter!(args) do arg Base.@_inline_meta @@ -5011,7 +6374,12 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function push!(wrapper_types, typ) push!(wrapper_attrs, LLVM.Attribute[]) elseif arg.cc != GPUCompiler.BITS_REF - if TT != nothing && (TT.parameters[arg.arg_i] <: MixedDuplicated || TT.parameters[arg.arg_i] <: BatchMixedDuplicated) && run_enzyme + if TT != nothing && + ( + TT.parameters[arg.arg_i] <: MixedDuplicated || + TT.parameters[arg.arg_i] <: BatchMixedDuplicated + ) && + run_enzyme push!(boxedArgs, arg.arg_i) push!(raisedArgs, arg.arg_i) push!(wrapper_types, LLVM.PointerType(typ, Derived)) @@ -5022,7 +6390,12 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function end else # bits ref, and not boxed - if TT != nothing && (TT.parameters[arg.arg_i] <: MixedDuplicated || TT.parameters[arg.arg_i] <: BatchMixedDuplicated) && run_enzyme + if TT != nothing && + ( + TT.parameters[arg.arg_i] <: MixedDuplicated || + TT.parameters[arg.arg_i] <: BatchMixedDuplicated + ) && + run_enzyme push!(boxedArgs, arg.arg_i) push!(wrapper_types, typ) push!(wrapper_attrs, LLVM.Attribute[EnumAttribute("noalias")]) @@ -5048,10 +6421,24 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function set_subprogram!(wrapper_f, sfn) end - hasReturnsTwice = any(map(k->kind(k)==kind(EnumAttribute("returns_twice")), collect(function_attributes(entry_f)))) - hasNoInline = any(map(k->kind(k)==kind(EnumAttribute("noinline")), collect(function_attributes(entry_f)))) + hasReturnsTwice = any( + map( + k -> kind(k) == kind(EnumAttribute("returns_twice")), + collect(function_attributes(entry_f)), + ), + ) + hasNoInline = any( + map( + k -> kind(k) == kind(EnumAttribute("noinline")), + collect(function_attributes(entry_f)), + ), + ) if hasNoInline - LLVM.API.LLVMRemoveEnumAttributeAtIndex(entry_f, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), kind(EnumAttribute("noinline"))) + LLVM.API.LLVMRemoveEnumAttributeAtIndex( + entry_f, + reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), + kind(EnumAttribute("noinline")), + ) end push!(function_attributes(wrapper_f), EnumAttribute("returns_twice")) push!(function_attributes(entry_f), EnumAttribute("returns_twice")) @@ -5089,10 +6476,18 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function push!(nops, load!(builder, convert(LLVMType, arg.typ), parm)) elseif arg.arg_i in raisedArgs obj = emit_allocobj!(builder, arg.typ, "raisedArg") - bc = bitcast!(builder, obj, LLVM.PointerType(value_type(parm), addrspace(value_type(obj)))) + bc = bitcast!( + builder, + obj, + LLVM.PointerType(value_type(parm), addrspace(value_type(obj))), + ) store!(builder, parm, bc) emit_writebarrier!(builder, get_julia_inner_types(builder, obj, parm)) - addr = addrspacecast!(builder, bc, LLVM.PointerType(value_type(parm), Derived)) + addr = addrspacecast!( + builder, + bc, + LLVM.PointerType(value_type(parm), Derived), + ) push!(nops, addr) else push!(nops, parm) @@ -5133,17 +6528,25 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function dl = string(LLVM.datalayout(LLVM.parent(entry_f))) if sret if !in(0, parmsRemoved) - sretPtr = alloca!(builder, eltype(value_type(parameters(entry_f)[1])), "innersret") + sretPtr = alloca!( + builder, + eltype(value_type(parameters(entry_f)[1])), + "innersret", + ) ctx = LLVM.context(entry_f) if RetActivity <: Const metadata(sretPtr)["enzyme_inactive"] = MDNode(LLVM.Metadata[]) end - metadata(sretPtr)["enzyme_type"] = to_md(typetree(Ptr{actualRetType}, ctx, - dl, seen), ctx) + metadata(sretPtr)["enzyme_type"] = + to_md(typetree(Ptr{actualRetType}, ctx, dl, seen), ctx) push!(wrapper_args, sretPtr) end if returnRoots && !in(1, parmsRemoved) - retRootPtr = alloca!(builder, eltype(value_type(parameters(entry_f)[1+sret])), "innerreturnroots") + retRootPtr = alloca!( + builder, + eltype(value_type(parameters(entry_f)[1+sret])), + "innerreturnroots", + ) # retRootPtr = alloca!(builder, parameters(wrapper_f)[1]) push!(wrapper_args, retRootPtr) end @@ -5160,48 +6563,95 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function # copy the argument value to a stack slot, and reference it. ty = value_type(parm) if !isa(ty, LLVM.PointerType) - throw(AssertionError("ty is not a LLVM.PointerType: entry_f = $(entry_f), args = $(args), parm = $(parm), ty = $(ty)")) + throw( + AssertionError( + "ty is not a LLVM.PointerType: entry_f = $(entry_f), args = $(args), parm = $(parm), ty = $(ty)", + ), + ) end - ptr = alloca!(builder, eltype(ty), LLVM.name(parm)*".innerparm") + ptr = alloca!(builder, eltype(ty), LLVM.name(parm) * ".innerparm") if TT !== nothing && TT.parameters[arg.arg_i] <: Const metadata(ptr)["enzyme_inactive"] = MDNode(LLVM.Metadata[]) end ctx = LLVM.context(entry_f) - metadata(ptr)["enzyme_type"] = to_md(typetree(Ptr{arg.typ}, ctx, dl, seen), - ctx) + metadata(ptr)["enzyme_type"] = + to_md(typetree(Ptr{arg.typ}, ctx, dl, seen), ctx) if LLVM.addrspace(ty) != 0 ptr = addrspacecast!(builder, ptr, ty) end @assert eltype(ty) == value_type(wrapparm) store!(builder, wrapparm, ptr) push!(wrapper_args, ptr) - push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzyme_type", string(typetree(arg.typ, ctx, dl, seen)))) - push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(arg.typ))))) - push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_VALUE)))) + push!( + parameter_attributes(wrapper_f, arg.codegen.i - sret - returnRoots), + StringAttribute( + "enzyme_type", + string(typetree(arg.typ, ctx, dl, seen)), + ), + ) + push!( + parameter_attributes(wrapper_f, arg.codegen.i - sret - returnRoots), + StringAttribute( + "enzymejl_parmtype", + string(convert(UInt, unsafe_to_pointer(arg.typ))), + ), + ) + push!( + parameter_attributes(wrapper_f, arg.codegen.i - sret - returnRoots), + StringAttribute( + "enzymejl_parmtype_ref", + string(UInt(GPUCompiler.BITS_VALUE)), + ), + ) elseif arg.arg_i in raisedArgs wrapparm = load!(builder, convert(LLVMType, arg.typ), wrapparm) ctx = LLVM.context(wrapparm) push!(wrapper_args, wrapparm) - push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzyme_type", string(typetree(Base.RefValue{arg.typ}, ctx, dl, seen)))) - push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(arg.typ))))) - push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_REF)))) + push!( + parameter_attributes(wrapper_f, arg.codegen.i - sret - returnRoots), + StringAttribute( + "enzyme_type", + string(typetree(Base.RefValue{arg.typ}, ctx, dl, seen)), + ), + ) + push!( + parameter_attributes(wrapper_f, arg.codegen.i - sret - returnRoots), + StringAttribute( + "enzymejl_parmtype", + string(convert(UInt, unsafe_to_pointer(arg.typ))), + ), + ) + push!( + parameter_attributes(wrapper_f, arg.codegen.i - sret - returnRoots), + StringAttribute( + "enzymejl_parmtype_ref", + string(UInt(GPUCompiler.BITS_REF)), + ), + ) else push!(wrapper_args, wrapparm) for attr in collect(parameter_attributes(entry_f, arg.codegen.i)) - push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), attr) + push!( + parameter_attributes(wrapper_f, arg.codegen.i - sret - returnRoots), + attr, + ) end end end res = call!(builder, LLVM.function_type(entry_f), entry_f, wrapper_args) if get_subprogram(entry_f) !== nothing - metadata(res)[LLVM.MD_dbg] = DILocation( 0, 0, get_subprogram(entry_f) ) + metadata(res)[LLVM.MD_dbg] = DILocation(0, 0, get_subprogram(entry_f)) end callconv!(res, LLVM.callconv(entry_f)) if swiftself attr = EnumAttribute("swiftself") - LLVM.API.LLVMAddCallSiteAttribute(res, LLVM.API.LLVMAttributeIndex(1+sret+returnRoots), attr) + LLVM.API.LLVMAddCallSiteAttribute( + res, + LLVM.API.LLVMAttributeIndex(1 + sret + returnRoots), + attr, + ) end # Box union return, from https://github.com/JuliaLang/julia/blob/81813164963f38dcd779d65ecd222fad8d7ed437/src/cgutils.cpp#L3138 @@ -5229,9 +6679,28 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function nobj = if sretPtr !== nothing obj = emit_allocobj!(builder, jlrettype, "boxunion") llty = convert(LLVMType, jlrettype) - ld = load!(builder, llty, bitcast!(builder, sretPtr, LLVM.PointerType(llty, addrspace(value_type(sretPtr))))) - store!(builder, ld, bitcast!(builder, obj, LLVM.PointerType(llty, addrspace(value_type(obj))))) - emit_writebarrier!(builder, get_julia_inner_types(builder, obj, ld)) + ld = load!( + builder, + llty, + bitcast!( + builder, + sretPtr, + LLVM.PointerType(llty, addrspace(value_type(sretPtr))), + ), + ) + store!( + builder, + ld, + bitcast!( + builder, + obj, + LLVM.PointerType(llty, addrspace(value_type(obj))), + ), + ) + emit_writebarrier!( + builder, + get_julia_inner_types(builder, obj, ld), + ) # memcpy!(builder, bitcast!(builder, obj, LLVM.PointerType(T_int8, addrspace(value_type(obj)))), 0, bitcast!(builder, sretPtr, LLVM.PointerType(T_int8)), 0, LLVM.ConstantInt(T_int64, sizeof(jlrettype))) obj else @@ -5240,35 +6709,93 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function ret!(builder, obj) end - LLVM.API.LLVMAddCase(sw, LLVM.ConstantInt(value_type(scase), counter), BB) - counter+=1 + LLVM.API.LLVMAddCase( + sw, + LLVM.ConstantInt(value_type(scase), counter), + BB, + ) + counter += 1 return end for_each_uniontype_small(inner, actualRetType) position!(builder, def) ret!(builder, extract_value!(builder, res, 0)) - - push!(return_attributes(wrapper_f), StringAttribute("enzyme_type", string(typetree(actualRetType, ctx, dl, seen)))) - push!(return_attributes(wrapper_f), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(actualRetType))))) - push!(return_attributes(wrapper_f), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_REF)))) + + push!( + return_attributes(wrapper_f), + StringAttribute( + "enzyme_type", + string(typetree(actualRetType, ctx, dl, seen)), + ), + ) + push!( + return_attributes(wrapper_f), + StringAttribute( + "enzymejl_parmtype", + string(convert(UInt, unsafe_to_pointer(actualRetType))), + ), + ) + push!( + return_attributes(wrapper_f), + StringAttribute( + "enzymejl_parmtype_ref", + string(UInt(GPUCompiler.BITS_REF)), + ), + ) end elseif sret if sretPtr === nothing ret!(builder) else - push!(return_attributes(wrapper_f), StringAttribute("enzyme_type", string(typetree(actualRetType, ctx, dl, seen)))) - push!(return_attributes(wrapper_f), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(actualRetType))))) - push!(return_attributes(wrapper_f), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_REF)))) + push!( + return_attributes(wrapper_f), + StringAttribute( + "enzyme_type", + string(typetree(actualRetType, ctx, dl, seen)), + ), + ) + push!( + return_attributes(wrapper_f), + StringAttribute( + "enzymejl_parmtype", + string(convert(UInt, unsafe_to_pointer(actualRetType))), + ), + ) + push!( + return_attributes(wrapper_f), + StringAttribute( + "enzymejl_parmtype_ref", + string(UInt(GPUCompiler.BITS_REF)), + ), + ) ret!(builder, load!(builder, RT, sretPtr)) end elseif LLVM.return_type(entry_ft) == LLVM.VoidType() ret!(builder) else ctx = LLVM.context(wrapper_f) - push!(return_attributes(wrapper_f), StringAttribute("enzyme_type", string(typetree(actualRetType, ctx, dl, seen)))) - push!(return_attributes(wrapper_f), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(actualRetType))))) - push!(return_attributes(wrapper_f), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_REF)))) + push!( + return_attributes(wrapper_f), + StringAttribute( + "enzyme_type", + string(typetree(actualRetType, ctx, dl, seen)), + ), + ) + push!( + return_attributes(wrapper_f), + StringAttribute( + "enzymejl_parmtype", + string(convert(UInt, unsafe_to_pointer(actualRetType))), + ), + ) + push!( + return_attributes(wrapper_f), + StringAttribute( + "enzymejl_parmtype_ref", + string(UInt(GPUCompiler.BITS_REF)), + ), + ) ret!(builder, res) end dispose(builder) @@ -5279,12 +6806,18 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function linkage!(entry_f, LLVM.API.LLVMInternalLinkage) fixup_metadata!(entry_f) - + mi, rt = enzyme_custom_extract_mi(entry_f) attributes = function_attributes(wrapper_f) - push!(attributes, StringAttribute("enzymejl_mi", string(convert(UInt, pointer_from_objref(mi))))) - push!(attributes, StringAttribute("enzymejl_rt", string(convert(UInt, unsafe_to_pointer(rt))))) - + push!( + attributes, + StringAttribute("enzymejl_mi", string(convert(UInt, pointer_from_objref(mi)))), + ) + push!( + attributes, + StringAttribute("enzymejl_rt", string(convert(UInt, unsafe_to_pointer(rt)))), + ) + for prev in collect(function_attributes(entry_f)) if kind(prev) == kind(StringAttribute("enzyme_ta_norecur")) push!(attributes, prev) @@ -5313,11 +6846,15 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function end end if LLVM.version().major > 15 - if kind(prev) == kind(EnumAttribute("memory")) - old = MemoryEffect(value(attr)) - mem = MemoryEffect(( set_writing(getModRef(old, ArgMem)) << getLocationPos(ArgMem)) | (getModRef(old, InaccessibleMem) << getLocationPos(InaccessibleMem)) | (getModRef(old, Other) << getLocationPos(Other))) - push!(attributes, EnumAttribute("memory", mem.data)) - end + if kind(prev) == kind(EnumAttribute("memory")) + old = MemoryEffect(value(attr)) + mem = MemoryEffect( + (set_writing(getModRef(old, ArgMem)) << getLocationPos(ArgMem)) | + (getModRef(old, InaccessibleMem) << getLocationPos(InaccessibleMem)) | + (getModRef(old, Other) << getLocationPos(Other)), + ) + push!(attributes, EnumAttribute("memory", mem.data)) + end end if kind(prev) == kind(EnumAttribute("speculatable")) push!(attributes, prev) @@ -5336,26 +6873,45 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function if LLVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMReturnStatusAction) != 0 msg = sprint() do io println(io, string(mod)) - println(io, LLVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMPrintMessageAction)) + println( + io, + LLVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMPrintMessageAction), + ) println(io, string(wrapper_f)) - println(io, "parmsRemoved=", parmsRemoved, " retRemoved=", retRemoved, " prargs=", prargs) + println( + io, + "parmsRemoved=", + parmsRemoved, + " retRemoved=", + retRemoved, + " prargs=", + prargs, + ) println(io, "Broken function") end throw(LLVM.LLVMException(msg)) end - ModulePassManager() do pm + ModulePassManager() do pm always_inliner!(pm) LLVM.run!(pm, mod) end if !hasReturnsTwice - LLVM.API.LLVMRemoveEnumAttributeAtIndex(wrapper_f, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), kind(EnumAttribute("returns_twice"))) + LLVM.API.LLVMRemoveEnumAttributeAtIndex( + wrapper_f, + reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), + kind(EnumAttribute("returns_twice")), + ) end if hasNoInline - LLVM.API.LLVMRemoveEnumAttributeAtIndex(wrapper_f, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), kind(EnumAttribute("alwaysinline"))) + LLVM.API.LLVMRemoveEnumAttributeAtIndex( + wrapper_f, + reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), + kind(EnumAttribute("alwaysinline")), + ) push!(function_attributes(wrapper_f), EnumAttribute("noinline")) end - + # Fix phinodes used exclusively in extractvalue to be separate phi nodes phistofix = LLVM.PHIInst[] for bb in blocks(wrapper_f) @@ -5395,11 +6951,11 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function phis = LLVM.PHIInst[] for (i, t) in enumerate(LLVM.elements(st)) np = phi!(nb, t) - nvs = Tuple{LLVM.Value, LLVM.BasicBlock}[] - for (v, b) in LLVM.incoming(p) + nvs = Tuple{LLVM.Value,LLVM.BasicBlock}[] + for (v, b) in LLVM.incoming(p) prevbld = IRBuilder() position!(prevbld, terminator(b)) - push!(nvs, (extract_value!(prevbld, v, i-1), b)) + push!(nvs, (extract_value!(prevbld, v, i - 1), b)) end append!(LLVM.incoming(np), nvs) push!(phis, np) @@ -5429,7 +6985,9 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function if haskey(globals(mod), "llvm.used") eraseInst(mod, globals(mod)["llvm.used"]) for u in user.(collect(uses(entry_f))) - if isa(u, LLVM.GlobalVariable) && endswith(LLVM.name(u), "_slot") && startswith(LLVM.name(u), "julia") + if isa(u, LLVM.GlobalVariable) && + endswith(LLVM.name(u), "_slot") && + startswith(LLVM.name(u), "julia") eraseInst(mod, u) end end @@ -5438,7 +6996,10 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function if LLVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMReturnStatusAction) != 0 msg = sprint() do io println(io, string(mod)) - println(io, LVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMPrintMessageAction)) + println( + io, + LVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMPrintMessageAction), + ) println(io, string(wrapper_f)) println(io, "Broken function") end @@ -5449,7 +7010,7 @@ end using Random # returns arg, return -function no_type_setting(@nospecialize(specTypes); world=nothing) +function no_type_setting(@nospecialize(specTypes); world = nothing) # Even though the julia type here is ptr{int8}, the actual data can be something else if specTypes.parameters[1] == typeof(Random.XoshiroSimd.xoshiro_bulk_simd) return (true, false) @@ -5462,22 +7023,35 @@ end const DumpPreOpt = Ref(false) -function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; - libraries::Bool=true, deferred_codegen::Bool=true, optimize::Bool=true, toplevel::Bool=true, - strip::Bool=false, validate::Bool=true, only_entry::Bool=false, parent_job::Union{Nothing, CompilerJob} = nothing) - params = job.config.params +function GPUCompiler.codegen( + output::Symbol, + job::CompilerJob{<:EnzymeTarget}; + libraries::Bool = true, + deferred_codegen::Bool = true, + optimize::Bool = true, + toplevel::Bool = true, + strip::Bool = false, + validate::Bool = true, + only_entry::Bool = false, + parent_job::Union{Nothing,CompilerJob} = nothing, +) + params = job.config.params if params.run_enzyme @assert eltype(params.rt) != Union{} end expectedTapeType = params.expectedTapeType - mode = params.mode + mode = params.mode TT = params.TT width = params.width abiwrap = params.abiwrap - primal = job.source + primal = job.source modifiedBetween = params.modifiedBetween - if length(modifiedBetween) != length(TT.parameters) - throw(AssertionError("length(modifiedBetween) [aka $(length(modifiedBetween))] != length(TT.parameters) [aka $(length(TT.parameters))] at TT=$TT")) + if length(modifiedBetween) != length(TT.parameters) + throw( + AssertionError( + "length(modifiedBetween) [aka $(length(modifiedBetween))] != length(TT.parameters) [aka $(length(TT.parameters))] at TT=$TT", + ), + ) end returnPrimal = params.returnPrimal @@ -5487,21 +7061,40 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; if parent_job === nothing primal_target = DefaultCompilerTarget() primal_params = PrimalCompilerParams(mode) - primal_job = CompilerJob(primal, CompilerConfig(primal_target, primal_params; kernel=false), job.world) + primal_job = CompilerJob( + primal, + CompilerConfig(primal_target, primal_params; kernel = false), + job.world, + ) else - config2 = CompilerConfig(parent_job.config.target, parent_job.config.params; kernel=false, parent_job.config.entry_abi, parent_job.config.name, parent_job.config.always_inline) + config2 = CompilerConfig( + parent_job.config.target, + parent_job.config.params; + kernel = false, + parent_job.config.entry_abi, + parent_job.config.name, + parent_job.config.always_inline, + ) primal_job = CompilerJob(primal, config2, job.world) # TODO EnzymeInterp params, etc end - mod, meta = GPUCompiler.codegen(:llvm, primal_job; optimize=false, toplevel=toplevel, cleanup=false, validate=false, parent_job=parent_job) + mod, meta = GPUCompiler.codegen( + :llvm, + primal_job; + optimize = false, + toplevel = toplevel, + cleanup = false, + validate = false, + parent_job = parent_job, + ) prepare_llvm(mod, primal_job, meta) for f in functions(mod) permit_inlining!(f) end LLVM.ModulePassManager() do pm - API.AddPreserveNVVMPass!(pm, #=Begin=#true) + API.AddPreserveNVVMPass!(pm, true) #=Begin=# LLVM.run!(pm, mod) end @@ -5510,34 +7103,63 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; disableFallback = String[] - ForwardModeDerivatives = ("nrm2","dot","gemm","gemv","axpy","copy","scal", "symm", "syrk", "potrf") - ReverseModeDerivatives = ("nrm2","dot","gemm","gemv","axpy","copy","scal", "symm", "trmv", "syrk", "trmm", "trsm", "potrf") + ForwardModeDerivatives = + ("nrm2", "dot", "gemm", "gemv", "axpy", "copy", "scal", "symm", "syrk", "potrf") + ReverseModeDerivatives = ( + "nrm2", + "dot", + "gemm", + "gemv", + "axpy", + "copy", + "scal", + "symm", + "trmv", + "syrk", + "trmm", + "trsm", + "potrf", + ) ForwardModeTypes = ("s", "d", "c", "z") ReverseModeTypes = ("s", "d") # Tablegen BLAS does not support forward mode yet if !(mode == API.DEM_ForwardMode && params.runtimeActivity) for ty in (mode == API.DEM_ForwardMode ? ForwardModeTypes : ReverseModeTypes) - for func in (mode == API.DEM_ForwardMode ? ForwardModeDerivatives : ReverseModeDerivatives) + for func in ( + mode == API.DEM_ForwardMode ? ForwardModeDerivatives : + ReverseModeDerivatives + ) for prefix in ("", "cblas_") for ending in ("", "_", "64_", "_64_") - push!(disableFallback, prefix*ty*func*ending) + push!(disableFallback, prefix * ty * func * ending) end end end end end found = String[] - if bitcode_replacement() && API.EnzymeBitcodeReplacement(mod, disableFallback, found) != 0 + if bitcode_replacement() && + API.EnzymeBitcodeReplacement(mod, disableFallback, found) != 0 ModulePassManager() do pm instruction_combining!(pm) LLVM.run!(pm, mod) end toremove = [] for f in functions(mod) - if !any(map(k->kind(k)==kind(EnumAttribute("alwaysinline")), collect(function_attributes(f)))) + if !any( + map( + k -> kind(k) == kind(EnumAttribute("alwaysinline")), + collect(function_attributes(f)), + ), + ) continue end - if !any(map(k->kind(k)==kind(EnumAttribute("returns_twice")), collect(function_attributes(f)))) + if !any( + map( + k -> kind(k) == kind(EnumAttribute("returns_twice")), + collect(function_attributes(f)), + ), + ) push!(function_attributes(f), EnumAttribute("returns_twice")) push!(toremove, name(f)) end @@ -5578,7 +7200,14 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; for fname in toremove if haskey(functions(mod), fname) f = functions(mod)[fname] - LLVM.API.LLVMRemoveEnumAttributeAtIndex(f, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), kind(EnumAttribute("returns_twice"))) + LLVM.API.LLVMRemoveEnumAttributeAtIndex( + f, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + kind(EnumAttribute("returns_twice")), + ) end end GPUCompiler.@safe_warn "Using fallback BLAS replacements for ($found), performance may be degraded" @@ -5587,16 +7216,16 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; LLVM.run!(pm, mod) end end - + for f in functions(mod) mi, RT = enzyme_custom_extract_mi(f, false) if mi === nothing continue end - llRT, sret, returnRoots = get_return_info(RT) + llRT, sret, returnRoots = get_return_info(RT) retRemoved, parmsRemoved = removed_ret_parms(f) - + dl = string(LLVM.datalayout(LLVM.parent(f))) expectLen = (sret !== nothing) + (returnRoots !== nothing) @@ -5604,11 +7233,18 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; if isghostty(source_typ) || Core.Compiler.isconstType(source_typ) continue end - expectLen+=1 + expectLen += 1 end expectLen -= length(parmsRemoved) - swiftself = any(any(map(k->kind(k)==kind(EnumAttribute("swiftself")), collect(parameter_attributes(f, i)))) for i in 1:length(collect(parameters(f)))) + swiftself = any( + any( + map( + k -> kind(k) == kind(EnumAttribute("swiftself")), + collect(parameter_attributes(f, i)), + ), + ) for i = 1:length(collect(parameters(f))) + ) if swiftself expectLen += 1 @@ -5652,7 +7288,8 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; push!( parameter_attributes(f, arg.codegen.i), StringAttribute( - "enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(arg.typ))) + "enzymejl_parmtype", + string(convert(UInt, unsafe_to_pointer(arg.typ))), ), ) push!( @@ -5705,7 +7342,8 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; end end - if llRT !== nothing && LLVM.return_type(LLVM.function_type(f)) != LLVM.VoidType() + if llRT !== nothing && + LLVM.return_type(LLVM.function_type(f)) != LLVM.VoidType() @assert !retRemoved rest = typetree(llRT, ctx, dl) push!(return_attributes(f), StringAttribute("enzyme_type", string(rest))) @@ -5726,7 +7364,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; actualRetType = nothing lowerConvention = true customDerivativeNames = String[] - fnsToInject = Tuple{Symbol, Type}[] + fnsToInject = Tuple{Symbol,Type}[] for (mi, k) in meta.compiled k_name = GPUCompiler.safe_name(k.specfunc) has_custom_rule = false @@ -5735,12 +7373,14 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; caller = mi if mode == API.DEM_ForwardMode - has_custom_rule = EnzymeRules.has_frule_from_sig(specTypes; world, method_table, caller) + has_custom_rule = + EnzymeRules.has_frule_from_sig(specTypes; world, method_table, caller) if has_custom_rule @safe_debug "Found frule for" mi.specTypes end else - has_custom_rule = EnzymeRules.has_rrule_from_sig(specTypes; world, method_table, caller) + has_custom_rule = + EnzymeRules.has_rrule_from_sig(specTypes; world, method_table, caller) if has_custom_rule @safe_debug "Found rrule for" mi.specTypes end @@ -5754,7 +7394,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; if llvmfn == primalf actualRetType = k.ci.rettype end - + if EnzymeRules.noalias_from_sig(mi.specTypes; world, method_table, caller) push!(return_attributes(llvmfn), EnumAttribute("noalias")) for u in LLVM.uses(llvmfn) @@ -5764,22 +7404,26 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; end cf = LLVM.called_operand(c) if cf == llvmfn - LLVM.API.LLVMAddCallSiteAttribute(c, LLVM.API.LLVMAttributeReturnIndex, LLVM.EnumAttribute("noalias", 0)) + LLVM.API.LLVMAddCallSiteAttribute( + c, + LLVM.API.LLVMAttributeReturnIndex, + LLVM.EnumAttribute("noalias", 0), + ) end end end func = mi.specTypes.parameters[1] - + meth = mi.def name = meth.name - jlmod = meth.module + jlmod = meth.module - function handleCustom(llvmfn, name, attrs=[], setlink=true, noinl=true) + function handleCustom(llvmfn, name, attrs = [], setlink = true, noinl = true) attributes = function_attributes(llvmfn) custom[k_name] = linkage(llvmfn) if setlink - linkage!(llvmfn, LLVM.API.LLVMExternalLinkage) + linkage!(llvmfn, LLVM.API.LLVMExternalLinkage) end for a in attrs push!(attributes, a) @@ -5794,91 +7438,147 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; julia_activity_rule(llvmfn) if has_custom_rule - handleCustom(llvmfn, "enzyme_custom", [StringAttribute("enzyme_preserve_primal", "*")]) + handleCustom( + llvmfn, + "enzyme_custom", + [StringAttribute("enzyme_preserve_primal", "*")], + ) continue end sparam_vals = mi.specTypes.parameters[2:end] # mi.sparam_vals - if func == typeof(Base.eps) || func == typeof(Base.nextfloat) || func == typeof(Base.prevfloat) + if func == typeof(Base.eps) || + func == typeof(Base.nextfloat) || + func == typeof(Base.prevfloat) if LLVM.version().major <= 15 - handleCustom(llvmfn, "jl_inactive_inout", [StringAttribute("enzyme_inactive"), - EnumAttribute("readnone"), - EnumAttribute("speculatable"), - StringAttribute("enzyme_shouldrecompute") - ]) + handleCustom( + llvmfn, + "jl_inactive_inout", + [ + StringAttribute("enzyme_inactive"), + EnumAttribute("readnone"), + EnumAttribute("speculatable"), + StringAttribute("enzyme_shouldrecompute"), + ], + ) else - handleCustom(llvmfn, "jl_inactive_inout", [StringAttribute("enzyme_inactive"), - EnumAttribute("memory", NoEffects.data), - EnumAttribute("speculatable"), - StringAttribute("enzyme_shouldrecompute") - ]) + handleCustom( + llvmfn, + "jl_inactive_inout", + [ + StringAttribute("enzyme_inactive"), + EnumAttribute("memory", NoEffects.data), + EnumAttribute("speculatable"), + StringAttribute("enzyme_shouldrecompute"), + ], + ) end continue end if func == typeof(Base.to_tuple_type) if LLVM.version().major <= 15 - handleCustom(llvmfn, "jl_to_tuple_type", - [EnumAttribute("readonly"), + handleCustom( + llvmfn, + "jl_to_tuple_type", + [ + EnumAttribute("readonly"), EnumAttribute("inaccessiblememonly", 0), EnumAttribute("speculatable", 0), StringAttribute("enzyme_shouldrecompute"), StringAttribute("enzyme_inactive"), - ]) + ], + ) else - handleCustom(llvmfn, "jl_to_tuple_type", - [ - EnumAttribute("memory", MemoryEffect((MRI_NoModRef << getLocationPos(ArgMem)) | (MRI_Ref << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data), + handleCustom( + llvmfn, + "jl_to_tuple_type", + [ + EnumAttribute( + "memory", + MemoryEffect( + (MRI_NoModRef << getLocationPos(ArgMem)) | + (MRI_Ref << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ), EnumAttribute("inaccessiblememonly", 0), EnumAttribute("speculatable", 0), StringAttribute("enzyme_shouldrecompute"), StringAttribute("enzyme_inactive"), - ]) + ], + ) end continue end if func == typeof(Base.mightalias) if LLVM.version().major <= 15 - handleCustom(llvmfn, "jl_mightalias", - [EnumAttribute("readonly"), + handleCustom( + llvmfn, + "jl_mightalias", + [ + EnumAttribute("readonly"), StringAttribute("enzyme_shouldrecompute"), StringAttribute("enzyme_inactive"), StringAttribute("enzyme_no_escaping_allocation"), EnumAttribute("nofree"), StringAttribute("enzyme_ta_norecur"), - ], true, false) + ], + true, + false, + ) else - handleCustom(llvmfn, "jl_mightalias", - [ + handleCustom( + llvmfn, + "jl_mightalias", + [ EnumAttribute("memory", ReadOnlyEffects.data), StringAttribute("enzyme_shouldrecompute"), StringAttribute("enzyme_inactive"), StringAttribute("enzyme_no_escaping_allocation"), EnumAttribute("nofree"), StringAttribute("enzyme_ta_norecur"), - ], true, false) + ], + true, + false, + ) end continue end if func == typeof(Base.Threads.threadid) || func == typeof(Base.Threads.nthreads) name = (func == typeof(Base.Threads.threadid)) ? "jl_threadid" : "jl_nthreads" if LLVM.version().major <= 15 - handleCustom(llvmfn, name, - [EnumAttribute("readonly"), + handleCustom( + llvmfn, + name, + [ + EnumAttribute("readonly"), EnumAttribute("inaccessiblememonly"), EnumAttribute("speculatable"), StringAttribute("enzyme_shouldrecompute"), StringAttribute("enzyme_inactive"), - StringAttribute("enzyme_no_escaping_allocation") - ]) + StringAttribute("enzyme_no_escaping_allocation"), + ], + ) else - handleCustom(llvmfn, name, - [EnumAttribute("memory", MemoryEffect((MRI_NoModRef << getLocationPos(ArgMem)) | (MRI_Ref << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data), + handleCustom( + llvmfn, + name, + [ + EnumAttribute( + "memory", + MemoryEffect( + (MRI_NoModRef << getLocationPos(ArgMem)) | + (MRI_Ref << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ), EnumAttribute("speculatable"), StringAttribute("enzyme_shouldrecompute"), StringAttribute("enzyme_inactive"), - StringAttribute("enzyme_no_escaping_allocation") - ]) + StringAttribute("enzyme_no_escaping_allocation"), + ], + ) end continue end @@ -5889,45 +7589,143 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; if func == typeof(Base.Checked.throw_overflowerr_binaryop) llvmfn = functions(mod)[k.specfunc] if LLVM.version().major <= 15 - handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("readonly"), StringAttribute("enzyme_ta_norecur")]) + handleCustom( + llvmfn, + "enz_noop", + [ + StringAttribute("enzyme_inactive"), + EnumAttribute("readonly"), + StringAttribute("enzyme_ta_norecur"), + ], + ) else - handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), - EnumAttribute("memory", ReadOnlyEffects.data), - StringAttribute("enzyme_ta_norecur")]) + handleCustom( + llvmfn, + "enz_noop", + [ + StringAttribute("enzyme_inactive"), + EnumAttribute("memory", ReadOnlyEffects.data), + StringAttribute("enzyme_ta_norecur"), + ], + ) end continue end - if EnzymeRules.is_inactive_from_sig(specTypes; world, method_table, caller) && has_method(Tuple{typeof(EnzymeRules.inactive), specTypes.parameters...}, world, method_table) - handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("nofree"), StringAttribute("enzyme_no_escaping_allocation"), StringAttribute("enzyme_ta_norecur")]) + if EnzymeRules.is_inactive_from_sig(specTypes; world, method_table, caller) && + has_method( + Tuple{typeof(EnzymeRules.inactive),specTypes.parameters...}, + world, + method_table, + ) + handleCustom( + llvmfn, + "enz_noop", + [ + StringAttribute("enzyme_inactive"), + EnumAttribute("nofree"), + StringAttribute("enzyme_no_escaping_allocation"), + StringAttribute("enzyme_ta_norecur"), + ], + ) continue end - if EnzymeRules.is_inactive_noinl_from_sig(specTypes; world, method_table, caller) && has_method(Tuple{typeof(EnzymeRules.inactive_noinl), specTypes.parameters...}, world, method_table) - handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("nofree"), StringAttribute("enzyme_no_escaping_allocation"), StringAttribute("enzyme_ta_norecur")], false, false) + if EnzymeRules.is_inactive_noinl_from_sig(specTypes; world, method_table, caller) && + has_method( + Tuple{typeof(EnzymeRules.inactive_noinl),specTypes.parameters...}, + world, + method_table, + ) + handleCustom( + llvmfn, + "enz_noop", + [ + StringAttribute("enzyme_inactive"), + EnumAttribute("nofree"), + StringAttribute("enzyme_no_escaping_allocation"), + StringAttribute("enzyme_ta_norecur"), + ], + false, + false, + ) for bb in blocks(llvmfn) for inst in instructions(bb) if isa(inst, LLVM.CallInst) - LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), StringAttribute("no_escaping_allocation")) - LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), StringAttribute("enzyme_inactive")) - LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), EnumAttribute("nofree")) + LLVM.API.LLVMAddCallSiteAttribute( + inst, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + StringAttribute("no_escaping_allocation"), + ) + LLVM.API.LLVMAddCallSiteAttribute( + inst, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + StringAttribute("enzyme_inactive"), + ) + LLVM.API.LLVMAddCallSiteAttribute( + inst, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + EnumAttribute("nofree"), + ) end end end continue end if func === typeof(Base.match) - handleCustom(llvmfn, "base_match", [StringAttribute("enzyme_inactive"), EnumAttribute("nofree"), StringAttribute("enzyme_no_escaping_allocation")], false, false) + handleCustom( + llvmfn, + "base_match", + [ + StringAttribute("enzyme_inactive"), + EnumAttribute("nofree"), + StringAttribute("enzyme_no_escaping_allocation"), + ], + false, + false, + ) for bb in blocks(llvmfn) for inst in instructions(bb) if isa(inst, LLVM.CallInst) - LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), StringAttribute("no_escaping_allocation")) - LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), StringAttribute("enzyme_inactive")) - LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), EnumAttribute("nofree")) + LLVM.API.LLVMAddCallSiteAttribute( + inst, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + StringAttribute("no_escaping_allocation"), + ) + LLVM.API.LLVMAddCallSiteAttribute( + inst, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + StringAttribute("enzyme_inactive"), + ) + LLVM.API.LLVMAddCallSiteAttribute( + inst, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + EnumAttribute("nofree"), + ) end end end continue end - if func == typeof(Base.enq_work) && length(sparam_vals) == 1 && first(sparam_vals) <: Task + if func == typeof(Base.enq_work) && + length(sparam_vals) == 1 && + first(sparam_vals) <: Task handleCustom(llvmfn, "jl_enq_work", [StringAttribute("enzyme_ta_norecur")]) continue end @@ -5943,7 +7741,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; end continue end - + name, toinject, T = find_math_method(func, sparam_vals) if name === nothing continue @@ -5956,17 +7754,25 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; # If sret, force lower of primitive math fn sret = get_return_info(k.ci.rettype)[2] !== nothing if sret - cur = llvmfn == primalf - llvmfn, _, boxedArgs, loweredArgs = lower_convention(mi.specTypes, mod, llvmfn, k.ci.rettype, Duplicated, nothing, params.run_enzyme) - if cur - primalf = llvmfn - lowerConvention = false - end - k_name = LLVM.name(llvmfn) + cur = llvmfn == primalf + llvmfn, _, boxedArgs, loweredArgs = lower_convention( + mi.specTypes, + mod, + llvmfn, + k.ci.rettype, + Duplicated, + nothing, + params.run_enzyme, + ) + if cur + primalf = llvmfn + lowerConvention = false + end + k_name = LLVM.name(llvmfn) end name = string(name) - name = T == Float32 ? name*"f" : name + name = T == Float32 ? name * "f" : name attrs = if LLVM.version().major <= 15 [LLVM.EnumAttribute("readnone"), StringAttribute("enzyme_shouldrecompute")] @@ -5985,13 +7791,18 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; llvmfn = primalf FT = LLVM.function_type(llvmfn) - wrapper_f = LLVM.Function(mod, safe_name(LLVM.name(llvmfn)*"mustwrap"), FT) + wrapper_f = LLVM.Function(mod, safe_name(LLVM.name(llvmfn) * "mustwrap"), FT) let builder = IRBuilder() entry = BasicBlock(wrapper_f, "entry") position!(builder, entry) - res = call!(builder, LLVM.function_type(llvmfn), llvmfn, collect(parameters(wrapper_f))) + res = call!( + builder, + LLVM.function_type(llvmfn), + llvmfn, + collect(parameters(wrapper_f)), + ) sretkind = kind(if LLVM.version().major >= 12 TypeAttribute("sret", LLVM.Int32Type()) @@ -6001,7 +7812,11 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; for idx in length(collect(parameters(llvmfn))) for attr in collect(parameter_attributes(llvmfn, idx)) if kind(attr) == sretkind - LLVM.API.LLVMAddCallSiteAttribute(res, LLVM.API.LLVMAttributeIndex(idx), attr) + LLVM.API.LLVMAddCallSiteAttribute( + res, + LLVM.API.LLVMAttributeIndex(idx), + attr, + ) end end end @@ -6017,8 +7832,14 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; attributes = function_attributes(wrapper_f) push!(attributes, StringAttribute("enzymejl_world", string(job.world))) mi, rt = enzyme_custom_extract_mi(primalf) - push!(attributes, StringAttribute("enzymejl_mi", string(convert(UInt, pointer_from_objref(mi))))) - push!(attributes, StringAttribute("enzymejl_rt", string(convert(UInt, unsafe_to_pointer(rt))))) + push!( + attributes, + StringAttribute("enzymejl_mi", string(convert(UInt, pointer_from_objref(mi)))), + ) + push!( + attributes, + StringAttribute("enzymejl_rt", string(convert(UInt, unsafe_to_pointer(rt)))), + ) primalf = wrapper_f end @@ -6027,10 +7848,18 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; primalf, returnRoots = primalf, false - if lowerConvention - primalf, returnRoots, boxedArgs, loweredArgs = lower_convention(source_sig, mod, primalf, actualRetType, job.config.params.rt, TT, params.run_enzyme) + if lowerConvention + primalf, returnRoots, boxedArgs, loweredArgs = lower_convention( + source_sig, + mod, + primalf, + actualRetType, + job.config.params.rt, + TT, + params.run_enzyme, + ) end - + if primal_job.config.target isa GPUCompiler.NativeCompilerTarget target_machine = JIT.get_tm() else @@ -6042,13 +7871,13 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; device_module = false if parent_job !== nothing if parent_job.config.target isa GPUCompiler.PTXCompilerTarget || - parent_job.config.target isa GPUCompiler.GCNCompilerTarget || - parent_job.config.target isa GPUCompiler.MetalCompilerTarget + parent_job.config.target isa GPUCompiler.GCNCompilerTarget || + parent_job.config.target isa GPUCompiler.MetalCompilerTarget parallel = true device_module = true end if parent_job.config.target isa GPUCompiler.GCNCompilerTarget || - parent_job.config.target isa GPUCompiler.MetalCompilerTarget + parent_job.config.target isa GPUCompiler.MetalCompilerTarget process_module = true end end @@ -6074,7 +7903,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; if process_module GPUCompiler.optimize_module!(parent_job, mod) end - + for name in ("gpu_report_exception", "report_exception") if haskey(functions(mod), name) exc = functions(mod)[name] @@ -6092,42 +7921,51 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; for f in functions(mod), bb in blocks(f), inst in instructions(bb) fn = isa(inst, LLVM.CallInst) ? LLVM.called_operand(inst) : nothing - if !API.HasFromStack(inst) && isa(inst, LLVM.CallInst) && (!isa(fn, LLVM.Function) || isempty(blocks(fn))) - legal, source_typ = abs_typeof(inst) + if !API.HasFromStack(inst) && + isa(inst, LLVM.CallInst) && + (!isa(fn, LLVM.Function) || isempty(blocks(fn))) + legal, source_typ, byref = abs_typeof(inst) codegen_typ = value_type(inst) if legal typ = if codegen_typ isa LLVM.PointerType - llvm_source_typ = convert(LLVMType, source_typ; allow_boxed=true) + llvm_source_typ = convert(LLVMType, source_typ; allow_boxed = true) # pointers are used for multiple kinds of arguments # - literal pointer values if source_typ <: Ptr || source_typ <: Core.LLVMPtr source_typ - elseif llvm_source_typ isa LLVM.PointerType - #if llvm_source_typ != codegen_typ - # throw(AssertionError("llvmtype ($llvm_source_typ) is not codegen_typ ($codegen_typ), source_typ = $source_typ within $(string(inst))")) - #end - # push!(args, (cc=MUT_REF, typ=source_typ, name=source_name, idx=codegen_i)) + elseif byref == GPUCompiler.MUT_REF || byref == GPUCompiler.BITS_REF Ptr{source_typ} - # - references to aggregates else - @assert llvm_source_typ != codegen_typ - # push!(args, (cc=BITS_REF, typ=source_typ, name=source_name, idx=codegen_i)) - Ptr{source_typ} + println(string(mod)) + @show legal, source_typ, byref, llvm_source_typ, codegen_typ, string(inst) + @assert false end else source_typ end if isa(inst, LLVM.CallInst) - LLVM.API.LLVMAddCallSiteAttribute(inst, LLVM.API.LLVMAttributeReturnIndex, StringAttribute("enzyme_type", string(typetree(typ, ctx, dl, seen)))) + LLVM.API.LLVMAddCallSiteAttribute( + inst, + LLVM.API.LLVMAttributeReturnIndex, + StringAttribute( + "enzyme_type", + string(typetree(typ, ctx, dl, seen)), + ), + ) else metadata(inst)["enzyme_type"] = to_md(typetree(typ, ctx, dl, seen), ctx) end elseif codegen_typ == T_prjlvalue if isa(inst, LLVM.CallInst) - LLVM.API.LLVMAddCallSiteAttribute(inst, LLVM.API.LLVMAttributeReturnIndex, StringAttribute("enzyme_type", "{[-1]:Pointer}")) + LLVM.API.LLVMAddCallSiteAttribute( + inst, + LLVM.API.LLVMAttributeReturnIndex, + StringAttribute("enzyme_type", "{[-1]:Pointer}"), + ) else - metadata(inst)["enzyme_type"] = to_md(typetree(Ptr{Cvoid}, ctx, dl, seen), ctx) + metadata(inst)["enzyme_type"] = + to_md(typetree(Ptr{Cvoid}, ctx, dl, seen), ctx) end end end @@ -6139,20 +7977,36 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; if length(blocks(fn)) != 0 continue end - + intr = LLVM.API.LLVMGetIntrinsicID(fn) - if intr == LLVM.Intrinsic("llvm.memcpy").id || intr == LLVM.Intrinsic("llvm.memmove").id || intr == LLVM.Intrinsic("llvm.memset").id - legal, jTy = abs_typeof(operands(inst)[1]) - sz = if intr == LLVM.Intrinsic("llvm.memcpy").id || intr == LLVM.Intrinsic("llvm.memmove").id - operands(inst)[3] - else - operands(inst)[3] - end + if intr == LLVM.Intrinsic("llvm.memcpy").id || + intr == LLVM.Intrinsic("llvm.memmove").id || + intr == LLVM.Intrinsic("llvm.memset").id + legal, jTy, byref = abs_typeof(operands(inst)[1]) + sz = + if intr == LLVM.Intrinsic("llvm.memcpy").id || + intr == LLVM.Intrinsic("llvm.memmove").id + operands(inst)[3] + else + operands(inst)[3] + end if legal && Base.isconcretetype(jTy) - if !(jTy isa UnionAll || jTy isa Union || jTy == Union{} || jTy === Tuple || (is_concrete_tuple(jTy) && any(T2 isa Core.TypeofVararg for T2 in jTy.parameters))) + if !( + jTy isa UnionAll || + jTy isa Union || + jTy == Union{} || + jTy === Tuple || + ( + is_concrete_tuple(jTy) && + any(T2 isa Core.TypeofVararg for T2 in jTy.parameters) + ) + ) if isa(sz, LLVM.ConstantInt) && sizeof(jTy) == convert(Int, sz) - metadata(inst)["enzyme_truetype"] = to_fullmd(jTy) + md = to_fullmd(jTy) + @assert byref == GPUCompiler.BITS_REF || + byref == GPUCompiler.MUT_REF + metadata(inst)["enzyme_truetype"] = md end end end @@ -6164,15 +8018,19 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; continue end - legal, jTy = abs_typeof(inst, true) + legal, jTy, byref = abs_typeof(inst, true) if !legal continue end if !guaranteed_const_nongen(jTy, world) continue - end + end if isa(inst, LLVM.CallInst) - LLVM.API.LLVMAddCallSiteAttribute(inst, LLVM.API.LLVMAttributeReturnIndex, StringAttribute("enzyme_inactive")) + LLVM.API.LLVMAddCallSiteAttribute( + inst, + LLVM.API.LLVMAttributeReturnIndex, + StringAttribute("enzyme_inactive"), + ) else metadata(inst)["enzyme_inactive"] = MDNode(LLVM.Metadata[]) end @@ -6186,10 +8044,17 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; Ty = eltype(FT) reg = active_reg_inner(Ty, (), world) if reg == DupState || reg == MixedState - swiftself = any(any(map(k->kind(k)==kind(EnumAttribute("swiftself")), collect(parameter_attributes(primalf, i)))) for i in 1:length(collect(parameters(primalf)))) - todo = LLVM.Value[parameters(primalf)[1+swiftself]] - done = Set{LLVM.Value}() - doneInst = Set{LLVM.Instruction}() + swiftself = any( + any( + map( + k -> kind(k) == kind(EnumAttribute("swiftself")), + collect(parameter_attributes(primalf, i)), + ), + ) for i = 1:length(collect(parameters(primalf))) + ) + todo = LLVM.Value[parameters(primalf)[1+swiftself]] + done = Set{LLVM.Value}() + doneInst = Set{LLVM.Instruction}() while length(todo) != 0 cur = pop!(todo) if cur in done @@ -6206,7 +8071,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; end if !mayWriteToMemory(user) - slegal , foundv = abs_typeof(user) + slegal, foundv, byref = abs_typeof(user) if slegal reg2 = active_reg_inner(foundv, (), world) if reg2 == ActiveState || reg2 == AnyState @@ -6221,7 +8086,9 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; # we are capturing the variable if operands(user)[1] == cur base = operands(user)[2] - while isa(base, LLVM.BitCastInst) || isa(base, LLVM.AddrSpaceCastInst) || isa(base, LLVM.GetElementPtrInst) + while isa(base, LLVM.BitCastInst) || + isa(base, LLVM.AddrSpaceCastInst) || + isa(base, LLVM.GetElementPtrInst) base = operands(base)[1] end if isa(base, LLVM.AllocaInst) @@ -6232,7 +8099,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; end # we are storing into the variable if operands(user)[2] == cur - slegal , foundv = abs_typeof(operands(user)[1]) + slegal, foundv, byref = abs_typeof(operands(user)[1]) if slegal reg2 = active_reg_inner(foundv, (), world) if reg2 == AnyState @@ -6253,13 +8120,16 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; end nm = LLVM.name(called) - if nm == "ijl_alloc_array_1d" || nm == "jl_alloc_array_1d" || - nm == "ijl_alloc_array_2d" || nm == "jl_alloc_array_2d" || - nm == "ijl_alloc_array_3d" || nm == "jl_alloc_array_3d" + if nm == "ijl_alloc_array_1d" || + nm == "jl_alloc_array_1d" || + nm == "ijl_alloc_array_2d" || + nm == "jl_alloc_array_2d" || + nm == "ijl_alloc_array_3d" || + nm == "jl_alloc_array_3d" continue end if is_readonly(called) - slegal , foundv = abs_typeof(user) + slegal, foundv, byref = abs_typeof(user) if slegal reg2 = active_reg_inner(foundv, (), world) if reg2 == ActiveState || reg2 == AnyState @@ -6269,13 +8139,15 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; push!(todo, user) continue end - if !isempty(blocks(called)) && length(collect(LLVM.uses(called))) == 1 - for (parm, op) in zip(LLVM.parameters(called), operands(user)[1:end-1]) + if !isempty(blocks(called)) && + length(collect(LLVM.uses(called))) == 1 + for (parm, op) in + zip(LLVM.parameters(called), operands(user)[1:end-1]) if op == cur push!(todo, parm) end end - slegal , foundv = abs_typeof(user) + slegal, foundv, byref = abs_typeof(user) if slegal reg2 = active_reg_inner(foundv, (), world) if reg2 == ActiveState || reg2 == AnyState @@ -6290,10 +8162,14 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; builder = LLVM.IRBuilder() position!(builder, user) - resstr = "Function argument passed to autodiff cannot be proven readonly.\nIf the the function argument cannot contain derivative data, instead call autodiff(Mode, Const(f), ...)\nSee https://enzyme.mit.edu/index.fcgi/julia/stable/faq/#Activity-of-temporary-storage for more information.\nThe potentially writing call is "*string(user)*", using "*string(cur) - slegal , foundv = absint(cur) + resstr = + "Function argument passed to autodiff cannot be proven readonly.\nIf the the function argument cannot contain derivative data, instead call autodiff(Mode, Const(f), ...)\nSee https://enzyme.mit.edu/index.fcgi/julia/stable/faq/#Activity-of-temporary-storage for more information.\nThe potentially writing call is " * + string(user) * + ", using " * + string(cur) + slegal, foundv = absint(cur) if slegal - resstr *= "of type "*string(foundv) + resstr *= "of type " * string(foundv) end emit_error(builder, user, resstr, EnzymeMutabilityException) end @@ -6339,7 +8215,10 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; cf = LLVM.called_operand(tmp) if isa(cf, LLVM.Function) nm = LLVM.name(cf) - if nm == "gpu_signal_exception" || nm == "gpu_report_exception" || nm == "ijl_throw" || nm == "jl_throw" + if nm == "gpu_signal_exception" || + nm == "gpu_report_exception" || + nm == "ijl_throw" || + nm == "jl_throw" shouldemit = false break end @@ -6350,14 +8229,28 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; if shouldemit b = IRBuilder() position!(b, term) - emit_error(b, term, "Enzyme: The original primal code hits this error condition, thus differentiating it does not make sense") + emit_error( + b, + term, + "Enzyme: The original primal code hits this error condition, thus differentiating it does not make sense", + ) end end end - if !any(map(k->kind(k)==kind(EnumAttribute("alwaysinline")), collect(function_attributes(f)))) + if !any( + map( + k -> kind(k) == kind(EnumAttribute("alwaysinline")), + collect(function_attributes(f)), + ), + ) continue end - if !any(map(k->kind(k)==kind(EnumAttribute("returns_twice")), collect(function_attributes(f)))) + if !any( + map( + k -> kind(k) == kind(EnumAttribute("returns_twice")), + collect(function_attributes(f)), + ), + ) push!(function_attributes(f), EnumAttribute("returns_twice")) push!(toremove, name(f)) end @@ -6369,7 +8262,14 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; for fname in toremove if haskey(functions(mod), fname) f = functions(mod)[fname] - LLVM.API.LLVMRemoveEnumAttributeAtIndex(f, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), kind(EnumAttribute("returns_twice"))) + LLVM.API.LLVMRemoveEnumAttributeAtIndex( + f, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + kind(EnumAttribute("returns_twice")), + ) end end else @@ -6378,86 +8278,192 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; end LLVM.ModulePassManager() do pm - API.AddPreserveNVVMPass!(pm, #=Begin=#false) + API.AddPreserveNVVMPass!(pm, false) #=Begin=# LLVM.run!(pm, mod) end if parent_job !== nothing if parent_job.config.target isa GPUCompiler.PTXCompilerTarget - arg1 = ("sin", "cos", "tan", "log2", "exp", "exp2", - "exp10", "cosh", "sinh", "tanh", "atan", - "asin", "acos", "log", "log10", "log1p", "acosh", - "asinh", "atanh", "expm1", "cbrt", - "rcbrt", "j0", "j1", "y0", "y1", - "erf", "erfinv", "erfc", "erfcx", "erfcinv", - "remquo", "tgamma", - "round", "fdim", "logb", "isinf", - "sqrt", "fabs", "atan2", ) - # isinf, finite "modf", "fmod", "remainder", - # "rnorm3d", "norm4d", "rnorm4d", "norm", "rnorm", - # "hypot", "rhypot", - # "yn", "jn", "norm3d", "ilogb", powi - # "normcdfinv", "normcdf", "lgamma", "ldexp", "scalbn", "frexp", - # arg1 = ("atan2", "fmax", "pow") - for n in arg1, (T, pf, lpf) in ((LLVM.DoubleType(), "", "f64"), (LLVM.FloatType(), "f", "f32")) - fname = "__nv_"*n*pf - if !haskey(functions(mod), fname) - FT = LLVM.FunctionType(T, [T], vararg=false) - wrapper_f = LLVM.Function(mod, fname, FT) - llname = "llvm."*n*"."*lpf - push!(function_attributes(wrapper_f), StringAttribute("implements", llname)) - end - end - end + arg1 = ( + "sin", + "cos", + "tan", + "log2", + "exp", + "exp2", + "exp10", + "cosh", + "sinh", + "tanh", + "atan", + "asin", + "acos", + "log", + "log10", + "log1p", + "acosh", + "asinh", + "atanh", + "expm1", + "cbrt", + "rcbrt", + "j0", + "j1", + "y0", + "y1", + "erf", + "erfinv", + "erfc", + "erfcx", + "erfcinv", + "remquo", + "tgamma", + "round", + "fdim", + "logb", + "isinf", + "sqrt", + "fabs", + "atan2", + ) + # isinf, finite "modf", "fmod", "remainder", + # "rnorm3d", "norm4d", "rnorm4d", "norm", "rnorm", + # "hypot", "rhypot", + # "yn", "jn", "norm3d", "ilogb", powi + # "normcdfinv", "normcdf", "lgamma", "ldexp", "scalbn", "frexp", + # arg1 = ("atan2", "fmax", "pow") + for n in arg1, + (T, pf, lpf) in + ((LLVM.DoubleType(), "", "f64"), (LLVM.FloatType(), "f", "f32")) + + fname = "__nv_" * n * pf + if !haskey(functions(mod), fname) + FT = LLVM.FunctionType(T, [T], vararg = false) + wrapper_f = LLVM.Function(mod, fname, FT) + llname = "llvm." * n * "." * lpf + push!( + function_attributes(wrapper_f), + StringAttribute("implements", llname), + ) + end + end + end if parent_job.config.target isa GPUCompiler.GCNCompilerTarget - arg1 = ("acos", "acosh", "asin", - "asinh", "atan2", "atan", - "atanh", "cbrt", "ceil", - "copysign", "cos", "native_cos", - "cosh", "cospi", "i0", - "i1", "erfc", "erfcinv", - "erfcx", "erf", "erfinv", - "exp10", "native_exp10", "exp2", - "exp", "native_exp", "expm1", - "fabs", "fdim", "floor", - "fma", "fmax", "fmin", - "fmod", "frexp", "hypot", - "ilogb", "isfinite", "isinf", - "isnan", "j0", "j1", - "ldexp", "lgamma", "log10", - "native_log10", "log1p", "log2", - "log2", "logb", "log", - "native_log", "modf", "nearbyint", - "nextafter", "len3", "len4", - "ncdf", "ncdfinv", "pow", - "pown", "rcbrt", "remainder", - "remquo", "rhypot", "rint", - "rlen3", "rlen4", "round", - "rsqrt", "scalb", "scalbn", - "signbit", "sincos", "sincospi", - "sin", "native_sin", "sinh", - "sinpi", "sqrt", "native_sqrt", - "tan", "tanh", "tgamma", - "trunc", "y0", "y1") - for n in arg1, (T, pf, lpf) in ((LLVM.DoubleType(), "", "f64"), (LLVM.FloatType(), "f", "f32")) - fname = "__ocml_"*n*"_"*lpf + arg1 = ( + "acos", + "acosh", + "asin", + "asinh", + "atan2", + "atan", + "atanh", + "cbrt", + "ceil", + "copysign", + "cos", + "native_cos", + "cosh", + "cospi", + "i0", + "i1", + "erfc", + "erfcinv", + "erfcx", + "erf", + "erfinv", + "exp10", + "native_exp10", + "exp2", + "exp", + "native_exp", + "expm1", + "fabs", + "fdim", + "floor", + "fma", + "fmax", + "fmin", + "fmod", + "frexp", + "hypot", + "ilogb", + "isfinite", + "isinf", + "isnan", + "j0", + "j1", + "ldexp", + "lgamma", + "log10", + "native_log10", + "log1p", + "log2", + "log2", + "logb", + "log", + "native_log", + "modf", + "nearbyint", + "nextafter", + "len3", + "len4", + "ncdf", + "ncdfinv", + "pow", + "pown", + "rcbrt", + "remainder", + "remquo", + "rhypot", + "rint", + "rlen3", + "rlen4", + "round", + "rsqrt", + "scalb", + "scalbn", + "signbit", + "sincos", + "sincospi", + "sin", + "native_sin", + "sinh", + "sinpi", + "sqrt", + "native_sqrt", + "tan", + "tanh", + "tgamma", + "trunc", + "y0", + "y1", + ) + for n in arg1, + (T, pf, lpf) in + ((LLVM.DoubleType(), "", "f64"), (LLVM.FloatType(), "f", "f32")) + + fname = "__ocml_" * n * "_" * lpf if !haskey(functions(mod), fname) - FT = LLVM.FunctionType(T, [T], vararg=false) + FT = LLVM.FunctionType(T, [T], vararg = false) wrapper_f = LLVM.Function(mod, fname, FT) - llname = "llvm."*n*"."*lpf - push!(function_attributes(wrapper_f), StringAttribute("implements", llname)) + llname = "llvm." * n * "." * lpf + push!( + function_attributes(wrapper_f), + StringAttribute("implements", llname), + ) end end end - end + end for (name, fnty) in fnsToInject - for (T, JT, pf) in ((LLVM.DoubleType(), Float64, ""), (LLVM.FloatType(), Float32, "f")) - fname = String(name)*pf + for (T, JT, pf) in + ((LLVM.DoubleType(), Float64, ""), (LLVM.FloatType(), Float32, "f")) + fname = String(name) * pf if haskey(functions(mod), fname) funcspec = GPUCompiler.methodinstance(fnty, Tuple{JT}, world) llvmf = nested_codegen!(mode, mod, funcspec, world) push!(function_attributes(llvmf), StringAttribute("implements", fname)) end - end + end end API.EnzymeReplaceFunctionImplementation(mod) @@ -6478,7 +8484,8 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; end end end - for fname in ["__enzyme_float", "__enzyme_double", "__enzyme_integer", "__enzyme_pointer"] + for fname in + ["__enzyme_float", "__enzyme_double", "__enzyme_integer", "__enzyme_pointer"] haskey(functions(mod), fname) || continue f = functions(mod)[fname] for u in uses(f) @@ -6504,7 +8511,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; if parent_job !== nothing reinsert_gcmarker!(adjointf) augmented_primalf !== nothing && reinsert_gcmarker!(augmented_primalf) - post_optimze!(mod, target_machine, #=machine=#false) + post_optimze!(mod, target_machine, false) #=machine=# end adjointf = functions(mod)[adjointf_name] @@ -6525,34 +8532,127 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; use_primal = mode == API.DEM_ReverseModePrimal entry = use_primal ? augmented_primalf : adjointf - return mod, (;adjointf, augmented_primalf, entry, compiled=meta.compiled, TapeType) + return mod, (; adjointf, augmented_primalf, entry, compiled = meta.compiled, TapeType) end # Compiler result -struct CompileResult{AT, PT} +struct CompileResult{AT,PT} adjoint::AT primal::PT TapeType::Type end -@inline (thunk::PrimalErrorThunk{PT, FA, RT, TT, Width, ReturnPrimal})(fn::FA, args...) where {PT, FA, RT, TT, Width, ReturnPrimal} = -enzyme_call(Val(false), thunk.adjoint, PrimalErrorThunk{PT, FA, RT, TT, Width, ReturnPrimal}, Val(Width), Val(ReturnPrimal), TT, RT, fn, Cvoid, args...) - -@inline (thunk::CombinedAdjointThunk{PT, FA, RT, TT, Width, ReturnPrimal})(fn::FA, args...) where {PT, FA, Width, RT, TT, ReturnPrimal} = -enzyme_call(Val(false), thunk.adjoint, CombinedAdjointThunk{PT, FA, RT, TT, Width, ReturnPrimal}, Val(Width), Val(ReturnPrimal), TT, RT, fn, Cvoid, args...) +@inline (thunk::PrimalErrorThunk{PT,FA,RT,TT,Width,ReturnPrimal})( + fn::FA, + args..., +) where {PT,FA,RT,TT,Width,ReturnPrimal} = enzyme_call( + Val(false), + thunk.adjoint, + PrimalErrorThunk{PT,FA,RT,TT,Width,ReturnPrimal}, + Val(Width), + Val(ReturnPrimal), + TT, + RT, + fn, + Cvoid, + args..., +) -@inline (thunk::ForwardModeThunk{PT, FA, RT, TT, Width, ReturnPrimal})(fn::FA, args...) where {PT, FA, Width, RT, TT, ReturnPrimal} = -enzyme_call(Val(false), thunk.adjoint, ForwardModeThunk{PT, FA, RT, TT, Width, ReturnPrimal}, Val(Width), Val(ReturnPrimal), TT, RT, fn, Cvoid, args...) +@inline (thunk::CombinedAdjointThunk{PT,FA,RT,TT,Width,ReturnPrimal})( + fn::FA, + args..., +) where {PT,FA,Width,RT,TT,ReturnPrimal} = enzyme_call( + Val(false), + thunk.adjoint, + CombinedAdjointThunk{PT,FA,RT,TT,Width,ReturnPrimal}, + Val(Width), + Val(ReturnPrimal), + TT, + RT, + fn, + Cvoid, + args..., +) -@inline (thunk::AdjointThunk{PT, FA, RT, TT, Width, TapeT})(fn::FA, args...) where {PT, FA, Width, RT, TT, TapeT} = -enzyme_call(Val(false), thunk.adjoint, AdjointThunk{PT, FA, RT, TT, Width, TapeT}, Val(Width), #=ReturnPrimal=#Val(false), TT, RT, fn, TapeT, args...) -@inline raw_enzyme_call(thunk::AdjointThunk{PT, FA, RT, TT, Width, TapeT}, fn::FA, args...) where {PT, FA, Width, RT, TT, TapeT} = -enzyme_call(Val(true), thunk.adjoint, AdjointThunk, Val(Width), #=ReturnPrimal=#Val(false), TT, RT, fn, TapeT, args...) +@inline (thunk::ForwardModeThunk{PT,FA,RT,TT,Width,ReturnPrimal})( + fn::FA, + args..., +) where {PT,FA,Width,RT,TT,ReturnPrimal} = enzyme_call( + Val(false), + thunk.adjoint, + ForwardModeThunk{PT,FA,RT,TT,Width,ReturnPrimal}, + Val(Width), + Val(ReturnPrimal), + TT, + RT, + fn, + Cvoid, + args..., +) -@inline (thunk::AugmentedForwardThunk{PT, FA, RT, TT, Width, ReturnPrimal, TapeT})(fn::FA, args...) where {PT, FA, Width, RT, TT, ReturnPrimal, TapeT} = -enzyme_call(Val(false), thunk.primal, AugmentedForwardThunk, Val(Width), Val(ReturnPrimal), TT, RT, fn, TapeT, args...) -@inline raw_enzyme_call(thunk::AugmentedForwardThunk{PT, FA, RT, TT, Width, ReturnPrimal, TapeT}, fn::FA, args...) where {PT, FA, Width, RT, TT, ReturnPrimal, TapeT} = -enzyme_call(Val(true), thunk.primal, AugmentedForwardThunk, Val(Width), Val(ReturnPrimal), TT, RT, fn, TapeT, args...) +@inline (thunk::AdjointThunk{PT,FA,RT,TT,Width,TapeT})( + fn::FA, + args..., +) where {PT,FA,Width,RT,TT,TapeT} = enzyme_call( + Val(false), + thunk.adjoint, + AdjointThunk{PT,FA,RT,TT,Width,TapeT}, + Val(Width), + Val(false), + TT, + RT, + fn, + TapeT, + args..., +) #=ReturnPrimal=# +@inline raw_enzyme_call( + thunk::AdjointThunk{PT,FA,RT,TT,Width,TapeT}, + fn::FA, + args..., +) where {PT,FA,Width,RT,TT,TapeT} = enzyme_call( + Val(true), + thunk.adjoint, + AdjointThunk, + Val(Width), + Val(false), + TT, + RT, + fn, + TapeT, + args..., +) #=ReturnPrimal=# + +@inline (thunk::AugmentedForwardThunk{PT,FA,RT,TT,Width,ReturnPrimal,TapeT})( + fn::FA, + args..., +) where {PT,FA,Width,RT,TT,ReturnPrimal,TapeT} = enzyme_call( + Val(false), + thunk.primal, + AugmentedForwardThunk, + Val(Width), + Val(ReturnPrimal), + TT, + RT, + fn, + TapeT, + args..., +) +@inline raw_enzyme_call( + thunk::AugmentedForwardThunk{PT,FA,RT,TT,Width,ReturnPrimal,TapeT}, + fn::FA, + args..., +) where {PT,FA,Width,RT,TT,ReturnPrimal,TapeT} = enzyme_call( + Val(true), + thunk.primal, + AugmentedForwardThunk, + Val(Width), + Val(ReturnPrimal), + TT, + RT, + fn, + TapeT, + args..., +) function jl_set_typeof(v::Ptr{Cvoid}, T) @@ -6561,7 +8661,7 @@ function jl_set_typeof(v::Ptr{Cvoid}, T) return nothing end -@generated function splatnew(::Type{T}, args::TT) where {T,TT <: Tuple} +@generated function splatnew(::Type{T}, args::TT) where {T,TT<:Tuple} return quote Base.@_inline_meta $(Expr(:splatnew, :T, :args)) @@ -6570,7 +8670,12 @@ end # Recursively return x + f(y), where y is active, otherwise x -@inline function recursive_add(x::T, y::T, f::F=identity, forcelhs::F2=guaranteed_const) where {T, F, F2} +@inline function recursive_add( + x::T, + y::T, + f::F = identity, + forcelhs::F2 = guaranteed_const, +) where {T,F,F2} if forcelhs(T) return x end @@ -6582,31 +8687,41 @@ end end) end -@inline function recursive_add(x::T, y::T, f::F=identity, forcelhs::F2=guaranteed_const) where {T<:AbstractFloat, F, F2} +@inline function recursive_add( + x::T, + y::T, + f::F = identity, + forcelhs::F2 = guaranteed_const, +) where {T<:AbstractFloat,F,F2} if forcelhs(T) return x end return x + f(y) end -@inline function recursive_add(x::T, y::T, f::F=identity, forcelhs::F2=guaranteed_const) where {T<:Complex, F, F2} +@inline function recursive_add( + x::T, + y::T, + f::F = identity, + forcelhs::F2 = guaranteed_const, +) where {T<:Complex,F,F2} if forcelhs(T) return x end return x + f(y) end -@inline mutable_register(::Type{T}) where T <: Integer = true -@inline mutable_register(::Type{T}) where T <: AbstractFloat = false -@inline mutable_register(::Type{Complex{T}}) where T <: AbstractFloat = false -@inline mutable_register(::Type{T}) where T <: Tuple = false -@inline mutable_register(::Type{T}) where T <: NamedTuple = false +@inline mutable_register(::Type{T}) where {T<:Integer} = true +@inline mutable_register(::Type{T}) where {T<:AbstractFloat} = false +@inline mutable_register(::Type{Complex{T}}) where {T<:AbstractFloat} = false +@inline mutable_register(::Type{T}) where {T<:Tuple} = false +@inline mutable_register(::Type{T}) where {T<:NamedTuple} = false @inline mutable_register(::Type{Core.Box}) = true -@inline mutable_register(::Type{T}) where T <: Array = true -@inline mutable_register(::Type{T}) where T = ismutabletype(T) +@inline mutable_register(::Type{T}) where {T<:Array} = true +@inline mutable_register(::Type{T}) where {T} = ismutabletype(T) # Recursively In-place accumulate(aka +=). E.g. generalization of x .+= f(y) -@inline function recursive_accumulate(x::Array{T}, y::Array{T}, f::F=identity) where {T, F} +@inline function recursive_accumulate(x::Array{T}, y::Array{T}, f::F = identity) where {T,F} if !mutable_register(T) for I in eachindex(x) prev = x[I] @@ -6617,16 +8732,16 @@ end # Recursively In-place accumulate(aka +=). E.g. generalization of x .+= f(y) -@inline function recursive_accumulate(x::Core.Box, y::Core.Box, f::F=identity) where {F} +@inline function recursive_accumulate(x::Core.Box, y::Core.Box, f::F = identity) where {F} recursive_accumulate(x.contents, y.contents, seen, f) end -@inline function recursive_accumulate(x::T, y::T, f::F=identity) where {T, F} +@inline function recursive_accumulate(x::T, y::T, f::F = identity) where {T,F} @assert !Base.isabstracttype(T) @assert Base.isconcretetype(T) nf = fieldcount(T) - for i in 1:nf + for i = 1:nf if isdefined(x, i) xi = getfield(x, i) ST = Core.Typeof(xi) @@ -6646,9 +8761,13 @@ end elseif T <: AbstractFloat return one(T) elseif T <: Complex - error("Attempted to use automatic pullback (differential return value) deduction on a either a type unstable function returning an active complex number, or autodiff_deferred returning an active complex number. For the first case, please type stabilize your code, e.g. by specifying autodiff(Reverse, f->f(x)::Complex, ...). For the second case, please use regular non-deferred autodiff") + error( + "Attempted to use automatic pullback (differential return value) deduction on a either a type unstable function returning an active complex number, or autodiff_deferred returning an active complex number. For the first case, please type stabilize your code, e.g. by specifying autodiff(Reverse, f->f(x)::Complex, ...). For the second case, please use regular non-deferred autodiff", + ) else - error("Active return values with automatic pullback (differential return value) deduction only supported for floating-like values and not type $T. If mutable memory, please use Duplicated. Otherwise, you can explicitly specify a pullback by using split mode, e.g. autodiff_thunk(ReverseSplitWithPrimal, ...)") + error( + "Active return values with automatic pullback (differential return value) deduction only supported for floating-like values and not type $T. If mutable memory, please use Duplicated. Otherwise, you can explicitly specify a pullback by using split mode, e.g. autodiff_thunk(ReverseSplitWithPrimal, ...)", + ) end end @@ -6656,56 +8775,72 @@ function add_one_in_place(x) if x isa Base.RefValue x[] = recursive_add(x[], default_adjoint(eltype(Core.Typeof(x)))) else - error("Enzyme Mutability Error: Cannot add one in place to immutable value "*string(x)) + error( + "Enzyme Mutability Error: Cannot add one in place to immutable value " * + string(x), + ) end return nothing end -@generated function enzyme_call(::Val{RawCall}, fptr::PT, ::Type{CC}, ::Val{width}, ::Val{returnPrimal}, tt::Type{T}, - rt::Type{RT}, fn::FA, ::Type{TapeType}, args::Vararg{Any, N}) where {RawCall, PT, FA, T, RT, TapeType, N, CC, width, returnPrimal} +@generated function enzyme_call( + ::Val{RawCall}, + fptr::PT, + ::Type{CC}, + ::Val{width}, + ::Val{returnPrimal}, + tt::Type{T}, + rt::Type{RT}, + fn::FA, + ::Type{TapeType}, + args::Vararg{Any,N}, +) where {RawCall,PT,FA,T,RT,TapeType,N,CC,width,returnPrimal} JuliaContext() do ctx Base.@_inline_meta F = eltype(FA) - is_forward = CC <: AugmentedForwardThunk || CC <: ForwardModeThunk || CC <: PrimalErrorThunk + is_forward = + CC <: AugmentedForwardThunk || CC <: ForwardModeThunk || CC <: PrimalErrorThunk is_adjoint = CC <: AdjointThunk || CC <: CombinedAdjointThunk - is_split = CC <: AdjointThunk || CC <: AugmentedForwardThunk + is_split = CC <: AdjointThunk || CC <: AugmentedForwardThunk needs_tape = CC <: AdjointThunk - argtt = tt.parameters[1] - rettype = rt.parameters[1] + argtt = tt.parameters[1] + rettype = rt.parameters[1] argtypes = DataType[argtt.parameters...] - argexprs = Union{Expr, Symbol}[:(args[$i]) for i in 1:N] + argexprs = Union{Expr,Symbol}[:(args[$i]) for i = 1:N] if false && CC <: PrimalErrorThunk - primargs = [quote - convert($(eltype(T)), $(argexprs[i]).val) - end for (i, T) in enumerate(argtypes)] + primargs = [ + quote + convert($(eltype(T)), $(argexprs[i]).val) + end for (i, T) in enumerate(argtypes) + ] return quote fn.val($(primargs...)) - error("Function to differentiate is guaranteed to return an error and doesn't make sense to autodiff. Giving up") + error( + "Function to differentiate is guaranteed to return an error and doesn't make sense to autodiff. Giving up", + ) end end if !RawCall && !(CC <: PrimalErrorThunk) - if rettype <: Active || rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated + if rettype <: Active || + rettype <: MixedDuplicated || + rettype <: BatchMixedDuplicated if length(argtypes) + is_adjoint + needs_tape != length(argexprs) return quote - @show $width - @show $(length(argtypes)), $is_adjoint, $needs_tape, $(length(argexprs)) - @show $argtypes - @show $argexprs throw(MethodError($CC(fptr), (fn, args...))) end end elseif rettype <: Const - if length(argtypes) + needs_tape != length(argexprs) + if length(argtypes) + needs_tape != length(argexprs) return quote throw(MethodError($CC(fptr), (fn, args...))) end end else - if length(argtypes) + needs_tape != length(argexprs) + if length(argtypes) + needs_tape != length(argexprs) return quote throw(MethodError($CC(fptr), (fn, args...))) end @@ -6715,14 +8850,18 @@ end types = DataType[] - if !(rettype <: Const) && (isghostty(eltype(rettype)) || Core.Compiler.isconstType(eltype(rettype)) || eltype(rettype) === DataType) + if !(rettype <: Const) && ( + isghostty(eltype(rettype)) || + Core.Compiler.isconstType(eltype(rettype)) || + eltype(rettype) === DataType + ) rrt = eltype(rettype) error("Return type `$rrt` not marked Const, but is ghost or const type.") end - sret_types = [] # Julia types of all returned variables + sret_types = [] # Julia types of all returned variables # By ref values we create and need to preserve - ccexprs = Union{Expr, Symbol}[] # The expressions passed to the `llvmcall` + ccexprs = Union{Expr,Symbol}[] # The expressions passed to the `llvmcall` if !isghostty(F) && !Core.Compiler.isconstType(F) isboxed = GPUCompiler.deserves_argbox(F) @@ -6745,8 +8884,8 @@ end push!(types, Any) elseif width == 1 push!(types, F) - else - push!(types, NTuple{width, F}) + else + push!(types, NTuple{width,F}) end push!(ccexprs, argexpr) end @@ -6759,7 +8898,7 @@ end source_typ = eltype(T) expr = argexprs[i] - i+=1 + i += 1 if isghostty(source_typ) || Core.Compiler.isconstType(source_typ) @assert T <: Const if is_adjoint @@ -6798,13 +8937,13 @@ end if width == 1 push!(ActiveRetTypes, source_typ) else - push!(ActiveRetTypes, NTuple{width, source_typ}) + push!(ActiveRetTypes, NTuple{width,source_typ}) end end elseif T <: Duplicated || T <: DuplicatedNoNeed if RawCall argexpr = argexprs[i] - i+=1 + i += 1 else argexpr = Expr(:., expr, QuoteNode(:dval)) end @@ -6820,15 +8959,15 @@ end elseif T <: BatchDuplicated || T <: BatchDuplicatedNoNeed if RawCall argexpr = argexprs[i] - i+=1 + i += 1 else argexpr = Expr(:., expr, QuoteNode(:dval)) end - isboxedvec = GPUCompiler.deserves_argbox(NTuple{width, source_typ}) + isboxedvec = GPUCompiler.deserves_argbox(NTuple{width,source_typ}) if isboxedvec push!(types, Any) else - push!(types, NTuple{width, source_typ}) + push!(types, NTuple{width,source_typ}) end if is_adjoint push!(ActiveRetTypes, Nothing) @@ -6837,7 +8976,7 @@ end elseif T <: MixedDuplicated if RawCall argexpr = argexprs[i] - i+=1 + i += 1 else argexpr = Expr(:., expr, QuoteNode(:dval)) end @@ -6845,19 +8984,20 @@ end if is_adjoint push!(ActiveRetTypes, Nothing) end - push!(ccexprs, argexpr) + push!(ccexprs, argexpr) elseif T <: BatchMixedDuplicated if RawCall argexpr = argexprs[i] - i+=1 + i += 1 else argexpr = Expr(:., expr, QuoteNode(:dval)) end - isboxedvec = GPUCompiler.deserves_argbox(NTuple{width, Base.RefValue{source_typ}}) + isboxedvec = + GPUCompiler.deserves_argbox(NTuple{width,Base.RefValue{source_typ}}) if isboxedvec push!(types, Any) else - push!(types, NTuple{width, Base.RefValue{source_typ}}) + push!(types, NTuple{width,Base.RefValue{source_typ}}) end if is_adjoint push!(ActiveRetTypes, Nothing) @@ -6870,8 +9010,8 @@ end jlRT = eltype(rettype) if typeof(jlRT) == UnionAll - # Future improvement, add type assertion on load - jlRT = DataType + # Future improvement, add type assertion on load + jlRT = DataType end if is_sret_union(jlRT) @@ -6879,20 +9019,22 @@ end end # API.DFT_OUT_DIFF - if is_adjoint - if rettype <: Active || rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated + if is_adjoint + if rettype <: Active || + rettype <: MixedDuplicated || + rettype <: BatchMixedDuplicated # TODO handle batch width - if rettype <: Active + if rettype <: Active @assert allocatedinline(jlRT) end j_drT = if width == 1 jlRT else - NTuple{width, jlRT} + NTuple{width,jlRT} end push!(types, j_drT) push!(ccexprs, argexprs[i]) - i+=1 + i += 1 end end @@ -6901,20 +9043,23 @@ end push!(types, TapeType) push!(ccexprs, argexprs[i]) end - i+=1 + i += 1 end if is_adjoint NT = Tuple{ActiveRetTypes...} - if any(any_jltypes(convert(LLVM.LLVMType, b; allow_boxed=true)) for b in ActiveRetTypes) + if any( + any_jltypes(convert(LLVM.LLVMType, b; allow_boxed = true)) for + b in ActiveRetTypes + ) NT = AnonymousStruct(NT) end push!(sret_types, NT) end - + if !(CC <: PrimalErrorThunk) - @assert i == length(argexprs)+1 + @assert i == length(argexprs) + 1 end # Tape @@ -6922,7 +9067,7 @@ end push!(sret_types, TapeType) end - if returnPrimal && !(CC <: ForwardModeThunk) + if returnPrimal && !(CC <: ForwardModeThunk) push!(sret_types, jlRT) end if is_forward @@ -6934,9 +9079,9 @@ end elseif rettype <: MixedDuplicated push!(sret_types, Base.RefValue{jlRT}) elseif rettype <: BatchDuplicated || rettype <: BatchDuplicatedNoNeed - push!(sret_types, AnonymousStruct(NTuple{width, jlRT})) + push!(sret_types, AnonymousStruct(NTuple{width,jlRT})) elseif rettype <: BatchMixedDuplicated - push!(sret_types, AnonymousStruct(NTuple{width, Base.RefValue{jlRT}})) + push!(sret_types, AnonymousStruct(NTuple{width,Base.RefValue{jlRT}})) elseif CC <: AugmentedForwardThunk push!(sret_types, Nothing) elseif rettype <: Const @@ -6946,17 +9091,21 @@ end end end - if returnPrimal && (CC <: ForwardModeThunk) + if returnPrimal && (CC <: ForwardModeThunk) push!(sret_types, jlRT) end # calls fptr - llvmtys = LLVMType[convert(LLVMType, x; allow_boxed=true) for x in types] + llvmtys = LLVMType[convert(LLVMType, x; allow_boxed = true) for x in types] T_void = convert(LLVMType, Nothing) - combinedReturn = (CC <: PrimalErrorThunk && eltype(rettype) == Union{}) ? Union{} : Tuple{sret_types...} - if any(any_jltypes(convert(LLVM.LLVMType, T; allow_boxed=true)) for T in sret_types) + combinedReturn = + (CC <: PrimalErrorThunk && eltype(rettype) == Union{}) ? Union{} : + Tuple{sret_types...} + if any( + any_jltypes(convert(LLVM.LLVMType, T; allow_boxed = true)) for T in sret_types + ) combinedReturn = AnonymousStruct(combinedReturn) end uses_sret = is_sret(combinedReturn) @@ -6999,7 +9148,10 @@ end if returnRoots tracked = CountTrackedPointers(jltype) - pushfirst!(callparams, alloca!(builder, LLVM.ArrayType(T_prjlvalue, tracked.count))) + pushfirst!( + callparams, + alloca!(builder, LLVM.ArrayType(T_prjlvalue, tracked.count)), + ) pushfirst!(callparams, alloca!(builder, jltype)) end @@ -7007,7 +9159,11 @@ end tape = callparams[end] if TapeType <: EnzymeTapeToLoad llty = from_tape_type(eltype(TapeType)) - tape = bitcast!(builder, tape, LLVM.PointerType(llty, LLVM.addrspace(value_type(tape)))) + tape = bitcast!( + builder, + tape, + LLVM.PointerType(llty, LLVM.addrspace(value_type(tape))), + ) tape = load!(builder, llty, tape) API.SetMustCache!(tape) callparams[end] = tape @@ -7018,10 +9174,13 @@ end end if !(GPUCompiler.isghosttype(PT) || Core.Compiler.isconstType(PT)) - FT = LLVM.FunctionType(returnRoots ? T_void : T_ret, [value_type(x) for x in callparams]) + FT = LLVM.FunctionType( + returnRoots ? T_void : T_ret, + [value_type(x) for x in callparams], + ) lfn = inttoptr!(builder, lfn, LLVM.PointerType(FT)) else - val_inner(::Type{Val{V}}) where V = V + val_inner(::Type{Val{V}}) where {V} = V submod, subname = val_inner(PT) # TODO, consider optimization # However, julia will optimize after this, so no need @@ -7032,7 +9191,7 @@ end end r = call!(builder, FT, lfn, callparams) - + if returnRoots attr = if LLVM.version().major >= 12 TypeAttribute("sret", jltype) @@ -7049,7 +9208,7 @@ end ret!(builder) end reinsert_gcmarker!(llvm_f) - + ir = string(mod) fn = LLVM.name(llvm_f) @@ -7058,16 +9217,23 @@ end if !(GPUCompiler.isghosttype(PT) || Core.Compiler.isconstType(PT)) return quote Base.@_inline_meta - Base.llvmcall(($ir, $fn), $combinedReturn, - Tuple{$PT, $(types...)}, - fptr, $(ccexprs...)) + Base.llvmcall( + ($ir, $fn), + $combinedReturn, + Tuple{$PT,$(types...)}, + fptr, + $(ccexprs...), + ) end else return quote Base.@_inline_meta - Base.llvmcall(($ir, $fn), $combinedReturn, - Tuple{$(types...)}, - $(ccexprs...)) + Base.llvmcall( + ($ir, $fn), + $combinedReturn, + Tuple{$(types...)}, + $(ccexprs...), + ) end end end @@ -7079,24 +9245,38 @@ end function _link(job, (mod, adjoint_name, primal_name, TapeType)) if job.config.params.ABI <: InlineABI - return CompileResult(Val((Symbol(mod), Symbol(adjoint_name))), Val((Symbol(mod), Symbol(primal_name))), TapeType) + return CompileResult( + Val((Symbol(mod), Symbol(adjoint_name))), + Val((Symbol(mod), Symbol(primal_name))), + TapeType, + ) end # Now invoke the JIT jitted_mod = JIT.add!(mod) adjoint_addr = JIT.lookup(jitted_mod, adjoint_name) - adjoint_ptr = pointer(adjoint_addr) + adjoint_ptr = pointer(adjoint_addr) if adjoint_ptr === C_NULL - throw(GPUCompiler.InternalCompilerError(job, "Failed to compile Enzyme thunk, adjoint not found")) + throw( + GPUCompiler.InternalCompilerError( + job, + "Failed to compile Enzyme thunk, adjoint not found", + ), + ) end if primal_name === nothing primal_ptr = C_NULL else primal_addr = JIT.lookup(jitted_mod, primal_name) - primal_ptr = pointer(primal_addr) + primal_ptr = pointer(primal_addr) if primal_ptr === C_NULL - throw(GPUCompiler.InternalCompilerError(job, "Failed to compile Enzyme thunk, primal not found")) + throw( + GPUCompiler.InternalCompilerError( + job, + "Failed to compile Enzyme thunk, primal not found", + ), + ) end end @@ -7106,8 +9286,8 @@ end const DumpPostOpt = Ref(false) # actual compilation -function _thunk(job, postopt::Bool=true) - mod, meta = codegen(:llvm, job; optimize=false) +function _thunk(job, postopt::Bool = true) + mod, meta = codegen(:llvm, job; optimize = false) adjointf, augmented_primalf = meta.adjointf, meta.augmented_primalf adjoint_name = name(adjointf) @@ -7117,14 +9297,14 @@ function _thunk(job, postopt::Bool=true) else primal_name = nothing end - + LLVM.ModulePassManager() do pm add!(pm, FunctionPass("ReinsertGCMarker", reinsert_gcmarker_pass!)) LLVM.run!(pm, mod) end # Run post optimization pipeline - if postopt + if postopt if job.config.params.ABI <: FFIABI || job.config.params.ABI <: NonGenABI post_optimze!(mod, JIT.get_tm()) if DumpPostOpt[] @@ -7137,7 +9317,7 @@ function _thunk(job, postopt::Bool=true) return (mod, adjoint_name, primal_name, meta.TapeType) end -const cache = Dict{UInt, CompileResult}() +const cache = Dict{UInt,CompileResult}() const cache_lock = ReentrantLock() @inline function cached_compilation(@nospecialize(job::CompilerJob))::CompileResult @@ -7167,20 +9347,65 @@ end @inline remove_innerty(::Type{<:MixedDuplicated}) = MixedDuplicated @inline remove_innerty(::Type{<:BatchMixedDuplicated}) = MixedDuplicated -@inline function thunkbase(ctx, mi::Core.MethodInstance, ::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}, ::Val{ErrIfFuncWritten}, ::Val{RuntimeActivity}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI, ErrIfFuncWritten, RuntimeActivity} +@inline function thunkbase( + ctx, + mi::Core.MethodInstance, + ::Val{World}, + ::Type{FA}, + ::Type{A}, + tt::Type{TT}, + ::Val{Mode}, + ::Val{width}, + ::Val{ModifiedBetween}, + ::Val{ReturnPrimal}, + ::Val{ShadowInit}, + ::Type{ABI}, + ::Val{ErrIfFuncWritten}, + ::Val{RuntimeActivity}, +) where { + FA<:Annotation, + A<:Annotation, + TT, + Mode, + ModifiedBetween, + width, + ReturnPrimal, + ShadowInit, + World, + ABI, + ErrIfFuncWritten, + RuntimeActivity, +} target = Compiler.EnzymeTarget() - params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI, ErrIfFuncWritten, RuntimeActivity) - tmp_job = if World isa Nothing - Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false)) - else - Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) - end + params = Compiler.EnzymeCompilerParams( + Tuple{FA,TT.parameters...}, + Mode, + width, + remove_innerty(A), + true, + true, + ModifiedBetween, + ReturnPrimal, + ShadowInit, + UnknownTapeType, + ABI, + ErrIfFuncWritten, + RuntimeActivity, + ) #=abiwrap=# + tmp_job = if World isa Nothing + Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel = false)) + else + Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel = false), World) + end interp = GPUCompiler.get_interpreter(tmp_job) # TODO check compile return here, early # rrt = Core.Compiler.return_type(f, primal.tt) # nothing - rrt = something(Core.Compiler.typeinf_type(interp, mi.def, mi.specTypes, mi.sparam_vals), Any) + rrt = something( + Core.Compiler.typeinf_type(interp, mi.def, mi.specTypes, mi.sparam_vals), + Any, + ) rrt = Core.Compiler.typeinf_ext_toplevel(interp, mi).rettype run_enzyme = true @@ -7188,12 +9413,12 @@ end A2 = if rrt == Union{} run_enzyme = false Const - else + else A end - + if run_enzyme && !(A2 <: Const) && guaranteed_const_nongen(rrt, World) - estr = "Return type `$rrt` not marked Const, but type is guaranteed to be constant" + estr = "Return type `$rrt` not marked Const, but type is guaranteed to be constant" return error(estr) end @@ -7207,13 +9432,27 @@ end # @assert eltype(A) == rrt A2 end - - params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, rt2, run_enzyme, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI, ErrIfFuncWritten, RuntimeActivity) - job = if World isa Nothing - Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false)) + + params = Compiler.EnzymeCompilerParams( + Tuple{FA,TT.parameters...}, + Mode, + width, + rt2, + run_enzyme, + true, + ModifiedBetween, + ReturnPrimal, + ShadowInit, + UnknownTapeType, + ABI, + ErrIfFuncWritten, + RuntimeActivity, + ) #=abiwrap=# + job = if World isa Nothing + Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel = false)) else - Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) - end + Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel = false), World) + end # We need to use primal as the key, to lookup the right method # but need to mixin the hash of the adjoint to avoid cache collisions # This is counter-intuitive since we would expect the cache to be split @@ -7223,7 +9462,7 @@ end compile_result = cached_compilation(job) if !run_enzyme - ErrT = PrimalErrorThunk{typeof(compile_result.adjoint), FA, rt2, TT, width, ReturnPrimal} + ErrT = PrimalErrorThunk{typeof(compile_result.adjoint),FA,rt2,TT,width,ReturnPrimal} if Mode == API.DEM_ReverseModePrimal || Mode == API.DEM_ReverseModeGradient return (ErrT(compile_result.adjoint), ErrT(compile_result.adjoint)) else @@ -7231,71 +9470,227 @@ end end elseif Mode == API.DEM_ReverseModePrimal || Mode == API.DEM_ReverseModeGradient TapeType = compile_result.TapeType - AugT = AugmentedForwardThunk{typeof(compile_result.primal), FA, rt2, Tuple{params.TT.parameters[2:end]...}, width, ReturnPrimal, TapeType} - AdjT = AdjointThunk{typeof(compile_result.adjoint), FA, rt2, Tuple{params.TT.parameters[2:end]...}, width, TapeType} + AugT = AugmentedForwardThunk{ + typeof(compile_result.primal), + FA, + rt2, + Tuple{params.TT.parameters[2:end]...}, + width, + ReturnPrimal, + TapeType, + } + AdjT = AdjointThunk{ + typeof(compile_result.adjoint), + FA, + rt2, + Tuple{params.TT.parameters[2:end]...}, + width, + TapeType, + } return (AugT(compile_result.primal), AdjT(compile_result.adjoint)) elseif Mode == API.DEM_ReverseModeCombined - CAdjT = CombinedAdjointThunk{typeof(compile_result.adjoint), FA, rt2, Tuple{params.TT.parameters[2:end]...}, width, ReturnPrimal} + CAdjT = CombinedAdjointThunk{ + typeof(compile_result.adjoint), + FA, + rt2, + Tuple{params.TT.parameters[2:end]...}, + width, + ReturnPrimal, + } return CAdjT(compile_result.adjoint) elseif Mode == API.DEM_ForwardMode - FMT = ForwardModeThunk{typeof(compile_result.adjoint), FA, rt2, Tuple{params.TT.parameters[2:end]...}, width, ReturnPrimal} + FMT = ForwardModeThunk{ + typeof(compile_result.adjoint), + FA, + rt2, + Tuple{params.TT.parameters[2:end]...}, + width, + ReturnPrimal, + } return FMT(compile_result.adjoint) else @assert false end end -@inline function thunk(mi::Core.MethodInstance, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}, ::Val{ErrIfFuncWritten}, ::Val{RuntimeActivity}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, ABI, ErrIfFuncWritten, RuntimeActivity} - ts_ctx = JuliaContext() - ctx = context(ts_ctx) - activate(ctx) - try - return thunkbase(ctx, mi, Val(#=World=#nothing), FA, A, TT, Val(Mode), Val(width), Val(ModifiedBetween), Val(ReturnPrimal), Val(ShadowInit), ABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) - finally - deactivate(ctx) - dispose(ts_ctx) - end -end - -@inline @generated function thunk(::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}, ::Val{ErrIfFuncWritten}, ::Val{RuntimeActivity}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI, ErrIfFuncWritten, RuntimeActivity} - mi = fspec(eltype(FA), TT, World) - ts_ctx = JuliaContext() - ctx = context(ts_ctx) - activate(ctx) - res = try - thunkbase(ctx, mi, Val(World), FA, A, TT, Val(Mode), Val(width), Val(ModifiedBetween), Val(ReturnPrimal), Val(ShadowInit), ABI, Val(ErrIfFuncWritten), Val(RuntimeActivity)) - finally - deactivate(ctx) - dispose(ts_ctx) - end - return quote - Base.@_inline_meta - return $(res) - end +@inline function thunk( + mi::Core.MethodInstance, + ::Type{FA}, + ::Type{A}, + tt::Type{TT}, + ::Val{Mode}, + ::Val{width}, + ::Val{ModifiedBetween}, + ::Val{ReturnPrimal}, + ::Val{ShadowInit}, + ::Type{ABI}, + ::Val{ErrIfFuncWritten}, + ::Val{RuntimeActivity}, +) where { + FA<:Annotation, + A<:Annotation, + TT, + Mode, + ModifiedBetween, + width, + ReturnPrimal, + ShadowInit, + ABI, + ErrIfFuncWritten, + RuntimeActivity, +} + ts_ctx = JuliaContext() + ctx = context(ts_ctx) + activate(ctx) + try + return thunkbase( + ctx, + mi, + Val(nothing), + FA, + A, + TT, + Val(Mode), + Val(width), + Val(ModifiedBetween), + Val(ReturnPrimal), + Val(ShadowInit), + ABI, + Val(ErrIfFuncWritten), + Val(RuntimeActivity), + ) #=World=# + finally + deactivate(ctx) + dispose(ts_ctx) + end +end + +@inline @generated function thunk( + ::Val{World}, + ::Type{FA}, + ::Type{A}, + tt::Type{TT}, + ::Val{Mode}, + ::Val{width}, + ::Val{ModifiedBetween}, + ::Val{ReturnPrimal}, + ::Val{ShadowInit}, + ::Type{ABI}, + ::Val{ErrIfFuncWritten}, + ::Val{RuntimeActivity}, +) where { + FA<:Annotation, + A<:Annotation, + TT, + Mode, + ModifiedBetween, + width, + ReturnPrimal, + ShadowInit, + World, + ABI, + ErrIfFuncWritten, + RuntimeActivity, +} + mi = fspec(eltype(FA), TT, World) + ts_ctx = JuliaContext() + ctx = context(ts_ctx) + activate(ctx) + res = try + thunkbase( + ctx, + mi, + Val(World), + FA, + A, + TT, + Val(Mode), + Val(width), + Val(ModifiedBetween), + Val(ReturnPrimal), + Val(ShadowInit), + ABI, + Val(ErrIfFuncWritten), + Val(RuntimeActivity), + ) + finally + deactivate(ctx) + dispose(ts_ctx) + end + return quote + Base.@_inline_meta + return $(res) + end end import GPUCompiler: deferred_codegen_jobs -@generated function deferred_codegen(::Val{World}, ::Type{FA}, ::Val{TT}, ::Val{A},::Val{Mode}, - ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal},::Val{ShadowInit},::Type{ExpectedTapeType}, ::Val{ErrIfFuncWritten}, ::Val{RuntimeActivity}) where {World, FA<:Annotation,TT, A, Mode, width, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType, ErrIfFuncWritten, RuntimeActivity} +@generated function deferred_codegen( + ::Val{World}, + ::Type{FA}, + ::Val{TT}, + ::Val{A}, + ::Val{Mode}, + ::Val{width}, + ::Val{ModifiedBetween}, + ::Val{ReturnPrimal}, + ::Val{ShadowInit}, + ::Type{ExpectedTapeType}, + ::Val{ErrIfFuncWritten}, + ::Val{RuntimeActivity}, +) where { + World, + FA<:Annotation, + TT, + A, + Mode, + width, + ModifiedBetween, + ReturnPrimal, + ShadowInit, + ExpectedTapeType, + ErrIfFuncWritten, + RuntimeActivity, +} JuliaContext() do ctx Base.@_inline_meta mi = fspec(eltype(FA), TT, World) target = EnzymeTarget() - rt2 = if A isa UnionAll - params = EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType, FFIABI, ErrIfFuncWritten, RuntimeActivity) - tmp_job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) - + rt2 = if A isa UnionAll + params = EnzymeCompilerParams( + Tuple{FA,TT.parameters...}, + Mode, + width, + remove_innerty(A), + true, + true, + ModifiedBetween, + ReturnPrimal, + ShadowInit, + ExpectedTapeType, + FFIABI, + ErrIfFuncWritten, + RuntimeActivity, + ) #=abiwrap=# + tmp_job = Compiler.CompilerJob( + mi, + CompilerConfig(target, params; kernel = false), + World, + ) + interp = GPUCompiler.get_interpreter(tmp_job) - rrt = something(Core.Compiler.typeinf_type(interp, mi.def, mi.specTypes, mi.sparam_vals), Any) + rrt = something( + Core.Compiler.typeinf_type(interp, mi.def, mi.specTypes, mi.sparam_vals), + Any, + ) # Don't error here but default to nothing return since in cuda context we don't use the device overrides if rrt == Union{} rrt = Nothing end - + if !(A <: Const) && guaranteed_const_nongen(rrt, World) estr = "Return type `$rrt` not marked Const, but type is guaranteed to be constant" return quote @@ -7307,9 +9702,24 @@ import GPUCompiler: deferred_codegen_jobs @assert A isa DataType A end - - params = EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, rt2, true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType, FFIABI, ErrIfFuncWritten, RuntimeActivity) - job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) + + params = EnzymeCompilerParams( + Tuple{FA,TT.parameters...}, + Mode, + width, + rt2, + true, + true, + ModifiedBetween, + ReturnPrimal, + ShadowInit, + ExpectedTapeType, + FFIABI, + ErrIfFuncWritten, + RuntimeActivity, + ) #=abiwrap=# + job = + Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel = false), World) addr = get_trampoline(job) id = Base.reinterpret(Int, pointer(addr)) @@ -7317,7 +9727,13 @@ import GPUCompiler: deferred_codegen_jobs quote Base.@_inline_meta - ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), $(reinterpret(Ptr{Cvoid}, id))) + ccall( + "extern deferred_codegen", + llvmcall, + Ptr{Cvoid}, + (Ptr{Cvoid},), + $(reinterpret(Ptr{Cvoid}, id)), + ) end end end diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index c167581c3a9..4d48297ae5e 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -1,6 +1,12 @@ module Interpreter import Enzyme: API -using Core.Compiler: AbstractInterpreter, InferenceResult, InferenceParams, InferenceState, OptimizationParams, MethodInstance +using Core.Compiler: + AbstractInterpreter, + InferenceResult, + InferenceParams, + InferenceState, + OptimizationParams, + MethodInstance using GPUCompiler: @safe_debug if VERSION < v"1.11.0-DEV.1552" using GPUCompiler: CodeCache, WorldView, @safe_debug @@ -18,11 +24,11 @@ else end struct EnzymeInterpreter <: AbstractInterpreter -@static if HAS_INTEGRATED_CACHE - token::Any -else - code_cache::CodeCache -end + @static if HAS_INTEGRATED_CACHE + token::Any + else + code_cache::CodeCache + end method_table::Union{Nothing,Core.MethodTable} # Cache of inference results for this particular interpreter @@ -37,11 +43,16 @@ end mode::API.CDerivativeMode end -function EnzymeInterpreter(cache_or_token, mt::Union{Nothing,Core.MethodTable}, world::UInt, mode::API.CDerivativeMode) +function EnzymeInterpreter( + cache_or_token, + mt::Union{Nothing,Core.MethodTable}, + world::UInt, + mode::API.CDerivativeMode, +) @assert world <= Base.get_world_counter() parms = @static if VERSION < v"1.12" - InferenceParams(unoptimize_throw_blocks=false) + InferenceParams(unoptimize_throw_blocks = false) else InferenceParams() end @@ -57,9 +68,9 @@ function EnzymeInterpreter(cache_or_token, mt::Union{Nothing,Core.MethodTable}, world, # parameters for inference and optimization - parms, + parms, OptimizationParams(), - mode + mode, ) end @@ -70,7 +81,8 @@ Core.Compiler.get_inference_cache(interp::EnzymeInterpreter) = interp.local_cach @static if HAS_INTEGRATED_CACHE Core.Compiler.cache_owner(interp::EnzymeInterpreter) = interp.token else - Core.Compiler.code_cache(interp::EnzymeInterpreter) = WorldView(interp.code_cache, interp.world) + Core.Compiler.code_cache(interp::EnzymeInterpreter) = + WorldView(interp.code_cache, interp.world) end # No need to do any locking since we're not putting our results into the runtime cache @@ -87,14 +99,14 @@ Core.Compiler.may_discard_trees(::EnzymeInterpreter) = false Core.Compiler.verbose_stmt_info(::EnzymeInterpreter) = false if isdefined(Base.Experimental, Symbol("@overlay")) -Core.Compiler.method_table(interp::EnzymeInterpreter, sv::InferenceState) = - Core.Compiler.OverlayMethodTable(interp.world, interp.method_table) + Core.Compiler.method_table(interp::EnzymeInterpreter, sv::InferenceState) = + Core.Compiler.OverlayMethodTable(interp.world, interp.method_table) else -# On 1.6- CUDA.jl will poison the method table at the end of the world -# using GPUCompiler: WorldOverlayMethodTable -# Core.Compiler.method_table(interp::EnzymeInterpreter, sv::InferenceState) = -# WorldOverlayMethodTable(interp.world) + # On 1.6- CUDA.jl will poison the method table at the end of the world + # using GPUCompiler: WorldOverlayMethodTable + # Core.Compiler.method_table(interp::EnzymeInterpreter, sv::InferenceState) = + # WorldOverlayMethodTable(interp.world) end function is_alwaysinline_func(@nospecialize(TT)) @@ -114,8 +126,11 @@ function is_primitive_func(@nospecialize(TT)) end # FIXME(@wsmoses): For which types should we not inline? - if ft === typeof(Base.wait) || ft === typeof(Base._wait) || ft === typeof(Base.enq_work) || - ft === typeof(Base.Threads.threadid) || ft == typeof(Base.Threads.nthreads) || + if ft === typeof(Base.wait) || + ft === typeof(Base._wait) || + ft === typeof(Base.enq_work) || + ft === typeof(Base.Threads.threadid) || + ft == typeof(Base.Threads.nthreads) || ft === typeof(Base.Threads.threading_run) return true end @@ -123,7 +138,7 @@ function is_primitive_func(@nospecialize(TT)) end function isKWCallSignature(@nospecialize(TT)) - return TT <: Tuple{typeof(Core.kwcall), Any, Any, Vararg} + return TT <: Tuple{typeof(Core.kwcall),Any,Any,Vararg} end function simplify_kw(@nospecialize specTypes) @@ -137,27 +152,46 @@ end import Core.Compiler: CallInfo struct NoInlineCallInfo <: CallInfo info::CallInfo # wrapped call - tt # ::Type + tt::Any # ::Type kind::Symbol - NoInlineCallInfo(@nospecialize(info::CallInfo), @nospecialize(tt), kind::Symbol) = new(info, tt, kind) + NoInlineCallInfo(@nospecialize(info::CallInfo), @nospecialize(tt), kind::Symbol) = + new(info, tt, kind) end Core.Compiler.nsplit_impl(info::NoInlineCallInfo) = Core.Compiler.nsplit(info.info) -Core.Compiler.getsplit_impl(info::NoInlineCallInfo, idx::Int) = Core.Compiler.getsplit(info.info, idx) -Core.Compiler.getresult_impl(info::NoInlineCallInfo, idx::Int) = Core.Compiler.getresult(info.info, idx) +Core.Compiler.getsplit_impl(info::NoInlineCallInfo, idx::Int) = + Core.Compiler.getsplit(info.info, idx) +Core.Compiler.getresult_impl(info::NoInlineCallInfo, idx::Int) = + Core.Compiler.getresult(info.info, idx) struct AlwaysInlineCallInfo <: CallInfo info::CallInfo # wrapped call - tt # ::Type + tt::Any # ::Type AlwaysInlineCallInfo(@nospecialize(info::CallInfo), @nospecialize(tt)) = new(info, tt) end Core.Compiler.nsplit_impl(info::AlwaysInlineCallInfo) = Core.Compiler.nsplit(info.info) -Core.Compiler.getsplit_impl(info::AlwaysInlineCallInfo, idx::Int) = Core.Compiler.getsplit(info.info, idx) -Core.Compiler.getresult_impl(info::AlwaysInlineCallInfo, idx::Int) = Core.Compiler.getresult(info.info, idx) +Core.Compiler.getsplit_impl(info::AlwaysInlineCallInfo, idx::Int) = + Core.Compiler.getsplit(info.info, idx) +Core.Compiler.getresult_impl(info::AlwaysInlineCallInfo, idx::Int) = + Core.Compiler.getresult(info.info, idx) using Core.Compiler: ArgInfo, StmtInfo, AbsIntState -function Core.Compiler.abstract_call_gf_by_type(interp::EnzymeInterpreter, @nospecialize(f), - arginfo::ArgInfo, si::StmtInfo, @nospecialize(atype), sv::AbsIntState, max_methods::Int) - ret = @invoke Core.Compiler.abstract_call_gf_by_type(interp::AbstractInterpreter, f::Any, - arginfo::ArgInfo, si::StmtInfo, atype::Any, sv::AbsIntState, max_methods::Int) +function Core.Compiler.abstract_call_gf_by_type( + interp::EnzymeInterpreter, + @nospecialize(f), + arginfo::ArgInfo, + si::StmtInfo, + @nospecialize(atype), + sv::AbsIntState, + max_methods::Int, +) + ret = @invoke Core.Compiler.abstract_call_gf_by_type( + interp::AbstractInterpreter, + f::Any, + arginfo::ArgInfo, + si::StmtInfo, + atype::Any, + sv::AbsIntState, + max_methods::Int, + ) callinfo = ret.info method_table = Core.Compiler.method_table(interp) specTypes = simplify_kw(atype) @@ -175,21 +209,43 @@ function Core.Compiler.abstract_call_gf_by_type(interp::EnzymeInterpreter, @nosp callinfo = NoInlineCallInfo(callinfo, atype, :rrule) end @static if VERSION ≥ v"1.11-" - return Core.Compiler.CallMeta(ret.rt, ret.exct, ret.effects, callinfo) + return Core.Compiler.CallMeta(ret.rt, ret.exct, ret.effects, callinfo) else - return Core.Compiler.CallMeta(ret.rt, ret.effects, callinfo) + return Core.Compiler.CallMeta(ret.rt, ret.effects, callinfo) end end let # overload `inlining_policy` @static if VERSION ≥ v"1.11.0-DEV.879" - sigs_ex = :(interp::EnzymeInterpreter, @nospecialize(src), @nospecialize(info::Core.Compiler.CallInfo), stmt_flag::UInt32) - args_ex = :(interp::AbstractInterpreter, src::Any, info::Core.Compiler.CallInfo, stmt_flag::UInt32) + sigs_ex = :( + interp::EnzymeInterpreter, + @nospecialize(src), + @nospecialize(info::Core.Compiler.CallInfo), + stmt_flag::UInt32, + ) + args_ex = :( + interp::AbstractInterpreter, + src::Any, + info::Core.Compiler.CallInfo, + stmt_flag::UInt32, + ) else - sigs_ex = :(interp::EnzymeInterpreter, - @nospecialize(src), @nospecialize(info::Core.Compiler.CallInfo), stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any}) - args_ex = :(interp::AbstractInterpreter, - src::Any, info::Core.Compiler.CallInfo, stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any}) + sigs_ex = :( + interp::EnzymeInterpreter, + @nospecialize(src), + @nospecialize(info::Core.Compiler.CallInfo), + stmt_flag::UInt8, + mi::MethodInstance, + argtypes::Vector{Any}, + ) + args_ex = :( + interp::AbstractInterpreter, + src::Any, + info::Core.Compiler.CallInfo, + stmt_flag::UInt8, + mi::MethodInstance, + argtypes::Vector{Any}, + ) end @eval function Core.Compiler.inlining_policy($(sigs_ex.args...)) if info isa NoInlineCallInfo @@ -212,20 +268,36 @@ let # overload `inlining_policy` end end -import Core.Compiler: abstract_call, abstract_call_known, ArgInfo, StmtInfo, AbsIntState, get_max_methods, - CallMeta, Effects, NoCallInfo, widenconst, mapany, MethodResultPure +import Core.Compiler: + abstract_call, + abstract_call_known, + ArgInfo, + StmtInfo, + AbsIntState, + get_max_methods, + CallMeta, + Effects, + NoCallInfo, + widenconst, + mapany, + MethodResultPure struct AutodiffCallInfo <: CallInfo # ... info::CallInfo end -function abstract_call_known(interp::EnzymeInterpreter, @nospecialize(f), - arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, - max_methods::Int = get_max_methods(interp, f, sv)) +function abstract_call_known( + interp::EnzymeInterpreter, + @nospecialize(f), + arginfo::ArgInfo, + si::StmtInfo, + sv::AbsIntState, + max_methods::Int = get_max_methods(interp, f, sv), +) (; fargs, argtypes) = arginfo - + if f === Enzyme.within_autodiff if length(argtypes) != 1 @static if VERSION < v"1.11.0-" @@ -235,26 +307,48 @@ function abstract_call_known(interp::EnzymeInterpreter, @nospecialize(f), end end @static if VERSION < v"1.11.0-" - return CallMeta(Core.Const(true), Core.Compiler.EFFECTS_TOTAL, MethodResultPure()) + return CallMeta( + Core.Const(true), + Core.Compiler.EFFECTS_TOTAL, + MethodResultPure(), + ) else - return CallMeta(Core.Const(true), Union{}, Core.Compiler.EFFECTS_TOTAL, MethodResultPure()) + return CallMeta( + Core.Const(true), + Union{}, + Core.Compiler.EFFECTS_TOTAL, + MethodResultPure(), + ) end end if f === Enzyme.autodiff && length(argtypes) >= 4 - if widenconst(argtypes[2]) <: Enzyme.Mode && widenconst(argtypes[3]) <: Enzyme.Annotation && widenconst(argtypes[4]) <: Type{<:Enzyme.Annotation} - arginfo2 = ArgInfo( - fargs isa Nothing ? nothing : [:(Enzyme.autodiff_deferred), fargs[2:end]...], - [Core.Const(Enzyme.autodiff_deferred), argtypes[2:end]...] - ) - return abstract_call_known( - interp, Enzyme.autodiff_deferred, arginfo2, - si, sv, max_methods) - end + if widenconst(argtypes[2]) <: Enzyme.Mode && + widenconst(argtypes[3]) <: Enzyme.Annotation && + widenconst(argtypes[4]) <: Type{<:Enzyme.Annotation} + arginfo2 = ArgInfo( + fargs isa Nothing ? nothing : + [:(Enzyme.autodiff_deferred), fargs[2:end]...], + [Core.Const(Enzyme.autodiff_deferred), argtypes[2:end]...], + ) + return abstract_call_known( + interp, + Enzyme.autodiff_deferred, + arginfo2, + si, + sv, + max_methods, + ) + end end return Base.@invoke abstract_call_known( - interp::AbstractInterpreter, f, arginfo::ArgInfo, - si::StmtInfo, sv::AbsIntState, max_methods::Int) + interp::AbstractInterpreter, + f, + arginfo::ArgInfo, + si::StmtInfo, + sv::AbsIntState, + max_methods::Int, + ) end end diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index 2e3e8194c97..d11daaa0b3e 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -15,24 +15,52 @@ struct PipelineConfig cleanup::Cint end -const RunAttributor = Ref(true) - -function pipeline_options(; lower_intrinsics=true, dump_native=false, external_use=false, llvm_only=false, always_inline=true, enable_early_simplifications=true, - enable_early_optimizations=true, - enable_scalar_optimizations=true, - enable_loop_optimizations=true, - enable_vector_pipeline=true, - remove_ni=true, - cleanup=true, Size=0, Speedup=3) - return PipelineConfig(Speedup, Size, lower_intrinsics, dump_native, external_use, llvm_only, always_inline, enable_early_simplifications, enable_early_optimizations, enable_scalar_optimizations, enable_loop_optimizations, enable_vector_pipeline, remove_ni, cleanup) +const RunAttributor = Ref(true) + +function pipeline_options(; + lower_intrinsics = true, + dump_native = false, + external_use = false, + llvm_only = false, + always_inline = true, + enable_early_simplifications = true, + enable_early_optimizations = true, + enable_scalar_optimizations = true, + enable_loop_optimizations = true, + enable_vector_pipeline = true, + remove_ni = true, + cleanup = true, + Size = 0, + Speedup = 3, +) + return PipelineConfig( + Speedup, + Size, + lower_intrinsics, + dump_native, + external_use, + llvm_only, + always_inline, + enable_early_simplifications, + enable_early_optimizations, + enable_scalar_optimizations, + enable_loop_optimizations, + enable_vector_pipeline, + remove_ni, + cleanup, + ) end -function run_jl_pipeline(pm, tm; kwargs...) - config = Ref(pipeline_options(;kwargs...)) +function run_jl_pipeline(pm, tm; kwargs...) + config = Ref(pipeline_options(; kwargs...)) function jl_pipeline(m) - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm - @ccall jl_build_newpm_pipeline(mpm.ref::Ptr{Cvoid}, pb.ref::Ptr{Cvoid}, config::Ptr{PipelineConfig})::Cvoid + @ccall jl_build_newpm_pipeline( + mpm.ref::Ptr{Cvoid}, + pb.ref::Ptr{Cvoid}, + config::Ptr{PipelineConfig}, + )::Cvoid end LLVM.run!(mpm, m, tm) end @@ -53,10 +81,10 @@ end else function gc_invariant_verifier_tm!(pm, tm, cond) function gc_invariant_verifier(mod) - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, NewPMFunctionPassManager()) do fpm - add!(fpm, GCInvariantVerifierPass(;strong=cond)) + add!(fpm, GCInvariantVerifierPass(; strong = cond)) end end run!(pb, mod) @@ -74,7 +102,7 @@ end else function propagate_julia_addrsp_tm!(pm, tm) function prop_julia_addr(mod) - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, NewPMFunctionPassManager()) do fpm add!(fpm, PropagateJuliaAddrspacesPass()) @@ -95,7 +123,7 @@ end else function alloc_opt_tm!(pm, tm) function alloc_opt(mod) - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, NewPMFunctionPassManager()) do fpm add!(fpm, AllocOptPass()) @@ -116,7 +144,7 @@ end else function remove_ni_tm!(pm, tm) function remove_ni(mod) - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, RemoveNIPass()) end @@ -135,7 +163,7 @@ end else function julia_licm_tm!(pm, tm) function julia_licm(mod) - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, NewPMFunctionPassManager()) do fpm add!(fpm, NewPMLoopPassManager()) do lpm @@ -159,7 +187,7 @@ end else function lower_simdloop_tm!(pm, tm) function lower_simdloop(mod) - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, NewPMFunctionPassManager()) do fpm add!(fpm, NewPMLoopPassManager()) do lpm @@ -181,13 +209,28 @@ function loop_optimizations_tm!(pm, tm) @static if true || VERSION < v"1.11-" lower_simdloop_tm!(pm, tm) licm!(pm) - if LLVM.version() >= v"15" + if LLVM.version() >= v"15" simple_loop_unswitch_legacy!(pm) else loop_unswitch!(pm) end else - run_jl_pipeline(pm, tm; lower_intrinsics=false, dump_native=false, external_use=false, llvm_only=false, always_inline=false, enable_early_simplifications=false, enable_early_optimizations=false, enable_scalar_optimizations=false, enable_loop_optimizations=true, enable_vector_pipeline=false, remove_ni=false, cleanup=false) + run_jl_pipeline( + pm, + tm; + lower_intrinsics = false, + dump_native = false, + external_use = false, + llvm_only = false, + always_inline = false, + enable_early_simplifications = false, + enable_early_optimizations = false, + enable_scalar_optimizations = false, + enable_loop_optimizations = true, + enable_vector_pipeline = false, + remove_ni = false, + cleanup = false, + ) end end @@ -205,7 +248,7 @@ function more_loop_optimizations_tm!(pm, tm) # Subsequent passes not stripping metadata from terminator instruction_combining!(pm) # TODO: createInstSimplifyLegacy jl_inst_simplify!(pm) - + ind_var_simplify!(pm) loop_deletion!(pm) loop_unroll!(pm) # TODO: in Julia createSimpleLoopUnroll @@ -224,7 +267,22 @@ function more_loop_optimizations_tm!(pm, tm) # IndVarSimplifyPass # LoopDeletionPass # LoopFullUnrollPass - run_jl_pipeline(pm, tm; lower_intrinsics=false, dump_native=false, external_use=false, llvm_only=false, always_inline=false, enable_early_simplifications=false, enable_early_optimizations=false, enable_scalar_optimizations=false, enable_loop_optimizations=true, enable_vector_pipeline=false, remove_ni=false, cleanup=false) + run_jl_pipeline( + pm, + tm; + lower_intrinsics = false, + dump_native = false, + external_use = false, + llvm_only = false, + always_inline = false, + enable_early_simplifications = false, + enable_early_optimizations = false, + enable_scalar_optimizations = false, + enable_loop_optimizations = true, + enable_vector_pipeline = false, + remove_ni = false, + cleanup = false, + ) end end @@ -235,7 +293,7 @@ end else function demote_float16_tm!(pm, tm) function demote_float16(mod) - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, NewPMFunctionPassManager()) do fpm add!(fpm, DemoteFloat16Pass()) @@ -256,7 +314,7 @@ end else function lower_exc_handlers_tm!(pm, tm) function lower_exc_handlers(mod) - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, NewPMFunctionPassManager()) do fpm add!(fpm, LowerExcHandlersPass()) @@ -277,7 +335,7 @@ end else function lower_ptls_tm!(pm, tm, dump_native) function lower_ptls(mod) - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, LowerPTLSPass()) end @@ -296,7 +354,7 @@ end else function combine_mul_add_tm!(pm, tm) function combine_mul_add(mod) - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, NewPMFunctionPassManager()) do fpm add!(fpm, CombineMulAddPass()) @@ -317,7 +375,7 @@ end else function late_lower_gc_frame_tm!(pm, tm) function late_lower_gc_frame(mod) - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, NewPMFunctionPassManager()) do fpm add!(fpm, LateLowerGCPass()) @@ -338,7 +396,7 @@ end else function final_lower_gc_tm!(pm, tm) function final_lower_gc(mod) - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, NewPMFunctionPassManager()) do fpm add!(fpm, FinalLowerGCPass()) @@ -355,17 +413,17 @@ end @static if VERSION < v"1.11-" function cpu_features_tm!(pm, tm) @static if isdefined(LLVM.Interop, :cpu_features!) - LLVM.Interop.cpu_features!(pm) + LLVM.Interop.cpu_features!(pm) else - @static if isdefined(GPUCompiler, :cpu_features!) + @static if isdefined(GPUCompiler, :cpu_features!) GPUCompiler.cpu_features!(pm) - end + end end end else function cpu_features_tm!(pm, tm) function cpu_features(mod) - @dispose pb=NewPMPassBuilder() begin + @dispose pb = NewPMPassBuilder() begin add!(pb, NewPMModulePassManager()) do mpm add!(mpm, CPUFeaturesPass()) end @@ -379,7 +437,7 @@ end function addNA(inst, node::LLVM.Metadata, MD) md = metadata(inst) - next = nothing + next = nothing if haskey(md, MD) next = LLVM.MDNode(Metadata[node, operands(md[MD])...]) else @@ -405,7 +463,7 @@ function addr13NoAlias(mod::LLVM.Module) end end elseif isa(inst, LLVM.LoadInst) - ty =value_type(inst) + ty = value_type(inst) if isa(ty, LLVM.PointerType) if addrspace(ty) == 13 addNA(inst, aliasscope, LLVM.MD_alias_scope) @@ -432,7 +490,7 @@ end # turn this into load/store, as this is more # amenable to caching analysis infrastructure function memcpy_alloca_to_loadstore(mod::LLVM.Module) - dl = datalayout(mod) + dl = datalayout(mod) for f in functions(mod) if length(blocks(f)) != 0 bb = first(blocks(f)) @@ -441,21 +499,24 @@ function memcpy_alloca_to_loadstore(mod::LLVM.Module) if !isa(alloca, LLVM.AllocaInst) continue end - todo = Tuple{LLVM.Instruction, LLVM.Value}[(alloca, alloca)] + todo = Tuple{LLVM.Instruction,LLVM.Value}[(alloca, alloca)] copy = nothing legal = true elty = LLVM.LLVMType(LLVM.API.LLVMGetAllocatedType(alloca)) lifetimestarts = LLVM.Instruction[] while length(todo) > 0 cur, prev = pop!(todo) - if isa(cur, LLVM.AllocaInst) || isa(cur, LLVM.AddrSpaceCastInst) || isa(cur, LLVM.BitCastInst) + if isa(cur, LLVM.AllocaInst) || + isa(cur, LLVM.AddrSpaceCastInst) || + isa(cur, LLVM.BitCastInst) for u in LLVM.uses(cur) u = LLVM.user(u) push!(todo, (u, cur)) end continue end - if isa(cur, LLVM.CallInst) && isa(LLVM.called_operand(cur), LLVM.Function) + if isa(cur, LLVM.CallInst) && + isa(LLVM.called_operand(cur), LLVM.Function) intr = LLVM.API.LLVMGetIntrinsicID(LLVM.called_operand(cur)) if intr == LLVM.Intrinsic("llvm.lifetime.start").id push!(lifetimestarts, cur) @@ -466,7 +527,9 @@ function memcpy_alloca_to_loadstore(mod::LLVM.Module) end if intr == LLVM.Intrinsic("llvm.memcpy").id sz = operands(cur)[3] - if operands(cur)[1] == prev && isa(sz, LLVM.ConstantInt) && convert(Int, sz) == sizeof(dl, elty) + if operands(cur)[1] == prev && + isa(sz, LLVM.ConstantInt) && + convert(Int, sz) == sizeof(dl, elty) if copy === nothing || copy == cur copy = cur continue @@ -479,13 +542,16 @@ function memcpy_alloca_to_loadstore(mod::LLVM.Module) if isa(cur, LLVM.LoadInst) continue end - if isa(cur, LLVM.CallInst) && isa(LLVM.called_operand(cur), LLVM.Function) + if isa(cur, LLVM.CallInst) && + isa(LLVM.called_operand(cur), LLVM.Function) legalc = true for (i, ci) in enumerate(operands(cur)[1:end-1]) if ci == prev nocapture = false readonly = false - for a in collect(parameter_attributes(LLVM.called_operand(cur), i)) + for a in collect( + parameter_attributes(LLVM.called_operand(cur), i), + ) if kind(a) == kind(EnumAttribute("readonly")) readonly = true end @@ -510,21 +576,35 @@ function memcpy_alloca_to_loadstore(mod::LLVM.Module) legal = false break end - + if legal && copy !== nothing B = LLVM.IRBuilder() position!(B, copy) dst = operands(copy)[1] src = operands(copy)[2] - dst0 = bitcast!(B, dst, LLVM.PointerType(LLVM.IntType(8), addrspace(value_type(dst)))) + dst0 = bitcast!( + B, + dst, + LLVM.PointerType(LLVM.IntType(8), addrspace(value_type(dst))), + ) - dst = bitcast!(B, dst, LLVM.PointerType(elty, addrspace(value_type(dst)))) - src = bitcast!(B, src, LLVM.PointerType(elty, addrspace(value_type(src)))) + dst = + bitcast!(B, dst, LLVM.PointerType(elty, addrspace(value_type(dst)))) + src = + bitcast!(B, src, LLVM.PointerType(elty, addrspace(value_type(src)))) src = load!(B, elty, src) - FT = LLVM.FunctionType(LLVM.VoidType(), [LLVM.IntType(64), value_type(dst0)]) + FT = LLVM.FunctionType( + LLVM.VoidType(), + [LLVM.IntType(64), value_type(dst0)], + ) lifetimestart, _ = get_function!(mod, "llvm.lifetime.start.p0i8", FT) - call!(B, FT, lifetimestart, LLVM.Value[LLVM.ConstantInt(Int64(sizeof(dl, elty))), dst0]) + call!( + B, + FT, + lifetimestart, + LLVM.Value[LLVM.ConstantInt(Int64(sizeof(dl, elty))), dst0], + ) store!(B, src, dst) push!(todel, copy) end @@ -601,289 +681,376 @@ function nodecayed_phis!(mod::LLVM.Module) end end - offty = LLVM.IntType(8*sizeof(Int)) + offty = LLVM.IntType(8 * sizeof(Int)) i8 = LLVM.IntType(8) for addr in (11, 13) - nextvs = Dict{LLVM.PHIInst, LLVM.PHIInst}() - mtodo = Vector{LLVM.PHIInst}[] - goffsets = Dict{LLVM.PHIInst, LLVM.PHIInst}() - nonphis = LLVM.Instruction[] - anyV = false - for bb in blocks(f) - todo = LLVM.PHIInst[] - nonphi = nothing - for inst in instructions(bb) - if !isa(inst, LLVM.PHIInst) - nonphi = inst - break - end - ty = value_type(inst) - if !isa(ty, LLVM.PointerType) - continue - end - if addrspace(ty) != addr - continue - end - if addr == 11 - all_args = true - addrtodo = Value[inst] - seen = Set{LLVM.Value}() + nextvs = Dict{LLVM.PHIInst,LLVM.PHIInst}() + mtodo = Vector{LLVM.PHIInst}[] + goffsets = Dict{LLVM.PHIInst,LLVM.PHIInst}() + nonphis = LLVM.Instruction[] + anyV = false + for bb in blocks(f) + todo = LLVM.PHIInst[] + nonphi = nothing + for inst in instructions(bb) + if !isa(inst, LLVM.PHIInst) + nonphi = inst + break + end + ty = value_type(inst) + if !isa(ty, LLVM.PointerType) + continue + end + if addrspace(ty) != addr + continue + end + if addr == 11 + all_args = true + addrtodo = Value[inst] + seen = Set{LLVM.Value}() - while length(addrtodo) != 0 - v = pop!(addrtodo) - base = get_base_object(v) - if in(base, seen) - continue - end - push!(seen, base) - if isa(base, LLVM.Argument) && addrspace(value_type(base)) == 11 - continue - end - if isa(base, LLVM.PHIInst) - for (v, _) in LLVM.incoming(base) - push!(addrtodo, v) + while length(addrtodo) != 0 + v = pop!(addrtodo) + base = get_base_object(v) + if in(base, seen) + continue end + push!(seen, base) + if isa(base, LLVM.Argument) && addrspace(value_type(base)) == 11 + continue + end + if isa(base, LLVM.PHIInst) + for (v, _) in LLVM.incoming(base) + push!(addrtodo, v) + end + continue + end + all_args = false + break + end + if all_args continue end - all_args = false - break end - if all_args - continue + + push!(todo, inst) + nb = IRBuilder() + position!(nb, inst) + el_ty = if addr == 11 + eltype(ty) + else + LLVM.StructType(LLVM.LLVMType[]) end - end - - push!(todo, inst) - nb = IRBuilder() - position!(nb, inst) - el_ty = if addr == 11 - eltype(ty) - else - LLVM.StructType(LLVM.LLVMType[]) - end - nphi = phi!(nb, LLVM.PointerType(el_ty, 10), "nodecayed." * LLVM.name(inst)) - nextvs[inst] = nphi - anyV = true + nphi = phi!( + nb, + LLVM.PointerType(el_ty, 10), + "nodecayed." * LLVM.name(inst), + ) + nextvs[inst] = nphi + anyV = true - goffsets[inst] = phi!(nb, offty, "nodecayedoff." * LLVM.name(inst)) + goffsets[inst] = phi!(nb, offty, "nodecayedoff." * LLVM.name(inst)) + end + push!(mtodo, todo) + push!(nonphis, nonphi) end - push!(mtodo, todo) - push!(nonphis, nonphi) - end - for (bb, todo, nonphi) in zip(blocks(f), mtodo, nonphis) + for (bb, todo, nonphi) in zip(blocks(f), mtodo, nonphis) - for inst in todo - ty = value_type(inst) - el_ty = if addr == 11 - eltype(ty) - else - LLVM.StructType(LLVM.LLVMType[]) - end - nvs = Tuple{LLVM.Value, LLVM.BasicBlock}[] - offsets = Tuple{LLVM.Value, LLVM.BasicBlock}[] - for (v, pb) in LLVM.incoming(inst) - done = false - for ((nv, pb0), (offset, pb1)) in zip(nvs, offsets) - if pb0 == pb - push!(nvs, (nv, pb)) - push!(offsets, (offset, pb)) - done = true - break - end - end - if done - continue - end - b = IRBuilder() - position!(b, terminator(pb)) + for inst in todo + ty = value_type(inst) + el_ty = if addr == 11 + eltype(ty) + else + LLVM.StructType(LLVM.LLVMType[]) + end + nvs = Tuple{LLVM.Value,LLVM.BasicBlock}[] + offsets = Tuple{LLVM.Value,LLVM.BasicBlock}[] + for (v, pb) in LLVM.incoming(inst) + done = false + for ((nv, pb0), (offset, pb1)) in zip(nvs, offsets) + if pb0 == pb + push!(nvs, (nv, pb)) + push!(offsets, (offset, pb)) + done = true + break + end + end + if done + continue + end + b = IRBuilder() + position!(b, terminator(pb)) - v0 = v - @inline function getparent(v, offset, hasload) - if addr == 11 && addrspace(value_type(v)) == 10 - return v, offset, hasload - end - if addr == 13 && hasload && addrspace(value_type(v)) == 10 - return v, offset, hasload - end - if addr == 13 && isa(v, LLVM.LoadInst) && !hasload - return getparent(operands(v)[1], offset, true) - end + v0 = v + @inline function getparent(v, offset, hasload) + if addr == 11 && addrspace(value_type(v)) == 10 + return v, offset, hasload + end + if addr == 13 && hasload && addrspace(value_type(v)) == 10 + return v, offset, hasload + end + if addr == 13 && isa(v, LLVM.LoadInst) && !hasload + return getparent(operands(v)[1], offset, true) + end - if addr == 13 && isa(v, LLVM.ConstantExpr) - if opcode(v) == LLVM.API.LLVMAddrSpaceCast - v2 = operands(v)[1] - if addrspace(value_type(v2)) == 0 - if addr == 13 && isa(v, LLVM.ConstantExpr) - v2 = const_addrspacecast(operands(v)[1], LLVM.PointerType(eltype(value_type(v)), 10)) - return v2, offset, hasload + if addr == 13 && isa(v, LLVM.ConstantExpr) + if opcode(v) == LLVM.API.LLVMAddrSpaceCast + v2 = operands(v)[1] + if addrspace(value_type(v2)) == 0 + if addr == 13 && isa(v, LLVM.ConstantExpr) + v2 = const_addrspacecast( + operands(v)[1], + LLVM.PointerType(eltype(value_type(v)), 10), + ) + return v2, offset, hasload + end + end end end - end - end - if addr == 11 && isa(v, LLVM.ConstantExpr) - if opcode(v) == LLVM.API.LLVMAddrSpaceCast - v2 = operands(v)[1] - if addrspace(value_type(v2)) == 10 - return v2, offset, hasload + if addr == 11 && isa(v, LLVM.ConstantExpr) + if opcode(v) == LLVM.API.LLVMAddrSpaceCast + v2 = operands(v)[1] + if addrspace(value_type(v2)) == 10 + return v2, offset, hasload + end + if addrspace(value_type(v2)) == 0 + if addr == 11 + v2 = const_addrspacecast( + v2, + LLVM.PointerType(eltype(value_type(v)), 10), + ) + return v2, offset, hasload + end + end + if LLVM.isnull(v2) + v2 = const_addrspacecast( + v2, + LLVM.PointerType(eltype(value_type(v)), 10), + ) + return v2, offset, hasload + end + end end - if addrspace(value_type(v2)) == 0 - if addr == 11 - v2 = const_addrspacecast(v2, LLVM.PointerType(eltype(value_type(v)), 10)) + + if isa(v, LLVM.AddrSpaceCastInst) + if addrspace(value_type(operands(v)[1])) == 0 + v2 = addrspacecast!( + b, + operands(v)[1], + LLVM.PointerType(eltype(value_type(v)), 10), + ) return v2, offset, hasload end + nv, noffset, nhasload = + getparent(operands(v)[1], offset, hasload) + if eltype(value_type(nv)) != eltype(value_type(v)) + nv = bitcast!( + b, + nv, + LLVM.PointerType( + eltype(value_type(v)), + addrspace(value_type(nv)), + ), + ) + end + return nv, noffset, nhasload end - if LLVM.isnull(v2) - v2 = const_addrspacecast(v2, LLVM.PointerType(eltype(value_type(v)), 10)) - return v2, offset, hasload - end - end - end - if isa(v, LLVM.AddrSpaceCastInst) - if addrspace(value_type(operands(v)[1])) == 0 - v2 = addrspacecast!(b, operands(v)[1], LLVM.PointerType(eltype(value_type(v)), 10)) - return v2, offset, hasload - end - nv, noffset, nhasload = getparent(operands(v)[1], offset, hasload) - if eltype(value_type(nv)) != eltype(value_type(v)) - nv = bitcast!(b, nv, LLVM.PointerType(eltype(value_type(v)), addrspace(value_type(nv)))) - end - return nv, noffset, nhasload - end + if isa(v, LLVM.BitCastInst) + v2, offset, skipload = + getparent(operands(v)[1], offset, hasload) + v2 = bitcast!( + b, + v2, + LLVM.PointerType( + eltype(value_type(v)), + addrspace(value_type(v2)), + ), + ) + @assert eltype(value_type(v2)) == eltype(value_type(v)) + return v2, offset, skipload + end - if isa(v, LLVM.BitCastInst) - v2, offset, skipload = getparent(operands(v)[1], offset, hasload) - v2 = bitcast!(b, v2, LLVM.PointerType(eltype(value_type(v)), addrspace(value_type(v2)))) - @assert eltype(value_type(v2)) == eltype(value_type(v)) - return v2, offset, skipload - end + if isa(v, LLVM.GetElementPtrInst) && all( + x -> (isa(x, LLVM.ConstantInt) && convert(Int, x) == 0), + operands(v)[2:end], + ) + v2, offset, skipload = + getparent(operands(v)[1], offset, hasload) + v2 = bitcast!( + b, + v2, + LLVM.PointerType( + eltype(value_type(v)), + addrspace(value_type(v2)), + ), + ) + @assert eltype(value_type(v2)) == eltype(value_type(v)) + return v2, offset, skipload + end - if isa(v, LLVM.GetElementPtrInst) && all(x->(isa(x, LLVM.ConstantInt) && convert(Int, x) == 0), operands(v)[2:end]) - v2, offset, skipload = getparent(operands(v)[1], offset, hasload) - v2 = bitcast!(b, v2, LLVM.PointerType(eltype(value_type(v)), addrspace(value_type(v2)))) - @assert eltype(value_type(v2)) == eltype(value_type(v)) - return v2, offset, skipload - end + if isa(v, LLVM.GetElementPtrInst) && !hasload + v2, offset, skipload = + getparent(operands(v)[1], offset, hasload) + offset = nuwadd!( + b, + offset, + API.EnzymeComputeByteOffsetOfGEP(b, v, offty), + ) + v2 = bitcast!( + b, + v2, + LLVM.PointerType( + eltype(value_type(v)), + addrspace(value_type(v2)), + ), + ) + @assert eltype(value_type(v2)) == eltype(value_type(v)) + return v2, offset, skipload + end - if isa(v, LLVM.GetElementPtrInst) && !hasload - v2, offset, skipload = getparent(operands(v)[1], offset, hasload) - offset = nuwadd!(b, offset, API.EnzymeComputeByteOffsetOfGEP(b, v, offty)) - v2 = bitcast!(b, v2, LLVM.PointerType(eltype(value_type(v)), addrspace(value_type(v2)))) - @assert eltype(value_type(v2)) == eltype(value_type(v)) - return v2, offset, skipload - end + if isa(v, LLVM.ConstantExpr) && + opcode(v) == LLVM.API.LLVMGetElementPtr && + !hasload + v2, offset, skipload = + getparent(operands(v)[1], offset, hasload) + offset = nuwadd!( + b, + offset, + API.EnzymeComputeByteOffsetOfGEP(b, v, offty), + ) + v2 = bitcast!( + b, + v2, + LLVM.PointerType( + eltype(value_type(v)), + addrspace(value_type(v2)), + ), + ) + @assert eltype(value_type(v2)) == eltype(value_type(v)) + return v2, offset, skipload + end - if isa(v, LLVM.ConstantExpr) && opcode(v) == LLVM.API.LLVMGetElementPtr && !hasload - v2, offset, skipload = getparent(operands(v)[1], offset, hasload) - offset = nuwadd!(b, offset, API.EnzymeComputeByteOffsetOfGEP(b, v, offty)) - v2 = bitcast!(b, v2, LLVM.PointerType(eltype(value_type(v)), addrspace(value_type(v2)))) - @assert eltype(value_type(v2)) == eltype(value_type(v)) - return v2, offset, skipload - end + undeforpoison = isa(v, LLVM.UndefValue) + @static if LLVM.version() >= v"12" + undeforpoison |= isa(v, LLVM.PoisonValue) + end + if undeforpoison + return LLVM.UndefValue( + LLVM.PointerType(eltype(value_type(v)), 10), + ), + offset, + addr == 13 + end - undeforpoison = isa(v, LLVM.UndefValue) - @static if LLVM.version() >= v"12" - undeforpoison |= isa(v, LLVM.PoisonValue) - end - if undeforpoison - return LLVM.UndefValue(LLVM.PointerType(eltype(value_type(v)),10)), offset, addr == 13 - end + if isa(v, LLVM.PHIInst) && !hasload && haskey(goffsets, v) + offset = nuwadd!(b, offset, goffsets[v]) + nv = nextvs[v] + return nv, offset, addr == 13 + end - if isa(v, LLVM.PHIInst) && !hasload && haskey(goffsets, v) - offset = nuwadd!(b, offset, goffsets[v]) - nv = nextvs[v] - return nv, offset, addr == 13 - end + if isa(v, LLVM.SelectInst) + lhs_v, lhs_offset, lhs_skipload = + getparent(operands(v)[2], offset, hasload) + rhs_v, rhs_offset, rhs_skipload = + getparent(operands(v)[3], offset, hasload) + if value_type(lhs_v) != value_type(rhs_v) || + value_type(lhs_offset) != value_type(rhs_offset) || + lhs_skipload != rhs_skipload + msg = sprint() do io + println( + io, + "Could not analyze [select] garbage collection behavior of", + ) + println(io, " v0: ", string(v0)) + println(io, " v: ", string(v)) + println(io, " offset: ", string(offset)) + println(io, " hasload: ", string(hasload)) + println(io, " lhs_v", lhs_v) + println(io, " rhs_v", rhs_v) + println(io, " lhs_offset", lhs_offset) + println(io, " rhs_offset", rhs_offset) + println(io, " lhs_skipload", lhs_skipload) + println(io, " rhs_skipload", rhs_skipload) + end + bt = GPUCompiler.backtrace(inst) + throw(EnzymeInternalError(msg, string(f), bt)) + end + return select!(b, operands(v)[1], lhs_v, rhs_v), + select!(b, operands(v)[1], lhs_offset, rhs_offset), + lhs_skipload + end - if isa(v, LLVM.SelectInst) - lhs_v, lhs_offset, lhs_skipload = getparent(operands(v)[2], offset, hasload) - rhs_v, rhs_offset, rhs_skipload = getparent(operands(v)[3], offset, hasload) - if value_type(lhs_v) != value_type(rhs_v) || value_type(lhs_offset) != value_type(rhs_offset) || lhs_skipload != rhs_skipload msg = sprint() do io - println(io, "Could not analyze [select] garbage collection behavior of") + println(io, "Could not analyze garbage collection behavior of") + println(io, " inst: ", string(inst)) println(io, " v0: ", string(v0)) println(io, " v: ", string(v)) println(io, " offset: ", string(offset)) println(io, " hasload: ", string(hasload)) - println(io, " lhs_v", lhs_v) - println(io, " rhs_v", rhs_v) - println(io, " lhs_offset", lhs_offset) - println(io, " rhs_offset", rhs_offset) - println(io, " lhs_skipload", lhs_skipload) - println(io, " rhs_skipload", rhs_skipload) end bt = GPUCompiler.backtrace(inst) throw(EnzymeInternalError(msg, string(f), bt)) end - return select!(b, operands(v)[1], lhs_v, rhs_v), select!(b, operands(v)[1], lhs_offset, rhs_offset), lhs_skipload - end - msg = sprint() do io - println(io, "Could not analyze garbage collection behavior of") - println(io, " inst: ", string(inst)) - println(io, " v0: ", string(v0)) - println(io, " v: ", string(v)) - println(io, " offset: ", string(offset)) - println(io, " hasload: ", string(hasload)) - end - bt = GPUCompiler.backtrace(inst) - throw(EnzymeInternalError(msg, string(f), bt)) - end + v, offset, hadload = getparent(v, LLVM.ConstantInt(offty, 0), false) - v, offset, hadload = getparent(v, LLVM.ConstantInt(offty, 0), false) - - if addr == 13 - @assert hadload - end + if addr == 13 + @assert hadload + end - if eltype(value_type(v)) != el_ty - v = bitcast!(b, v, LLVM.PointerType(el_ty, addrspace(value_type(v)))) - end - push!(nvs, (v, pb)) - push!(offsets, (offset, pb)) - end + if eltype(value_type(v)) != el_ty + v = bitcast!( + b, + v, + LLVM.PointerType(el_ty, addrspace(value_type(v))), + ) + end + push!(nvs, (v, pb)) + push!(offsets, (offset, pb)) + end - nb = IRBuilder() - position!(nb, inst) - - offset = goffsets[inst] - append!(LLVM.incoming(offset), offsets) - if all(x->x[1]==offsets[1][1], offsets) - offset = offsets[1][1] - end + nb = IRBuilder() + position!(nb, inst) - nphi = nextvs[inst] - if !all(x->x[1]==nvs[1][1], nvs) - append!(LLVM.incoming(nphi), nvs) - else - replace_uses!(nphi, nvs[1][1]) - LLVM.API.LLVMInstructionEraseFromParent(nphi) - nphi = nvs[1][1] - end + offset = goffsets[inst] + append!(LLVM.incoming(offset), offsets) + if all(x -> x[1] == offsets[1][1], offsets) + offset = offsets[1][1] + end - position!(nb, nonphi) - if addr == 13 - nphi = bitcast!(nb, nphi, LLVM.PointerType(ty, 10)) - nphi = addrspacecast!(nb, nphi, LLVM.PointerType(ty, 11)) - nphi = load!(nb, ty, nphi) - else - nphi = addrspacecast!(nb, nphi, ty) - end - if !isa(offset, LLVM.ConstantInt) || convert(Int64, offset) != 0 - nphi = bitcast!(nb, nphi, LLVM.PointerType(i8, addrspace(ty))) - nphi = gep!(nb, i8, nphi, [offset]) - nphi = bitcast!(nb, nphi, ty) + nphi = nextvs[inst] + if !all(x -> x[1] == nvs[1][1], nvs) + append!(LLVM.incoming(nphi), nvs) + else + replace_uses!(nphi, nvs[1][1]) + LLVM.API.LLVMInstructionEraseFromParent(nphi) + nphi = nvs[1][1] + end + + position!(nb, nonphi) + if addr == 13 + nphi = bitcast!(nb, nphi, LLVM.PointerType(ty, 10)) + nphi = addrspacecast!(nb, nphi, LLVM.PointerType(ty, 11)) + nphi = load!(nb, ty, nphi) + else + nphi = addrspacecast!(nb, nphi, ty) + end + if !isa(offset, LLVM.ConstantInt) || convert(Int64, offset) != 0 + nphi = bitcast!(nb, nphi, LLVM.PointerType(i8, addrspace(ty))) + nphi = gep!(nb, i8, nphi, [offset]) + nphi = bitcast!(nb, nphi, ty) + end + replace_uses!(inst, nphi) + end + for inst in todo + LLVM.API.LLVMInstructionEraseFromParent(inst) + end end - replace_uses!(inst, nphi) - end - for inst in todo - LLVM.API.LLVMInstructionEraseFromParent(inst) end - end - end end return nothing end @@ -910,56 +1077,58 @@ function fix_decayaddr!(mod::LLVM.Module) temp = nothing for u in LLVM.uses(inst) st = LLVM.user(u) - # Storing _into_ the decay addr is okay - # we just cannot store the decayed addr into - # somewhere - if isa(st, LLVM.StoreInst) - if operands(st)[2] == inst - LLVM.API.LLVMSetOperand(st, 2-1, operands(inst)[1]) - continue - end - end - if isa(st, LLVM.LoadInst) - LLVM.API.LLVMSetOperand(st, 1-1, operands(inst)[1]) + # Storing _into_ the decay addr is okay + # we just cannot store the decayed addr into + # somewhere + if isa(st, LLVM.StoreInst) + if operands(st)[2] == inst + LLVM.API.LLVMSetOperand(st, 2 - 1, operands(inst)[1]) continue - end - # if isa(st, LLVM.InsertValueInst) - # if operands(st)[1] == inst - # push!(invalid, st) - # LLVM.API.LLVMSetOperand(st, 1-1, LLVM.UndefValue(value_type(inst))) - # continue - # end - # if operands(st)[2] == inst - # push!(invalid, st) - # LLVM.API.LLVMSetOperand(st, 2-1, LLVM.UndefValue(value_type(inst))) - # continue - # end - # end - if !isa(st, LLVM.CallInst) - bt = GPUCompiler.backtrace(st) - msg = sprint() do io::IO - println(io, string(f)) - println(io, inst) - println(io, st) - print(io, "Illegal decay of nonnull\n") - if bt !== nothing - print(io,"\nCaused by:") - Base.show_backtrace(io, bt) - println(io) - end - end - throw(AssertionError(msg)) - end - + end + end + if isa(st, LLVM.LoadInst) + LLVM.API.LLVMSetOperand(st, 1 - 1, operands(inst)[1]) + continue + end + # if isa(st, LLVM.InsertValueInst) + # if operands(st)[1] == inst + # push!(invalid, st) + # LLVM.API.LLVMSetOperand(st, 1-1, LLVM.UndefValue(value_type(inst))) + # continue + # end + # if operands(st)[2] == inst + # push!(invalid, st) + # LLVM.API.LLVMSetOperand(st, 2-1, LLVM.UndefValue(value_type(inst))) + # continue + # end + # end + if !isa(st, LLVM.CallInst) + bt = GPUCompiler.backtrace(st) + msg = sprint() do io::IO + println(io, string(f)) + println(io, inst) + println(io, st) + print(io, "Illegal decay of nonnull\n") + if bt !== nothing + print(io, "\nCaused by:") + Base.show_backtrace(io, bt) + println(io) + end + end + throw(AssertionError(msg)) + end + fop = operands(st)[end] - + intr = LLVM.API.LLVMGetIntrinsicID(fop) - if intr == LLVM.Intrinsic("llvm.memcpy").id || intr == LLVM.Intrinsic("llvm.memmove").id || intr == LLVM.Intrinsic("llvm.memset").id + if intr == LLVM.Intrinsic("llvm.memcpy").id || + intr == LLVM.Intrinsic("llvm.memmove").id || + intr == LLVM.Intrinsic("llvm.memset").id newvs = LLVM.Value[] for (i, v) in enumerate(operands(st)[1:end-1]) if v == inst - LLVM.API.LLVMSetOperand(st, i-1, operands(inst)[1]) + LLVM.API.LLVMSetOperand(st, i - 1, operands(inst)[1]) push!(newvs, operands(inst)[1]) continue end @@ -976,22 +1145,36 @@ function fix_decayaddr!(mod::LLVM.Module) newi = memset!(nb, newvs[1], newvs[2], newvs[3], 0) end - for idx = [LLVM.API.LLVMAttributeFunctionIndex, LLVM.API.LLVMAttributeReturnIndex, [LLVM.API.LLVMAttributeIndex(i) for i in 1:(length(operands(st))-1)]...] + for idx in [ + LLVM.API.LLVMAttributeFunctionIndex, + LLVM.API.LLVMAttributeReturnIndex, + [ + LLVM.API.LLVMAttributeIndex(i) for + i = 1:(length(operands(st))-1) + ]..., + ] idx = reinterpret(LLVM.API.LLVMAttributeIndex, idx) - count = LLVM.API.LLVMGetCallSiteAttributeCount(st, idx); - - Attrs = Base.unsafe_convert(Ptr{LLVM.API.LLVMAttributeRef}, Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef)*count)) + count = LLVM.API.LLVMGetCallSiteAttributeCount(st, idx) + + Attrs = Base.unsafe_convert( + Ptr{LLVM.API.LLVMAttributeRef}, + Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef) * count), + ) LLVM.API.LLVMGetCallSiteAttributes(st, idx, Attrs) - for j in 1:count - LLVM.API.LLVMAddCallSiteAttribute(newi, idx, unsafe_load(Attrs, j)) + for j = 1:count + LLVM.API.LLVMAddCallSiteAttribute( + newi, + idx, + unsafe_load(Attrs, j), + ) end Libc.free(Attrs) end - + API.EnzymeCopyMetadata(newi, st) - + LLVM.API.LLVMInstructionEraseFromParent(st) - continue + continue end mayread = false maywrite = false @@ -1051,7 +1234,7 @@ function fix_decayaddr!(mod::LLVM.Module) end throw(AssertionError(msg)) end - + elt = eltype(value_type(inst)) if temp === nothing nb = IRBuilder() @@ -1089,12 +1272,13 @@ function pre_attr!(mod::LLVM.Module) if isempty(blocks(fn)) continue end - if linkage(fn) != LLVM.API.LLVMInternalLinkage && linkage(fn) != LLVM.API.LLVMPrivateLinkage + if linkage(fn) != LLVM.API.LLVMInternalLinkage && + linkage(fn) != LLVM.API.LLVMPrivateLinkage continue end - + fty = LLVM.FunctionType(fn) - nfn = LLVM.Function(mod, "enzyme_attr_prev_"*LLVM.name(enzymefn), fty) + nfn = LLVM.Function(mod, "enzyme_attr_prev_" * LLVM.name(enzymefn), fty) LLVM.IRBuilder() do builder entry = BasicBlock(nfn, "entry") position!(builder, entry) @@ -1111,74 +1295,89 @@ function pre_attr!(mod::LLVM.Module) end function jl_inst_simplify!(PM) - ccall((:LLVMAddJLInstSimplifyPass, API.libEnzyme), Cvoid, (LLVM.API.LLVMPassManagerRef,), PM) + ccall( + (:LLVMAddJLInstSimplifyPass, API.libEnzyme), + Cvoid, + (LLVM.API.LLVMPassManagerRef,), + PM, + ) end -function post_attr!(mod::LLVM.Module) -end +function post_attr!(mod::LLVM.Module) end function prop_global!(g) newfns = String[] changed = false - todo = Tuple{Vector{Cuint},LLVM.Value}[] - for u in LLVM.uses(g) - u = LLVM.user(u) - push!(todo, (Cuint[],u)) - end - while length(todo) > 0 - path, var = pop!(todo) - if isa(var, LLVM.LoadInst) - B = IRBuilder() - position!(B, var) - res = LLVM.initializer(g) - for p in path - res = extract_value!(B, res, p) - end - changed = true - for u in LLVM.uses(var) - u = LLVM.user(u) - if isa(u, LLVM.CallInst) - f2 = LLVM.called_operand(u) - if isa(f2, LLVM.Function) - push!(newfns, LLVM.name(f2)) - end - end - end - replace_uses!(var, res) - eraseInst(LLVM.parent(var), var) - continue + todo = Tuple{Vector{Cuint},LLVM.Value}[] + for u in LLVM.uses(g) + u = LLVM.user(u) + push!(todo, (Cuint[], u)) + end + while length(todo) > 0 + path, var = pop!(todo) + if isa(var, LLVM.LoadInst) + B = IRBuilder() + position!(B, var) + res = LLVM.initializer(g) + for p in path + res = extract_value!(B, res, p) end - if isa(var, LLVM.AddrSpaceCastInst) - for u in LLVM.uses(var) - u = LLVM.user(u) - push!(todo, (path, u)) + changed = true + for u in LLVM.uses(var) + u = LLVM.user(u) + if isa(u, LLVM.CallInst) + f2 = LLVM.called_operand(u) + if isa(f2, LLVM.Function) + push!(newfns, LLVM.name(f2)) + end end - continue end - if isa(var, LLVM.ConstantExpr) && opcode(var) == LLVM.API.LLVMAddrSpaceCast - for u in LLVM.uses(var) - u = LLVM.user(u) - push!(todo, (path, u)) - end - continue + replace_uses!(var, res) + eraseInst(LLVM.parent(var), var) + continue + end + if isa(var, LLVM.AddrSpaceCastInst) + for u in LLVM.uses(var) + u = LLVM.user(u) + push!(todo, (path, u)) end - if isa(var, LLVM.GetElementPtrInst) - if all(isa(v, LLVM.ConstantInt) for v in operands(var)[2:end]) - if convert(Cuint, operands(var)[2]) == 0 - for u in LLVM.uses(var) - u = LLVM.user(u) - push!(todo, (vcat(path,collect((convert(Cuint, v) for v in operands(var)[3:end]))), u)) - end + continue + end + if isa(var, LLVM.ConstantExpr) && opcode(var) == LLVM.API.LLVMAddrSpaceCast + for u in LLVM.uses(var) + u = LLVM.user(u) + push!(todo, (path, u)) + end + continue + end + if isa(var, LLVM.GetElementPtrInst) + if all(isa(v, LLVM.ConstantInt) for v in operands(var)[2:end]) + if convert(Cuint, operands(var)[2]) == 0 + for u in LLVM.uses(var) + u = LLVM.user(u) + push!( + todo, + ( + vcat( + path, + collect(( + convert(Cuint, v) for v in operands(var)[3:end] + )), + ), + u, + ), + ) end - continue end + continue end end + end return changed, newfns end # From https://llvm.org/doxygen/IR_2Instruction_8cpp_source.html#l00959 -function mayWriteToMemory(inst::LLVM.Instruction; err_is_readonly=false)::Bool +function mayWriteToMemory(inst::LLVM.Instruction; err_is_readonly = false)::Bool # we will ignore fense here if isa(inst, LLVM.StoreInst) return true @@ -1200,11 +1399,14 @@ function mayWriteToMemory(inst::LLVM.Instruction; err_is_readonly=false)::Bool end if isa(inst, LLVM.CallInst) || isa(inst, LLVM.InvokeInst) || isa(inst, LLVM.CallBrInst) idx = reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex) - count = LLVM.API.LLVMGetCallSiteAttributeCount(inst, idx); - - Attrs = Base.unsafe_convert(Ptr{LLVM.API.LLVMAttributeRef}, Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef)*count)) + count = LLVM.API.LLVMGetCallSiteAttributeCount(inst, idx) + + Attrs = Base.unsafe_convert( + Ptr{LLVM.API.LLVMAttributeRef}, + Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef) * count), + ) LLVM.API.LLVMGetCallSiteAttributes(inst, idx, Attrs) - for j in 1:count + for j = 1:count attr = LLVM.Attribute(unsafe_load(Attrs, j)) if kind(attr) == kind(EnumAttribute("readnone")) return false @@ -1298,14 +1500,14 @@ function remove_readonly_unused_calls!(fn::LLVM.Function, next::Set{String}) end end end - + changed = set_readonly!(fn) if length(calls) == 0 || hasUser return changed end - for c in calls + for c in calls parentf = LLVM.parent(LLVM.parent(c)) push!(next, LLVM.name(parentf)) LLVM.API.LLVMInstructionEraseFromParent(c) @@ -1317,7 +1519,8 @@ end function propagate_returned!(mod::LLVM.Module) globs = LLVM.GlobalVariable[] for g in globals(mod) - if linkage(g) == LLVM.API.LLVMInternalLinkage || linkage(g) == LLVM.API.LLVMPrivateLinkage + if linkage(g) == LLVM.API.LLVMInternalLinkage || + linkage(g) == LLVM.API.LLVMPrivateLinkage if !isconstant(g) continue end @@ -1344,19 +1547,33 @@ function propagate_returned!(mod::LLVM.Module) changed = true end attrs = collect(function_attributes(fn)) - prevent = any(kind(attr) == kind(StringAttribute("enzyme_preserve_primal")) for attr in attrs) + prevent = any( + kind(attr) == kind(StringAttribute("enzyme_preserve_primal")) for + attr in attrs + ) # if any(kind(attr) == kind(EnumAttribute("noinline")) for attr in attrs) # continue # end argn = nothing toremove = Int64[] for (i, arg) in enumerate(parameters(fn)) - if any(kind(attr) == kind(EnumAttribute("returned")) for attr in collect(parameter_attributes(fn, i))) + if any( + kind(attr) == kind(EnumAttribute("returned")) for + attr in collect(parameter_attributes(fn, i)) + ) argn = i end # remove unused sret-like - if !prevent && (linkage(fn) == LLVM.API.LLVMInternalLinkage || linkage(fn) == LLVM.API.LLVMPrivateLinkage) && any(kind(attr) == kind(EnumAttribute("nocapture")) for attr in collect(parameter_attributes(fn, i))) + if !prevent && + ( + linkage(fn) == LLVM.API.LLVMInternalLinkage || + linkage(fn) == LLVM.API.LLVMPrivateLinkage + ) && + any( + kind(attr) == kind(EnumAttribute("nocapture")) for + attr in collect(parameter_attributes(fn, i)) + ) val = nothing illegalUse = false torem = LLVM.Instruction[] @@ -1454,7 +1671,10 @@ function propagate_returned!(mod::LLVM.Module) end # interprocedural const prop from callers of arg - if !prevent && (linkage(fn) == LLVM.API.LLVMInternalLinkage || linkage(fn) == LLVM.API.LLVMPrivateLinkage) + if !prevent && ( + linkage(fn) == LLVM.API.LLVMInternalLinkage || + linkage(fn) == LLVM.API.LLVMPrivateLinkage + ) val = nothing illegalUse = false for u in LLVM.uses(fn) @@ -1479,9 +1699,9 @@ function propagate_returned!(mod::LLVM.Module) continue end @static if LLVM.version() >= v"12" - if isa(ops[i], LLVM.PoisonValue) - continue - end + if isa(ops[i], LLVM.PoisonValue) + continue + end end if ops[i] == arg continue @@ -1546,10 +1766,13 @@ function propagate_returned!(mod::LLVM.Module) end end if !baduse - push!(toremove, i-1) + push!(toremove, i - 1) end end - illegalUse = !(linkage(fn) == LLVM.API.LLVMInternalLinkage || linkage(fn) == LLVM.API.LLVMPrivateLinkage) + illegalUse = !( + linkage(fn) == LLVM.API.LLVMInternalLinkage || + linkage(fn) == LLVM.API.LLVMPrivateLinkage + ) hasAnyUse = false for u in LLVM.uses(fn) un = LLVM.user(u) @@ -1588,7 +1811,9 @@ function propagate_returned!(mod::LLVM.Module) end end #if the function return has no users whatsoever, remove it - if argn === nothing && !hasAnyUse && LLVM.return_type(LLVM.function_type(fn)) != LLVM.VoidType() + if argn === nothing && + !hasAnyUse && + LLVM.return_type(LLVM.function_type(fn)) != LLVM.VoidType() argn = -1 end if argn === nothing && length(toremove) == 0 @@ -1605,7 +1830,9 @@ function propagate_returned!(mod::LLVM.Module) un = LLVM.user(u) push!(next, LLVM.name(LLVM.parent(LLVM.parent(un)))) end - nfn = LLVM.Function(API.EnzymeCloneFunctionWithoutReturnOrArgs(fn, keepret, toremove)) + nfn = LLVM.Function( + API.EnzymeCloneFunctionWithoutReturnOrArgs(fn, keepret, toremove), + ) for u in LLVM.uses(fn) un = LLVM.user(u) push!(todo, un) @@ -1620,7 +1847,7 @@ function propagate_returned!(mod::LLVM.Module) eraseInst(mod, fn) changed = true catch - break + break end end if !changed @@ -1629,7 +1856,8 @@ function propagate_returned!(mod::LLVM.Module) todo = LLVM.Function[] for name in next fn = functions(mod)[name] - if linkage(fn) == LLVM.API.LLVMInternalLinkage || linkage(fn) == LLVM.API.LLVMPrivateLinkage + if linkage(fn) == LLVM.API.LLVMInternalLinkage || + linkage(fn) == LLVM.API.LLVMPrivateLinkage has_user = false for u in LLVM.uses(fn) has_user = true @@ -1651,11 +1879,11 @@ function detect_writeonly!(mod::LLVM.Module) end for (i, a) in enumerate(parameters(f)) if isa(value_type(a), LLVM.PointerType) - todo = Tuple{LLVM.Value, LLVM.Instruction}[] + todo = Tuple{LLVM.Value,LLVM.Instruction}[] for u in LLVM.uses(a) push!(todo, (a, LLVM.user(u))) end - seen = Set{Tuple{LLVM.Value, LLVM.Instruction}}() + seen = Set{Tuple{LLVM.Value,LLVM.Instruction}}() mayread = false maywrite = false while length(todo) > 0 @@ -1665,20 +1893,22 @@ function detect_writeonly!(mod::LLVM.Module) end push!(seen, cur) curv, curi = cur - + if isa(curi, LLVM.StoreInst) if operands(curi)[1] != curv maywrite = true continue end end - + if isa(curi, LLVM.LoadInst) mayread = true continue end - if isa(curi, LLVM.GetElementPtrInst) || isa(curi, LLVM.BitCastInst) || isa(curi, LLVM.AddrSpaceCastInst) + if isa(curi, LLVM.GetElementPtrInst) || + isa(curi, LLVM.BitCastInst) || + isa(curi, LLVM.AddrSpaceCastInst) for u in LLVM.uses(curi) push!(todo, (curi, LLVM.user(u))) end @@ -1687,20 +1917,47 @@ function detect_writeonly!(mod::LLVM.Module) mayread = true maywrite = true end - if any(map(k->kind(k)==kind(EnumAttribute("readnone")), collect(parameter_attributes(f, i)))) + if any( + map( + k -> kind(k) == kind(EnumAttribute("readnone")), + collect(parameter_attributes(f, i)), + ), + ) mayread = false maywrite = false end - if any(map(k->kind(k)==kind(EnumAttribute("readonly")), collect(parameter_attributes(f, i)))) + if any( + map( + k -> kind(k) == kind(EnumAttribute("readonly")), + collect(parameter_attributes(f, i)), + ), + ) maywrite = false end - if any(map(k->kind(k)==kind(EnumAttribute("writeonly")), collect(parameter_attributes(f, i)))) + if any( + map( + k -> kind(k) == kind(EnumAttribute("writeonly")), + collect(parameter_attributes(f, i)), + ), + ) mayread = false end - - LLVM.API.LLVMRemoveEnumAttributeAtIndex(f, LLVM.API.LLVMAttributeIndex(i), kind(EnumAttribute("readnone"))) - LLVM.API.LLVMRemoveEnumAttributeAtIndex(f, LLVM.API.LLVMAttributeIndex(i), kind(EnumAttribute("readonly"))) - LLVM.API.LLVMRemoveEnumAttributeAtIndex(f, LLVM.API.LLVMAttributeIndex(i), kind(EnumAttribute("writeonly"))) + + LLVM.API.LLVMRemoveEnumAttributeAtIndex( + f, + LLVM.API.LLVMAttributeIndex(i), + kind(EnumAttribute("readnone")), + ) + LLVM.API.LLVMRemoveEnumAttributeAtIndex( + f, + LLVM.API.LLVMAttributeIndex(i), + kind(EnumAttribute("readonly")), + ) + LLVM.API.LLVMRemoveEnumAttributeAtIndex( + f, + LLVM.API.LLVMAttributeIndex(i), + kind(EnumAttribute("writeonly")), + ) if !mayread && !maywrite push!(parameter_attributes(f, i), LLVM.EnumAttribute("readnone", 0)) @@ -1752,17 +2009,26 @@ function validate_return_roots!(mod) if length(enzyme_srets) >= 1 && length(srets) == 0 @assert enzyme_srets[1] == 1 VT = LLVM.VoidType() - if length(enzyme_srets) == 1 && LLVM.return_type(LLVM.function_type(f)) == VT && length(enzyme_srets_v) == 0 + if length(enzyme_srets) == 1 && + LLVM.return_type(LLVM.function_type(f)) == VT && + length(enzyme_srets_v) == 0 # Upgrading to sret requires writeonly - if !any(kind(attr) == kind(EnumAttribute("writeonly")) for attr in collect(parameter_attributes(f, 1))) - msg = sprint() do io::IO - println(io, "Enzyme internal error (not writeonly sret)") - println(io, string(f)) - println(io, "collect(parameter_attributes(f, 1))=", collect(parameter_attributes(f, 1))) - end - throw(AssertionError(msg)) - end - + if !any( + kind(attr) == kind(EnumAttribute("writeonly")) for + attr in collect(parameter_attributes(f, 1)) + ) + msg = sprint() do io::IO + println(io, "Enzyme internal error (not writeonly sret)") + println(io, string(f)) + println( + io, + "collect(parameter_attributes(f, 1))=", + collect(parameter_attributes(f, 1)), + ) + end + throw(AssertionError(msg)) + end + alty = nothing for u in LLVM.uses(f) u = LLVM.user(u) @@ -1770,13 +2036,13 @@ function validate_return_roots!(mod) @assert LLVM.called_operand(u) == f alop = operands(u)[1] if !isa(alop, LLVM.AllocaInst) - msg = sprint() do io::IO - println(io, "Enzyme internal error (!isa(alop, LLVM.AllocaInst))") - println(io, "alop=", alop) - println(io, "u=", u) - println(io, "f=", string(f)) - end - throw(AssertionError(msg)) + msg = sprint() do io::IO + println(io, "Enzyme internal error (!isa(alop, LLVM.AllocaInst))") + println(io, "alop=", alop) + println(io, "u=", u) + println(io, "f=", string(f)) + end + throw(AssertionError(msg)) end @assert isa(alop, LLVM.AllocaInst) @@ -1791,8 +2057,17 @@ function validate_return_roots!(mod) else EnumAttribute("sret") end - LLVM.API.LLVMAddCallSiteAttribute(u, LLVM.API.LLVMAttributeIndex(1), attr) - LLVM.API.LLVMRemoveCallSiteStringAttribute(u, LLVM.API.LLVMAttributeIndex(1), "enzyme_sret", length("enzyme_sret")) + LLVM.API.LLVMAddCallSiteAttribute( + u, + LLVM.API.LLVMAttributeIndex(1), + attr, + ) + LLVM.API.LLVMRemoveCallSiteStringAttribute( + u, + LLVM.API.LLVMAttributeIndex(1), + "enzyme_sret", + length("enzyme_sret"), + ) end @assert alty !== nothing attr = if LLVM.version().major >= 12 @@ -1806,7 +2081,7 @@ function validate_return_roots!(mod) srets = [(1, attr)] enzyme_srets = Int[] else - + enzyme_srets2 = Int[] for idx in enzyme_srets alty = nothing @@ -1821,10 +2096,18 @@ function validate_return_roots!(mod) if any_jltypes(nty) bad = true end - LLVM.API.LLVMRemoveCallSiteStringAttribute(u, LLVM.API.LLVMAttributeIndex(idx), "enzyme_sret", length("enzyme_sret")) + LLVM.API.LLVMRemoveCallSiteStringAttribute( + u, + LLVM.API.LLVMAttributeIndex(idx), + "enzyme_sret", + length("enzyme_sret"), + ) end if !bad - delete!(parameter_attributes(f, idx), StringAttribute("enzyme_sret")) + delete!( + parameter_attributes(f, idx), + StringAttribute("enzyme_sret"), + ) else push!(enzyme_srets2, idx) end @@ -1832,16 +2115,16 @@ function validate_return_roots!(mod) enzyme_srets = enzyme_srets2 if length(enzyme_srets) != 0 - msg = sprint() do io::IO - println(io, "Enzyme internal error (length(enzyme_srets) != 0)") - println(io, "f=", string(f)) - println(io, "enzyme_srets=", enzyme_srets) - println(io, "enzyme_srets_v=", enzyme_srets_v) - println(io, "srets=", srets) - println(io, "rroots=", rroots) - println(io, "rroots_v=", rroots_v) - end - throw(AssertionError(msg)) + msg = sprint() do io::IO + println(io, "Enzyme internal error (length(enzyme_srets) != 0)") + println(io, "f=", string(f)) + println(io, "enzyme_srets=", enzyme_srets) + println(io, "enzyme_srets_v=", enzyme_srets_v) + println(io, "srets=", srets) + println(io, "rroots=", rroots) + println(io, "rroots_v=", rroots_v) + end + throw(AssertionError(msg)) end end end @@ -1860,7 +2143,7 @@ function validate_return_roots!(mod) end end -function checkNoAssumeFalse(mod, shouldshow=false) +function checkNoAssumeFalse(mod, shouldshow = false) for f in functions(mod) for bb in blocks(f), inst in instructions(bb) if !isa(inst, LLVM.CallInst) @@ -1885,7 +2168,8 @@ function checkNoAssumeFalse(mod, shouldshow=false) end end if isa(op, LLVM.ICmpInst) - if predicate_int(op) == LLVM.API.LLVMIntNE && operands(op)[1] == operands(op)[2] + if predicate_int(op) == LLVM.API.LLVMIntNE && + operands(op)[1] == operands(op)[2] msg = sprint() do io println(io, "Enzyme Internal Error: non-icmp assume condition") println(io, "mod=", string(mod)) @@ -1913,17 +2197,55 @@ function removeDeadArgs!(mod::LLVM.Module, tm) LLVM.run!(pm, mod) end # Prevent dead-arg-elimination of functions which we may require args for in the derivative - funcT = LLVM.FunctionType(LLVM.VoidType(), LLVMType[], vararg=true) + funcT = LLVM.FunctionType(LLVM.VoidType(), LLVMType[], vararg = true) if LLVM.version().major <= 15 - func, _ = get_function!(mod, "llvm.enzymefakeuse", funcT, [EnumAttribute("readnone"), EnumAttribute("nofree")]) - rfunc, _ = get_function!(mod, "llvm.enzymefakeread", funcT, [EnumAttribute("readonly"), EnumAttribute("nofree"), EnumAttribute("argmemonly")]) - sfunc, _ = get_function!(mod, "llvm.enzyme.sret_use", funcT, [EnumAttribute("readonly"), EnumAttribute("nofree"), EnumAttribute("argmemonly")]) + func, _ = get_function!( + mod, + "llvm.enzymefakeuse", + funcT, + [EnumAttribute("readnone"), EnumAttribute("nofree")], + ) + rfunc, _ = get_function!( + mod, + "llvm.enzymefakeread", + funcT, + [ + EnumAttribute("readonly"), + EnumAttribute("nofree"), + EnumAttribute("argmemonly"), + ], + ) + sfunc, _ = get_function!( + mod, + "llvm.enzyme.sret_use", + funcT, + [ + EnumAttribute("readonly"), + EnumAttribute("nofree"), + EnumAttribute("argmemonly"), + ], + ) else - func, _ = get_function!(mod, "llvm.enzymefakeuse", funcT, [EnumAttribute("memory", NoEffects.data), EnumAttribute("nofree")]) - rfunc, _ = get_function!(mod, "llvm.enzymefakeread", funcT, [EnumAttribute("memory", ReadOnlyArgMemEffects.data), EnumAttribute("nofree")]) - sfunc, _ = get_function!(mod, "llvm.enzyme.sret_use", funcT, [EnumAttribute("memory", ReadOnlyArgMemEffects.data), EnumAttribute("nofree")]) + func, _ = get_function!( + mod, + "llvm.enzymefakeuse", + funcT, + [EnumAttribute("memory", NoEffects.data), EnumAttribute("nofree")], + ) + rfunc, _ = get_function!( + mod, + "llvm.enzymefakeread", + funcT, + [EnumAttribute("memory", ReadOnlyArgMemEffects.data), EnumAttribute("nofree")], + ) + sfunc, _ = get_function!( + mod, + "llvm.enzyme.sret_use", + funcT, + [EnumAttribute("memory", ReadOnlyArgMemEffects.data), EnumAttribute("nofree")], + ) end - + for fn in functions(mod) if isempty(blocks(fn)) continue @@ -1933,7 +2255,12 @@ function removeDeadArgs!(mod::LLVM.Module, tm) # active both can occur on 4. If the original sret is removed (at index 1) we no longer need # to preserve this. for idx in (2, 3, 4) - if length(collect(parameters(fn))) >= idx && any( ( kind(attr) == kind(StringAttribute("enzymejl_returnRoots")) || kind(attr) == kind(StringAttribute("enzymejl_returnRoots_v"))) for attr in collect(parameter_attributes(fn, idx))) + if length(collect(parameters(fn))) >= idx && any( + ( + kind(attr) == kind(StringAttribute("enzymejl_returnRoots")) || + kind(attr) == kind(StringAttribute("enzymejl_returnRoots_v")) + ) for attr in collect(parameter_attributes(fn, idx)) + ) for u in LLVM.uses(fn) u = LLVM.user(u) @assert isa(u, LLVM.CallInst) @@ -1944,7 +2271,9 @@ function removeDeadArgs!(mod::LLVM.Module, tm) cl = call!(B, funcT, rfunc, LLVM.Value[inp]) if isa(value_type(inp), LLVM.PointerType) LLVM.API.LLVMAddCallSiteAttribute( - cl, LLVM.API.LLVMAttributeIndex(1), EnumAttribute("nocapture") + cl, + LLVM.API.LLVMAttributeIndex(1), + EnumAttribute("nocapture"), ) end end @@ -1960,7 +2289,13 @@ function removeDeadArgs!(mod::LLVM.Module, tm) continue end attrs = collect(parameter_attributes(fn, idx)) - if any( ( kind(attr) == sretkind || kind(attr) == kind(StringAttribute("enzyme_sret")) || kind(attr) == kind(StringAttribute("enzyme_sret_v")) ) for attr in attrs) + if any( + ( + kind(attr) == sretkind || + kind(attr) == kind(StringAttribute("enzyme_sret")) || + kind(attr) == kind(StringAttribute("enzyme_sret_v")) + ) for attr in attrs + ) for u in LLVM.uses(fn) u = LLVM.user(u) if isa(u, LLVM.ConstantExpr) @@ -1977,14 +2312,18 @@ function removeDeadArgs!(mod::LLVM.Module, tm) cl = call!(B, funcT, sfunc, LLVM.Value[inp]) if isa(value_type(inp), LLVM.PointerType) LLVM.API.LLVMAddCallSiteAttribute( - cl, LLVM.API.LLVMAttributeIndex(1), EnumAttribute("nocapture") + cl, + LLVM.API.LLVMAttributeIndex(1), + EnumAttribute("nocapture"), ) end end end end attrs = collect(function_attributes(fn)) - prevent = any(kind(attr) == kind(StringAttribute("enzyme_preserve_primal")) for attr in attrs) + prevent = any( + kind(attr) == kind(StringAttribute("enzyme_preserve_primal")) for attr in attrs + ) # && any(kind(attr) == kind(StringAttribute("enzyme_math")) for attr in attrs) if prevent B = IRBuilder() @@ -2009,7 +2348,7 @@ function removeDeadArgs!(mod::LLVM.Module, tm) API.EnzymeAddAttributorLegacyPass(pm) LLVM.run!(pm, mod) end - end + end end propagate_returned!(mod) ModulePassManager() do pm @@ -2088,7 +2427,7 @@ function optimize!(mod::LLVM.Module, tm) alloc_opt_tm!(pm, tm) loop_idiom!(pm) loop_rotate!(pm) - + loop_optimizations_tm!(pm, tm) instruction_combining!(pm) @@ -2099,7 +2438,7 @@ function optimize!(mod::LLVM.Module, tm) alloc_opt_tm!(pm, tm) scalar_repl_aggregates_ssa!(pm) # SSA variant? gvn!(pm) - + # This InstCombine needs to be after GVN # Otherwise it will generate load chains in GPU code... instruction_combining!(pm) @@ -2117,7 +2456,7 @@ function optimize!(mod::LLVM.Module, tm) jump_threading!(pm) correlated_value_propagation!(pm) # SLP_Vectorizer -- not for Enzyme - + LLVM.run!(pm, mod) aggressive_dce!(pm) @@ -2243,7 +2582,7 @@ function addMachinePasses!(pm, tm) gvn!(pm) end -function addJuliaLegalizationPasses!(pm, tm, lower_intrinsics=true) +function addJuliaLegalizationPasses!(pm, tm, lower_intrinsics = true) if lower_intrinsics # LowerPTLS removes an indirect call. As a result, it is likely to trigger # LLVM's devirtualization heuristics, which would result in the entire @@ -2267,7 +2606,7 @@ function addJuliaLegalizationPasses!(pm, tm, lower_intrinsics=true) sccp!(pm) # Remove dead use of ptls dce!(pm) - lower_ptls_tm!(pm, tm, #=dump_native=# false) + lower_ptls_tm!(pm, tm, false) #=dump_native=# instruction_combining!(pm) jl_inst_simplify!(pm) # Clean up write barrier and ptls lowering @@ -2278,7 +2617,7 @@ function addJuliaLegalizationPasses!(pm, tm, lower_intrinsics=true) end end -function post_optimze!(mod, tm, machine=true) +function post_optimze!(mod, tm, machine = true) addr13NoAlias(mod) removeDeadArgs!(mod, tm) for f in collect(functions(mod)) @@ -2289,7 +2628,14 @@ function post_optimze!(mod, tm, machine=true) end out_error = Ref{Cstring}() if LLVM.API.LLVMVerifyModule(mod, LLVM.API.LLVMReturnStatusAction, out_error) != 0 - throw(LLVM.LLVMException("broken gc calling conv fix\n"*string(unsafe_string(out_error[]))*"\n"*string(mod))) + throw( + LLVM.LLVMException( + "broken gc calling conv fix\n" * + string(unsafe_string(out_error[])) * + "\n" * + string(mod), + ), + ) end LLVM.ModulePassManager() do pm addTargetPasses!(pm, tm, LLVM.triple(mod)) diff --git a/src/compiler/orcv2.jl b/src/compiler/orcv2.jl index 482a961b52d..4b8f2d202a3 100644 --- a/src/compiler/orcv2.jl +++ b/src/compiler/orcv2.jl @@ -3,7 +3,7 @@ module JIT using LLVM using Libdl -import LLVM:TargetMachine +import LLVM: TargetMachine import GPUCompiler import ..Compiler @@ -13,8 +13,8 @@ export get_trampoline struct CompilerInstance jit::LLVM.JuliaOJIT - lctm::Union{LLVM.LazyCallThroughManager, Nothing} - ism::Union{LLVM.IndirectStubsManager, Nothing} + lctm::Union{LLVM.LazyCallThroughManager,Nothing} + ism::Union{LLVM.IndirectStubsManager,Nothing} end function LLVM.dispose(ci::CompilerInstance) @@ -35,24 +35,24 @@ get_tm() = tm[] get_jit() = jit[].jit function absolute_symbol_materialization(name, ptr) - address = LLVM.API.LLVMOrcJITTargetAddress(reinterpret(UInt, ptr)) - flags = LLVM.API.LLVMJITSymbolFlags(LLVM.API.LLVMJITSymbolGenericFlagsExported, 0) - symbol = LLVM.API.LLVMJITEvaluatedSymbol(address, flags) - gv = if LLVM.version() >= v"15" - LLVM.API.LLVMOrcCSymbolMapPair(name, symbol) - else - LLVM.API.LLVMJITCSymbolMapPair(name, symbol) - end - return LLVM.absolute_symbols(Ref(gv)) + address = LLVM.API.LLVMOrcJITTargetAddress(reinterpret(UInt, ptr)) + flags = LLVM.API.LLVMJITSymbolFlags(LLVM.API.LLVMJITSymbolGenericFlagsExported, 0) + symbol = LLVM.API.LLVMJITEvaluatedSymbol(address, flags) + gv = if LLVM.version() >= v"15" + LLVM.API.LLVMOrcCSymbolMapPair(name, symbol) + else + LLVM.API.LLVMJITCSymbolMapPair(name, symbol) + end + return LLVM.absolute_symbols(Ref(gv)) end function define_absolute_symbol(jd, name) - ptr = LLVM.find_symbol(name) - if ptr !== C_NULL - LLVM.define(jd, absolute_symbol_materialization(name, ptr)) - return true - end - return false + ptr = LLVM.find_symbol(name) + if ptr !== C_NULL + LLVM.define(jd, absolute_symbol_materialization(name, ptr)) + return true + end + return false end function __init__() @@ -68,7 +68,7 @@ function __init__() tempTM = LLVM.JITTargetMachine(LLVM.triple(), cpu_name(), cpu_features(); optlevel) LLVM.asm_verbosity!(tempTM, true) tm[] = tempTM - + lljit = JuliaOJIT() jd_main = JITDylib(lljit) @@ -77,10 +77,10 @@ function __init__() dg = LLVM.CreateDynamicLibrarySearchGeneratorForProcess(prefix) LLVM.add!(jd_main, dg) - if Sys.iswindows() && Int === Int64 - # TODO can we check isGNU? - define_absolute_symbol(jd_main, mangle(lljit, "___chkstk_ms")) - end + if Sys.iswindows() && Int === Int64 + # TODO can we check isGNU? + define_absolute_symbol(jd_main, mangle(lljit, "___chkstk_ms")) + end es = ExecutionSession(lljit) try @@ -95,12 +95,18 @@ function __init__() hnd = unsafe_load(cglobal(:jl_libjulia_handle, Ptr{Cvoid})) for (k, v) in Compiler.JuliaGlobalNameMap ptr = unsafe_load(Base.reinterpret(Ptr{Ptr{Cvoid}}, Libdl.dlsym(hnd, k))) - LLVM.define(jd_main, absolute_symbol_materialization(mangle(lljit, "ejl_"*k), ptr)) + LLVM.define( + jd_main, + absolute_symbol_materialization(mangle(lljit, "ejl_" * k), ptr), + ) end for (k, v) in Compiler.JuliaEnzymeNameMap ptr = Compiler.unsafe_to_ptr(v) - LLVM.define(jd_main, absolute_symbol_materialization(mangle(lljit, "ejl_"*k), ptr)) + LLVM.define( + jd_main, + absolute_symbol_materialization(mangle(lljit, "ejl_" * k), ptr), + ) end atexit() do @@ -123,13 +129,15 @@ end function add_trampoline!(jd, (lljit, lctm, ism), entry, target) flags = LLVM.API.LLVMJITSymbolFlags( - LLVM.API.LLVMJITSymbolGenericFlagsCallable | - LLVM.API.LLVMJITSymbolGenericFlagsExported, 0) + LLVM.API.LLVMJITSymbolGenericFlagsCallable | + LLVM.API.LLVMJITSymbolGenericFlagsExported, + 0, + ) alias = LLVM.API.LLVMOrcCSymbolAliasMapPair( - mangle(lljit, entry), - LLVM.API.LLVMOrcCSymbolAliasMapEntry( - mangle(lljit, target), flags)) + mangle(lljit, entry), + LLVM.API.LLVMOrcCSymbolAliasMapEntry(mangle(lljit, target), flags), + ) mu = LLVM.reexports(lctm, ism, jd, [alias]) LLVM.define(jd, mu) @@ -140,8 +148,8 @@ end function get_trampoline(job) compiler = jit[] lljit = compiler.jit - lctm = compiler.lctm - ism = compiler.ism + lctm = compiler.lctm + ism = compiler.ism if lctm === nothing || ism === nothing error("Delayed compilation not available.") @@ -155,8 +163,7 @@ function get_trampoline(job) sym = String(gensym(:func)) _sym = String(gensym(:func)) - addr = add_trampoline!(jd, (lljit, lctm, ism), - _sym, sym) + addr = add_trampoline!(jd, (lljit, lctm, ism), _sym, sym) # 3. add MU that will call back into the compiler function materialize(mr) @@ -193,16 +200,14 @@ function get_trampoline(job) function discard(jd, sym) end flags = LLVM.API.LLVMJITSymbolFlags( - LLVM.API.LLVMJITSymbolGenericFlagsCallable | - LLVM.API.LLVMJITSymbolGenericFlagsExported, 0) + LLVM.API.LLVMJITSymbolGenericFlagsCallable | + LLVM.API.LLVMJITSymbolGenericFlagsExported, + 0, + ) - symbols = [ - LLVM.API.LLVMOrcCSymbolFlagsMapPair( - mangle(lljit, sym), flags), - ] + symbols = [LLVM.API.LLVMOrcCSymbolFlagsMapPair(mangle(lljit, sym), flags)] - mu = LLVM.CustomMaterializationUnit(sym, symbols, - materialize, discard) + mu = LLVM.CustomMaterializationUnit(sym, symbols, materialize, discard) LLVM.define(jd, mu) return addr end diff --git a/src/compiler/passes.jl b/src/compiler/passes.jl index e7e3c3c0a7e..403b2bfa048 100644 --- a/src/compiler/passes.jl +++ b/src/compiler/passes.jl @@ -1,5 +1,5 @@ function reinsert_gcmarker_pass!(fn::LLVM.Function) - reinsert_gcmarker!(fn) + reinsert_gcmarker!(fn) unique_gcmarker!(fn) return true end diff --git a/src/compiler/reflection.jl b/src/compiler/reflection.jl index 583a6f2f68c..304372951c5 100644 --- a/src/compiler/reflection.jl +++ b/src/compiler/reflection.jl @@ -1,11 +1,27 @@ -function get_job(@nospecialize(func), @nospecialize(A), @nospecialize(types); - run_enzyme::Bool=true, mode::API.CDerivativeMode=API.DEM_ReverseModeCombined, dupClosure::Bool=false, argwrap::Bool=true, width::Int=1, modifiedBetween=nothing, returnPrimal::Bool=false, augmentedInit=false, world=nothing, ABI=DefaultABI, ErrIfFuncWritten=false, RuntimeActivity=true, kwargs...) +function get_job( + @nospecialize(func), + @nospecialize(A), + @nospecialize(types); + run_enzyme::Bool = true, + mode::API.CDerivativeMode = API.DEM_ReverseModeCombined, + dupClosure::Bool = false, + argwrap::Bool = true, + width::Int = 1, + modifiedBetween = nothing, + returnPrimal::Bool = false, + augmentedInit = false, + world = nothing, + ABI = DefaultABI, + ErrIfFuncWritten = false, + RuntimeActivity = true, + kwargs..., +) - tt = Tuple{map(eltype, types.parameters)...} + tt = Tuple{map(eltype, types.parameters)...} if world === nothing world = codegen_world_age(Core.Typeof(func), tt) end - + primal = fspec(Core.Typeof(func), types, world) rt = Core.Compiler.return_type(func, tt, world) @@ -15,16 +31,40 @@ function get_job(@nospecialize(func), @nospecialize(A), @nospecialize(types); defaultMod = mode != API.DEM_ReverseModeCombined && mode != API.DEM_ForwardMode modifiedBetween = (defaultMod, (defaultMod for _ in types.parameters)...) end - params = Compiler.EnzymeCompilerParams(Tuple{(dupClosure ? Duplicated : Const){Core.Typeof(func)}, types.parameters...}, mode, width, rt, run_enzyme, argwrap, modifiedBetween, returnPrimal, augmentedInit, Compiler.UnknownTapeType, ABI, ErrIfFuncWritten, RuntimeActivity) - return Compiler.CompilerJob(primal, CompilerConfig(target, params; kernel=false), world) + params = Compiler.EnzymeCompilerParams( + Tuple{(dupClosure ? Duplicated : Const){Core.Typeof(func)},types.parameters...}, + mode, + width, + rt, + run_enzyme, + argwrap, + modifiedBetween, + returnPrimal, + augmentedInit, + Compiler.UnknownTapeType, + ABI, + ErrIfFuncWritten, + RuntimeActivity, + ) + return Compiler.CompilerJob( + primal, + CompilerConfig(target, params; kernel = false), + world, + ) end -function reflect(@nospecialize(func), @nospecialize(A), @nospecialize(types); - optimize::Bool=true, second_stage::Bool=true, kwargs...) +function reflect( + @nospecialize(func), + @nospecialize(A), + @nospecialize(types); + optimize::Bool = true, + second_stage::Bool = true, + kwargs..., +) job = get_job(func, A, types; kwargs...) # Codegen the primal function and all its dependency in one module - mod, meta = Compiler.codegen(:llvm, job; optimize #= validate=false =#) + mod, meta = Compiler.codegen(:llvm, job; optimize) #= validate=false =# if second_stage post_optimze!(mod, JIT.get_tm()) @@ -40,33 +80,62 @@ struct jl_llvmf_dump F::LLVM.API.LLVMValueRef end -function enzyme_code_llvm(io::IO, @nospecialize(func), @nospecialize(A), @nospecialize(types); - optimize::Bool=true, run_enzyme::Bool=true, second_stage::Bool=true, - raw::Bool=false, debuginfo::Symbol=:default, dump_module::Bool=false, mode=API.DEM_ReverseModeCombined) +function enzyme_code_llvm( + io::IO, + @nospecialize(func), + @nospecialize(A), + @nospecialize(types); + optimize::Bool = true, + run_enzyme::Bool = true, + second_stage::Bool = true, + raw::Bool = false, + debuginfo::Symbol = :default, + dump_module::Bool = false, + mode = API.DEM_ReverseModeCombined, +) JuliaContext() do ctx entry_fn, ir = reflect(func, A, types; optimize, run_enzyme, second_stage, mode) ts_mod = ThreadSafeModule(ir) GC.@preserve ts_mod entry_fn begin value = Ref(jl_llvmf_dump(ts_mod.ref, entry_fn.ref)) - str = ccall(:jl_dump_function_ir, Ref{String}, - (Ptr{jl_llvmf_dump}, Bool, Bool, Ptr{UInt8}), - value, !raw, dump_module, debuginfo) + str = ccall( + :jl_dump_function_ir, + Ref{String}, + (Ptr{jl_llvmf_dump}, Bool, Bool, Ptr{UInt8}), + value, + !raw, + dump_module, + debuginfo, + ) end print(io, str) end end -enzyme_code_llvm(@nospecialize(func), @nospecialize(A), @nospecialize(types); kwargs...) = enzyme_code_llvm(stdout, func, A, types; kwargs...) +enzyme_code_llvm(@nospecialize(func), @nospecialize(A), @nospecialize(types); kwargs...) = + enzyme_code_llvm(stdout, func, A, types; kwargs...) -function enzyme_code_native(io::IO, @nospecialize(func), @nospecialize(A), @nospecialize(types); mode=API.DEM_ReverseModeCombined) +function enzyme_code_native( + io::IO, + @nospecialize(func), + @nospecialize(A), + @nospecialize(types); + mode = API.DEM_ReverseModeCombined, +) JuliaContext() do ctx _, mod = reflect(func, A, types; mode) str = String(LLVM.emit(JIT.get_tm(), mod, LLVM.API.LLVMAssemblyFile)) print(io, str) end end -enzyme_code_native(@nospecialize(func), @nospecialize(A), @nospecialize(types); kwargs...) = enzyme_code_native(stdout, func, A, types; kwargs...) +enzyme_code_native(@nospecialize(func), @nospecialize(A), @nospecialize(types); kwargs...) = + enzyme_code_native(stdout, func, A, types; kwargs...) -function enzyme_code_typed(@nospecialize(func), @nospecialize(A), @nospecialize(types); kwargs...) +function enzyme_code_typed( + @nospecialize(func), + @nospecialize(A), + @nospecialize(types); + kwargs..., +) job = get_job(func, A, types; kwargs...) GPUCompiler.code_typed(job; kwargs...) end diff --git a/src/compiler/utils.jl b/src/compiler/utils.jl index cde5d2cade4..8a801067eb3 100644 --- a/src/compiler/utils.jl +++ b/src/compiler/utils.jl @@ -2,16 +2,9 @@ struct MemoryEffect data::UInt32 end -@enum(ModRefInfo, - MRI_NoModRef = 0, - MRI_Ref = 1, - MRI_Mod = 2, - MRI_ModRef = 3) +@enum(ModRefInfo, MRI_NoModRef = 0, MRI_Ref = 1, MRI_Mod = 2, MRI_ModRef = 3) -@enum(IRMemLocation, - ArgMem = 0, - InaccessibleMem = 1, - Other = 2) +@enum(IRMemLocation, ArgMem = 0, InaccessibleMem = 1, Other = 2) const BitsPerLoc = UInt32(2) const LocMask = UInt32((1 << BitsPerLoc) - 1) @@ -27,14 +20,30 @@ end function Base.:&(lhs::ModRefInfo, rhs::ModRefInfo) ModRefInfo(UInt32(lhs) & UInt32(rhs)) end -const AllEffects = MemoryEffect((MRI_ModRef << getLocationPos(ArgMem)) | (MRI_ModRef << getLocationPos(InaccessibleMem)) | (MRI_ModRef << getLocationPos(Other))) -const ReadOnlyEffects = MemoryEffect((MRI_Ref << getLocationPos(ArgMem)) | (MRI_Ref << getLocationPos(InaccessibleMem)) | (MRI_Ref << getLocationPos(Other))) -const ReadOnlyArgMemEffects = MemoryEffect((MRI_Ref << getLocationPos(ArgMem)) | (MRI_NoModRef << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))) -const NoEffects = MemoryEffect((MRI_NoModRef << getLocationPos(ArgMem)) | (MRI_NoModRef << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))) +const AllEffects = MemoryEffect( + (MRI_ModRef << getLocationPos(ArgMem)) | + (MRI_ModRef << getLocationPos(InaccessibleMem)) | + (MRI_ModRef << getLocationPos(Other)), +) +const ReadOnlyEffects = MemoryEffect( + (MRI_Ref << getLocationPos(ArgMem)) | + (MRI_Ref << getLocationPos(InaccessibleMem)) | + (MRI_Ref << getLocationPos(Other)), +) +const ReadOnlyArgMemEffects = MemoryEffect( + (MRI_Ref << getLocationPos(ArgMem)) | + (MRI_NoModRef << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), +) +const NoEffects = MemoryEffect( + (MRI_NoModRef << getLocationPos(ArgMem)) | + (MRI_NoModRef << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), +) # Get ModRefInfo for any location. function getModRef(effect::MemoryEffect, loc::IRMemLocation)::ModRefInfo - ModRefInfo((effect.data >> getLocationPos(loc)) & LocMask) + ModRefInfo((effect.data >> getLocationPos(loc)) & LocMask) end function getModRef(effect::MemoryEffect)::ModRefInfo @@ -54,7 +63,7 @@ end function setModRef(effect::MemoryEffect)::MemoryEffect for loc in (ArgMem, InaccessibleMem, Other) - effect = setModRef(effect, mri)= getModRef(effect, loc) + effect = setModRef(effect, mri) = getModRef(effect, loc) end return effect end @@ -93,12 +102,12 @@ function is_writeonly(mri::ModRefInfo) end for n in (:is_readonly, :is_readnone, :is_writeonly) -@eval begin - function $n(memeffect::MemoryEffect) - return $n(getModRef(memeffect)) + @eval begin + function $n(memeffect::MemoryEffect) + return $n(getModRef(memeffect)) + end end end -end function is_noreturn(f::LLVM.Function) for attr in collect(function_attributes(f)) @@ -120,7 +129,8 @@ function is_readonly(f::LLVM.Function) if intr == LLVM.Intrinsic("llvm.assume").id return true end - if LLVM.name(f) == "llvm.julia.gc_preserve_begin" || LLVM.name(f) == "llvm.julia.gc_preserve_end" + if LLVM.name(f) == "llvm.julia.gc_preserve_begin" || + LLVM.name(f) == "llvm.julia.gc_preserve_end" return true end for attr in collect(function_attributes(f)) @@ -131,12 +141,12 @@ function is_readonly(f::LLVM.Function) return true end if LLVM.version().major > 15 - if kind(attr) == kind(EnumAttribute("memory")) - if is_readonly(MemoryEffect(value(attr))) - return true + if kind(attr) == kind(EnumAttribute("memory")) + if is_readonly(MemoryEffect(value(attr))) + return true + end end end - end end return false end @@ -152,7 +162,8 @@ function is_readnone(f::LLVM.Function) if intr == LLVM.Intrinsic("llvm.assume").id return true end - if LLVM.name(f) == "llvm.julia.gc_preserve_begin" || LLVM.name(f) == "llvm.julia.gc_preserve_end" + if LLVM.name(f) == "llvm.julia.gc_preserve_begin" || + LLVM.name(f) == "llvm.julia.gc_preserve_end" return true end for attr in collect(function_attributes(cur)) @@ -160,12 +171,12 @@ function is_readnone(f::LLVM.Function) return true end if LLVM.version().major > 15 - if kind(attr) == kind(EnumAttribute("memory")) - if is_readnone(MemoryEffect(value(attr))) - return true + if kind(attr) == kind(EnumAttribute("memory")) + if is_readnone(MemoryEffect(value(attr))) + return true + end end end - end end return false end @@ -181,7 +192,8 @@ function is_writeonly(f::LLVM.Function) if intr == LLVM.Intrinsic("llvm.assume").id return true end - if LLVM.name(f) == "llvm.julia.gc_preserve_begin" || LLVM.name(f) == "llvm.julia.gc_preserve_end" + if LLVM.name(f) == "llvm.julia.gc_preserve_begin" || + LLVM.name(f) == "llvm.julia.gc_preserve_end" return true end for attr in collect(function_attributes(cur)) @@ -192,12 +204,12 @@ function is_writeonly(f::LLVM.Function) return true end if LLVM.version().major > 15 - if kind(attr) == kind(EnumAttribute("memory")) - if is_writeonly(MemoryEffect(value(attr))) - return true + if kind(attr) == kind(EnumAttribute("memory")) + if is_writeonly(MemoryEffect(value(attr))) + return true + end end end - end end return false end @@ -205,7 +217,8 @@ end function set_readonly!(fn::LLVM.Function) attrs = collect(function_attributes(fn)) if LLVM.version().major <= 15 - if !any(kind(attr) == kind(EnumAttribute("readonly")) for attr in attrs) && !any(kind(attr) == kind(EnumAttribute("readnone")) for attr in attrs) + if !any(kind(attr) == kind(EnumAttribute("readonly")) for attr in attrs) && + !any(kind(attr) == kind(EnumAttribute("readnone")) for attr in attrs) if any(kind(attr) == kind(EnumAttribute("writeonly")) for attr in attrs) delete!(function_attributes(fn), EnumAttribute("writeonly")) push!(function_attributes(fn), EnumAttribute("readnone")) @@ -224,12 +237,20 @@ function set_readonly!(fn::LLVM.Function) return old != eff end end - push!(function_attributes(fn), EnumAttribute("memory", set_readonly(AllEffects).data)) + push!( + function_attributes(fn), + EnumAttribute("memory", set_readonly(AllEffects).data), + ) return true end end -function get_function!(mod::LLVM.Module, name::AbstractString, FT::LLVM.FunctionType, attrs=[]) +function get_function!( + mod::LLVM.Module, + name::AbstractString, + FT::LLVM.FunctionType, + attrs = [], +) if haskey(functions(mod), name) F = functions(mod)[name] PT = LLVM.PointerType(FT) @@ -261,8 +282,12 @@ T_ppjlvalue() = LLVM.PointerType(LLVM.PointerType(LLVM.StructType(LLVMType[]))) return v end -function declare_pgcstack!(mod) - get_function!(mod, "julia.get_pgcstack", LLVM.FunctionType(LLVM.PointerType(T_ppjlvalue()))) +function declare_pgcstack!(mod) + get_function!( + mod, + "julia.get_pgcstack", + LLVM.FunctionType(LLVM.PointerType(T_ppjlvalue())), + ) end function emit_pgcstack(B) @@ -285,14 +310,19 @@ function get_pgcstack(func) return nothing end -function reinsert_gcmarker!(func, PB=nothing) +function reinsert_gcmarker!(func, PB = nothing) for (i, v) in enumerate(parameters(func)) - if any(map(k->kind(k)==kind(EnumAttribute("swiftself")), collect(parameter_attributes(func, i)))) + if any( + map( + k -> kind(k) == kind(EnumAttribute("swiftself")), + collect(parameter_attributes(func, i)), + ), + ) return v end end - pgs = get_pgcstack(func) + pgs = get_pgcstack(func) if pgs === nothing context(LLVM.parent(func)) B = IRBuilder() @@ -303,13 +333,13 @@ function reinsert_gcmarker!(func, PB=nothing) position!(B, entry_bb) end emit_pgcstack(B) - else + else entry_bb = first(blocks(func)) fst = first(instructions(entry_bb)) if fst != pgs API.moveBefore(pgs, fst, PB === nothing ? C_NULL : PB.ref) end - pgs + pgs end end @@ -332,7 +362,7 @@ function unique_gcmarker!(func) end end if length(found) > 1 - for i in 2:length(found) + for i = 2:length(found) LLVM.replace_uses!(found[i], found[1]) ops = LLVM.collect(operands(found[i])) eraseInst(entry_bb, found[i]) @@ -341,7 +371,8 @@ function unique_gcmarker!(func) return nothing end -@inline AnonymousStruct(::Type{U}) where U<:Tuple = NamedTuple{ntuple(i->Symbol(i), Val(length(U.parameters))), U} +@inline AnonymousStruct(::Type{U}) where {U<:Tuple} = + NamedTuple{ntuple(i -> Symbol(i), Val(length(U.parameters))),U} # recursively compute the eltype type indexed by idx[0], idx[1], ... function recursive_eltype(val::LLVM.Value, idxs::Vector{Cuint}) @@ -359,7 +390,15 @@ end # Fix calling convention within julia that Tuple{Float,Float} ->[2 x float] rather than {float, float} # and that Bool -> i8, not i1 -function calling_conv_fixup(builder, val::LLVM.Value, tape::LLVM.LLVMType, prev::LLVM.Value=LLVM.UndefValue(tape), lidxs::Vector{Cuint}=Cuint[], ridxs::Vector{Cuint}=Cuint[], emesg=nothing)::LLVM.Value +function calling_conv_fixup( + builder, + val::LLVM.Value, + tape::LLVM.LLVMType, + prev::LLVM.Value = LLVM.UndefValue(tape), + lidxs::Vector{Cuint} = Cuint[], + ridxs::Vector{Cuint} = Cuint[], + emesg = nothing, +)::LLVM.Value ctype = recursive_eltype(val, lidxs) if ctype == tape if length(lidxs) != 0 @@ -377,9 +416,9 @@ function calling_conv_fixup(builder, val::LLVM.Value, tape::LLVM.LLVMType, prev: @assert length(ctype) == length(elements(tape)) for (i, ty) in enumerate(elements(tape)) ln = copy(lidxs) - push!(ln, i-1) + push!(ln, i - 1) rn = copy(ridxs) - push!(rn, i-1) + push!(rn, i - 1) prev = calling_conv_fixup(builder, val, ty, prev, ln, rn, emesg) end return prev @@ -388,9 +427,9 @@ function calling_conv_fixup(builder, val::LLVM.Value, tape::LLVM.LLVMType, prev: @assert length(elements(ctype)) == length(elements(tape)) for (i, ty) in enumerate(elements(tape)) ln = copy(lidxs) - push!(ln, i-1) + push!(ln, i - 1) rn = copy(ridxs) - push!(rn, i-1) + push!(rn, i - 1) prev = calling_conv_fixup(builder, val, ty, prev, ln, rn, emesg) end return prev @@ -398,29 +437,31 @@ function calling_conv_fixup(builder, val::LLVM.Value, tape::LLVM.LLVMType, prev: elseif isa(tape, LLVM.ArrayType) if isa(ctype, LLVM.ArrayType) @assert length(ctype) == length(tape) - for i in 1:length(tape) + for i = 1:length(tape) ln = copy(lidxs) - push!(ln, i-1) + push!(ln, i - 1) rn = copy(ridxs) - push!(rn, i-1) + push!(rn, i - 1) prev = calling_conv_fixup(builder, val, eltype(tape), prev, ln, rn, emesg) end return prev end if isa(ctype, LLVM.StructType) @assert length(elements(ctype)) == length(tape) - for i in 1:length(tape) + for i = 1:length(tape) ln = copy(lidxs) - push!(ln, i-1) + push!(ln, i - 1) rn = copy(ridxs) - push!(rn, i-1) + push!(rn, i - 1) prev = calling_conv_fixup(builder, val, eltype(tape), prev, ln, rn, emesg) end return prev end end - if isa(tape, LLVM.IntegerType) && LLVM.width(tape) == 1 && LLVM.width(ctype) != LLVM.width(tape) + if isa(tape, LLVM.IntegerType) && + LLVM.width(tape) == 1 && + LLVM.width(ctype) != LLVM.width(tape) if length(lidxs) != 0 val = API.e_extract_value!(builder, val, lidxs) end @@ -431,7 +472,9 @@ function calling_conv_fixup(builder, val::LLVM.Value, tape::LLVM.LLVMType, prev: val end end - if isa(tape, LLVM.PointerType) && isa(ctype, LLVM.PointerType) && LLVM.addrspace(tape) == LLVM.addrspace(ctype) + if isa(tape, LLVM.PointerType) && + isa(ctype, LLVM.PointerType) && + LLVM.addrspace(tape) == LLVM.addrspace(ctype) if length(lidxs) != 0 val = API.e_extract_value!(builder, val, lidxs) end @@ -451,7 +494,7 @@ function calling_conv_fixup(builder, val::LLVM.Value, tape::LLVM.LLVMType, prev: msg2 = sprint() do io println(io, "Enzyme Internal Error: Illegal calling convention fixup") - if emesg !== nothing + if emesg !== nothing emesg(io) end println(io, "ctype = ", ctype) @@ -461,7 +504,11 @@ function calling_conv_fixup(builder, val::LLVM.Value, tape::LLVM.LLVMType, prev: println(io, "lidxs = ", lidxs) println(io, "ridxs = ", ridxs) println(io, "tape_type(tape) = ", tape_type(tape)) - println(io, "convert(LLVMType, tape_type(tape)) = ", convert(LLVM.LLVMType, tape_type(tape); allow_boxed=true)) + println( + io, + "convert(LLVMType, tape_type(tape)) = ", + convert(LLVM.LLVMType, tape_type(tape); allow_boxed = true), + ) end throw(AssertionError(msg2)) end diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 3df37be1175..b672c50f57e 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -3,105 +3,139 @@ using ObjectFile using Libdl module FFI - using LLVM - module BLASSupport - # TODO: LAPACK handling - using LinearAlgebra - using ObjectFile - using Libdl - function __init__() - global blas_handle = Libdl.dlopen(BLAS.libblastrampoline) - end - function get_blas_symbols() - symbols = BLAS.get_config().exported_symbols - if BLAS.USE_BLAS64 - return map(n->n*"64_", symbols) - end - return symbols - end - - function lookup_blas_symbol(name) - Libdl.dlsym(blas_handle::Ptr{Cvoid}, name; throw_error=false) - end +using LLVM +module BLASSupport +# TODO: LAPACK handling +using LinearAlgebra +using ObjectFile +using Libdl +function __init__() + global blas_handle = Libdl.dlopen(BLAS.libblastrampoline) +end +function get_blas_symbols() + symbols = BLAS.get_config().exported_symbols + if BLAS.USE_BLAS64 + return map(n -> n * "64_", symbols) end + return symbols +end - const ptr_map = Dict{Ptr{Cvoid},String}() - - function __init__() - known_names = ( - "jl_alloc_array_1d", "jl_alloc_array_2d", "jl_alloc_array_3d", - "ijl_alloc_array_1d", "ijl_alloc_array_2d", "ijl_alloc_array_3d", - "jl_new_array", "ijl_new_array", - "jl_array_copy", "ijl_array_copy", - "jl_alloc_string", - "jl_in_threaded_region", "jl_enter_threaded_region", "jl_exit_threaded_region", "jl_set_task_tid", "jl_new_task", - "malloc", "memmove", "memcpy", "memset", - "jl_array_grow_beg", "ijl_array_grow_beg", - "jl_array_grow_end", "ijl_array_grow_end", - "jl_array_grow_at", "ijl_array_grow_at", - "jl_array_del_beg", "ijl_array_del_beg", - "jl_array_del_end", "ijl_array_del_end", - "jl_array_del_at", "ijl_array_del_at", - "jl_array_ptr", "ijl_array_ptr", - "jl_value_ptr", "jl_get_ptls_states", "jl_gc_add_finalizer_th", - "jl_symbol_n", "jl_", "jl_object_id", - "jl_reshape_array","ijl_reshape_array", - "jl_matching_methods", "ijl_matching_methods", - "jl_array_sizehint", "ijl_array_sizehint", - "jl_get_keyword_sorter", "ijl_get_keyword_sorter", - "jl_ptr_to_array", - "jl_box_float32", - "ijl_box_float32", - "jl_box_float64", - "ijl_box_float64", - "jl_ptr_to_array_1d", - "jl_eqtable_get", "ijl_eqtable_get", - "memcmp","memchr", - "jl_get_nth_field_checked", "ijl_get_nth_field_checked", - "jl_stored_inline", - "ijl_stored_inline", - "jl_array_isassigned", "ijl_array_isassigned", - "jl_array_ptr_copy", "ijl_array_ptr_copy", - "jl_array_typetagdata", "ijl_array_typetagdata", - "jl_idtable_rehash" - ) - for name in known_names - sym = LLVM.find_symbol(name) - if sym == C_NULL +function lookup_blas_symbol(name) + Libdl.dlsym(blas_handle::Ptr{Cvoid}, name; throw_error = false) +end +end + +const ptr_map = Dict{Ptr{Cvoid},String}() + +function __init__() + known_names = ( + "jl_alloc_array_1d", + "jl_alloc_array_2d", + "jl_alloc_array_3d", + "ijl_alloc_array_1d", + "ijl_alloc_array_2d", + "ijl_alloc_array_3d", + "jl_new_array", + "ijl_new_array", + "jl_array_copy", + "ijl_array_copy", + "jl_alloc_string", + "jl_in_threaded_region", + "jl_enter_threaded_region", + "jl_exit_threaded_region", + "jl_set_task_tid", + "jl_new_task", + "malloc", + "memmove", + "memcpy", + "memset", + "jl_array_grow_beg", + "ijl_array_grow_beg", + "jl_array_grow_end", + "ijl_array_grow_end", + "jl_array_grow_at", + "ijl_array_grow_at", + "jl_array_del_beg", + "ijl_array_del_beg", + "jl_array_del_end", + "ijl_array_del_end", + "jl_array_del_at", + "ijl_array_del_at", + "jl_array_ptr", + "ijl_array_ptr", + "jl_value_ptr", + "jl_get_ptls_states", + "jl_gc_add_finalizer_th", + "jl_symbol_n", + "jl_", + "jl_object_id", + "jl_reshape_array", + "ijl_reshape_array", + "jl_matching_methods", + "ijl_matching_methods", + "jl_array_sizehint", + "ijl_array_sizehint", + "jl_get_keyword_sorter", + "ijl_get_keyword_sorter", + "jl_ptr_to_array", + "jl_box_float32", + "ijl_box_float32", + "jl_box_float64", + "ijl_box_float64", + "jl_ptr_to_array_1d", + "jl_eqtable_get", + "ijl_eqtable_get", + "memcmp", + "memchr", + "jl_get_nth_field_checked", + "ijl_get_nth_field_checked", + "jl_stored_inline", + "ijl_stored_inline", + "jl_array_isassigned", + "ijl_array_isassigned", + "jl_array_ptr_copy", + "ijl_array_ptr_copy", + "jl_array_typetagdata", + "ijl_array_typetagdata", + "jl_idtable_rehash", + ) + for name in known_names + sym = LLVM.find_symbol(name) + if sym == C_NULL + continue + end + if haskey(ptr_map, sym) + # On MacOS memcpy and memmove seem to collide? + if name == "memcpy" continue end - if haskey(ptr_map, sym) - # On MacOS memcpy and memmove seem to collide? - if name == "memcpy" - continue - end - end - @assert !haskey(ptr_map, sym) - ptr_map[sym] = name end - for sym in BLASSupport.get_blas_symbols() - ptr = BLASSupport.lookup_blas_symbol(sym) - if ptr !== nothing - if haskey(ptr_map, ptr) - if ptr_map[ptr] != sym - @warn "Duplicated symbol in ptr_map" ptr, sym, ptr_map[ptr] - end - continue + @assert !haskey(ptr_map, sym) + ptr_map[sym] = name + end + for sym in BLASSupport.get_blas_symbols() + ptr = BLASSupport.lookup_blas_symbol(sym) + if ptr !== nothing + if haskey(ptr_map, ptr) + if ptr_map[ptr] != sym + @warn "Duplicated symbol in ptr_map" ptr, sym, ptr_map[ptr] end - ptr_map[ptr] = sym + continue end + ptr_map[ptr] = sym end end +end - function memoize!(ptr, fn) - fn = get(ptr_map, ptr, fn) - if !haskey(ptr_map, ptr) - ptr_map[ptr] = fn - else - @assert ptr_map[ptr] == fn - end - return fn +function memoize!(ptr, fn) + fn = get(ptr_map, ptr, fn) + if !haskey(ptr_map, ptr) + ptr_map[ptr] = fn + else + @assert ptr_map[ptr] == fn end + return fn +end end import GPUCompiler: IRError, InvalidIRError @@ -111,7 +145,15 @@ function restore_lookups(mod::LLVM.Module) for (v, k) in FFI.ptr_map if haskey(functions(mod), k) f = functions(mod)[k] - replace_uses!(f, LLVM.Value(LLVM.API.LLVMConstIntToPtr(ConstantInt(T_size_t, convert(UInt, v)), value_type(f)))) + replace_uses!( + f, + LLVM.Value( + LLVM.API.LLVMConstIntToPtr( + ConstantInt(T_size_t, convert(UInt, v)), + value_type(f), + ), + ), + ) eraseInst(mod, f) end end @@ -128,7 +170,7 @@ end # Rewrite calls with "jl_roots" to only have the jl_value_t attached and not { { {} addrspace(10)*, [1 x [2 x i64]], i64, i64 }, [2 x i64] } %unbox110183_replacementA function rewrite_ccalls!(mod::LLVM.Module) for f in collect(functions(mod)) - replaceAndErase = Tuple{Instruction, Instruction}[] + replaceAndErase = Tuple{Instruction,Instruction}[] for bb in blocks(f), inst in instructions(bb) if isa(inst, LLVM.CallInst) fn = called_operand(inst) @@ -160,17 +202,45 @@ function rewrite_ccalls!(mod::LLVM.Module) prevname = LLVM.name(inst) LLVM.name!(inst, "") if !isdefined(LLVM, :OperandBundleDef) - newinst = call!(B, called_type(inst), called_operand(inst), uservals, collect(operand_bundles(inst)), prevname) - else - newinst = call!(B, called_type(inst), called_operand(inst), uservals, collect(map(LLVM.OperandBundleDef, operand_bundles(inst))), prevname) - end - for idx = [LLVM.API.LLVMAttributeFunctionIndex, LLVM.API.LLVMAttributeReturnIndex, [LLVM.API.LLVMAttributeIndex(i) for i in 1:(length(arguments(inst)))]...] + newinst = call!( + B, + called_type(inst), + called_operand(inst), + uservals, + collect(operand_bundles(inst)), + prevname, + ) + else + newinst = call!( + B, + called_type(inst), + called_operand(inst), + uservals, + collect(map(LLVM.OperandBundleDef, operand_bundles(inst))), + prevname, + ) + end + for idx in [ + LLVM.API.LLVMAttributeFunctionIndex, + LLVM.API.LLVMAttributeReturnIndex, + [ + LLVM.API.LLVMAttributeIndex(i) for + i = 1:(length(arguments(inst))) + ]..., + ] idx = reinterpret(LLVM.API.LLVMAttributeIndex, idx) - count = LLVM.API.LLVMGetCallSiteAttributeCount(inst, idx); - Attrs = Base.unsafe_convert(Ptr{LLVM.API.LLVMAttributeRef}, Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef)*count)) + count = LLVM.API.LLVMGetCallSiteAttributeCount(inst, idx) + Attrs = Base.unsafe_convert( + Ptr{LLVM.API.LLVMAttributeRef}, + Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef) * count), + ) LLVM.API.LLVMGetCallSiteAttributes(inst, idx, Attrs) - for j in 1:count - LLVM.API.LLVMAddCallSiteAttribute(newinst, idx, unsafe_load(Attrs, j)) + for j = 1:count + LLVM.API.LLVMAddCallSiteAttribute( + newinst, + idx, + unsafe_load(Attrs, j), + ) end Libc.free(Attrs) end @@ -181,26 +251,26 @@ function rewrite_ccalls!(mod::LLVM.Module) continue end if !isdefined(LLVM, :OperandBundleDef) - newbundles = OperandBundle[] - else - newbundles = OperandBundleDef[] - end - for bunduse in operand_bundles(inst) + newbundles = OperandBundle[] + else + newbundles = OperandBundleDef[] + end + for bunduse in operand_bundles(inst) if isdefined(LLVM, :OperandBundleDef) - bunduse = LLVM.OperandBundleDef(bunduse) - end + bunduse = LLVM.OperandBundleDef(bunduse) + end if !isdefined(LLVM, :OperandBundleDef) - if LLVM.tag(bunduse) != "jl_roots" - push!(newbundles, bunduse) - continue - end - else - if LLVM.tag_name(bunduse) != "jl_roots" - push!(newbundles, bunduse) - continue - end - end + if LLVM.tag(bunduse) != "jl_roots" + push!(newbundles, bunduse) + continue + end + else + if LLVM.tag_name(bunduse) != "jl_roots" + push!(newbundles, bunduse) + continue + end + end uservals = LLVM.Value[] subchanged = false for lval in LLVM.inputs(bunduse) @@ -228,23 +298,47 @@ function rewrite_ccalls!(mod::LLVM.Module) end changed = true if !isdefined(LLVM, :OperandBundleDef) - push!(newbundles, OperandBundle(LLVM.tag(bunduse), uservals)) + push!(newbundles, OperandBundle(LLVM.tag(bunduse), uservals)) else - push!(newbundles, OperandBundleDef(LLVM.tag_name(bunduse), uservals)) + push!( + newbundles, + OperandBundleDef(LLVM.tag_name(bunduse), uservals), + ) end end changed = false if changed prevname = LLVM.name(inst) LLVM.name!(inst, "") - newinst = call!(B, called_type(inst), called_operand(inst), collect(arguments(inst)), newbundles, prevname) - for idx = [LLVM.API.LLVMAttributeFunctionIndex, LLVM.API.LLVMAttributeReturnIndex, [LLVM.API.LLVMAttributeIndex(i) for i in 1:(length(arguments(inst)))]...] + newinst = call!( + B, + called_type(inst), + called_operand(inst), + collect(arguments(inst)), + newbundles, + prevname, + ) + for idx in [ + LLVM.API.LLVMAttributeFunctionIndex, + LLVM.API.LLVMAttributeReturnIndex, + [ + LLVM.API.LLVMAttributeIndex(i) for + i = 1:(length(arguments(inst))) + ]..., + ] idx = reinterpret(LLVM.API.LLVMAttributeIndex, idx) - count = LLVM.API.LLVMGetCallSiteAttributeCount(inst, idx); - Attrs = Base.unsafe_convert(Ptr{LLVM.API.LLVMAttributeRef}, Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef)*count)) + count = LLVM.API.LLVMGetCallSiteAttributeCount(inst, idx) + Attrs = Base.unsafe_convert( + Ptr{LLVM.API.LLVMAttributeRef}, + Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef) * count), + ) LLVM.API.LLVMGetCallSiteAttributes(inst, idx, Attrs) - for j in 1:count - LLVM.API.LLVMAddCallSiteAttribute(newinst, idx, unsafe_load(Attrs, j)) + for j = 1:count + LLVM.API.LLVMAddCallSiteAttribute( + newinst, + idx, + unsafe_load(Attrs, j), + ) end Libc.free(Attrs) end @@ -270,7 +364,11 @@ function check_ir!(job, errors, mod::LLVM.Module) prev_ft = eltype(value_type(f)::LLVM.PointerType)::LLVM.FunctionType - mfn = LLVM.API.LLVMAddFunction(mod, "malloc", LLVM.FunctionType(ptr8, parameters(prev_ft))) + mfn = LLVM.API.LLVMAddFunction( + mod, + "malloc", + LLVM.FunctionType(ptr8, parameters(prev_ft)), + ) replace_uses!(f, LLVM.Value(LLVM.API.LLVMConstPointerCast(mfn, value_type(f)))) eraseInst(mod, f) end @@ -291,11 +389,18 @@ function check_ir!(job, errors, imported, f::LLVM.Function) for bb in blocks(f), inst in instructions(bb) if isa(inst, LLVM.CallInst) push!(calls, inst) - # remove illegal invariant.load and jtbaa_const invariants + # remove illegal invariant.load and jtbaa_const invariants elseif isInline && isa(inst, LLVM.LoadInst) md = metadata(inst) if haskey(md, LLVM.MD_tbaa) - modified = LLVM.Metadata(ccall((:EnzymeMakeNonConstTBAA, API.libEnzyme), LLVM.API.LLVMMetadataRef, (LLVM.API.LLVMMetadataRef,), md[LLVM.MD_tbaa])) + modified = LLVM.Metadata( + ccall( + (:EnzymeMakeNonConstTBAA, API.libEnzyme), + LLVM.API.LLVMMetadataRef, + (LLVM.API.LLVMMetadataRef,), + md[LLVM.MD_tbaa], + ), + ) setindex!(md, modified, LLVM.MD_tbaa) end if haskey(md, LLVM.MD_invariant_load) @@ -314,7 +419,18 @@ end const libjulia = Ref{Ptr{Cvoid}}(C_NULL) # List of methods to location of arg which is the mi/function, then start of args -const generic_method_offsets = Dict{String, Tuple{Int,Int}}(("jl_f__apply_latest" => (2,3), "ijl_f__apply_latest" => (2,3), "jl_f__call_latest" => (2,3), "ijl_f__call_latest" => (2,3), "jl_f_invoke" => (2,3), "jl_invoke" => (1,3), "jl_apply_generic" => (1,2), "ijl_f_invoke" => (2,3), "ijl_invoke" => (1,3), "ijl_apply_generic" => (1,2))) +const generic_method_offsets = Dict{String,Tuple{Int,Int}}(( + "jl_f__apply_latest" => (2, 3), + "ijl_f__apply_latest" => (2, 3), + "jl_f__call_latest" => (2, 3), + "ijl_f__call_latest" => (2, 3), + "jl_f_invoke" => (2, 3), + "jl_invoke" => (1, 3), + "jl_apply_generic" => (1, 2), + "ijl_f_invoke" => (2, 3), + "ijl_invoke" => (1, 3), + "ijl_apply_generic" => (1, 2), +)) @inline function has_method(sig, world::UInt, mt::Union{Nothing,Core.MethodTable}) return ccall(:jl_gf_invoke_lookup, Any, (Any, Any, UInt), sig, mt, world) !== nothing @@ -330,16 +446,17 @@ end @inline function is_inactive(tys, world::UInt, mt) specTypes = Interpreter.simplify_kw(Tuple{tys...}) - if has_method(Tuple{typeof(EnzymeRules.inactive), tys...}, world, mt) + if has_method(Tuple{typeof(EnzymeRules.inactive),tys...}, world, mt) return true end - if has_method(Tuple{typeof(EnzymeRules.inactive_noinl), tys...}, world, mt) + if has_method(Tuple{typeof(EnzymeRules.inactive_noinl),tys...}, world, mt) return true end return false end -import GPUCompiler: DYNAMIC_CALL, DELAYED_BINDING, RUNTIME_FUNCTION, UNKNOWN_FUNCTION, POINTER_FUNCTION +import GPUCompiler: + DYNAMIC_CALL, DELAYED_BINDING, RUNTIME_FUNCTION, UNKNOWN_FUNCTION, POINTER_FUNCTION import GPUCompiler: backtrace, isintrinsic function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) world = job.world @@ -363,12 +480,28 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) mfn = LLVM.API.LLVMGetNamedFunction(mod, "malloc") if mfn == C_NULL ptr8 = LLVM.PointerType(LLVM.IntType(8)) - mfn = LLVM.API.LLVMAddFunction(mod, "malloc", LLVM.FunctionType(ptr8, [value_type(LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(inst, 0)))])) + mfn = LLVM.API.LLVMAddFunction( + mod, + "malloc", + LLVM.FunctionType( + ptr8, + [value_type(LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(inst, 0)))], + ), + ) end mfn2 = LLVM.Function(mfn) - nval = ptrtoint!(b, call!(b, LLVM.function_type(mfn2), mfn2, [LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(inst, 0))]), value_type(inst)) + nval = ptrtoint!( + b, + call!( + b, + LLVM.function_type(mfn2), + mfn2, + [LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(inst, 0))], + ), + value_type(inst), + ) replace_uses!(inst, nval) - LLVM.API.LLVMInstructionEraseFromParent(inst) + LLVM.API.LLVMInstructionEraseFromParent(inst) elseif fn == "jl_load_and_lookup" || fn == "ijl_load_and_lookup" ofn = LLVM.parent(LLVM.parent(inst)) mod = LLVM.parent(ofn) @@ -376,7 +509,9 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) arg1 = operands(inst)[1] while isa(arg1, ConstantExpr) - if opcode(arg1) == LLVM.API.LLVMAddrSpaceCast || opcode(arg1) == LLVM.API.LLVMBitCast || opcode(arg1) == LLVM.API.LLVMIntToPtr + if opcode(arg1) == LLVM.API.LLVMAddrSpaceCast || + opcode(arg1) == LLVM.API.LLVMBitCast || + opcode(arg1) == LLVM.API.LLVMIntToPtr arg1 = operands(arg1)[1] else break @@ -389,71 +524,106 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) hnd = operands(inst)[3] if isa(hnd, LLVM.GlobalVariable) hnd = LLVM.name(hnd) - if fn == "jl_lazy_load_and_lookup" - res = ccall(:jl_load_and_lookup, Ptr{Cvoid}, (Ptr{Cvoid}, Cstring, Ptr{Cvoid}), arg1, fname, reinterpret(Ptr{Cvoid}, JIT.lookup(nothing, hnd).ptr)) - else - res = ccall(:ijl_load_and_lookup, Ptr{Cvoid}, (Ptr{Cvoid}, Cstring, Ptr{Cvoid}), arg1, fname, reinterpret(Ptr{Cvoid}, JIT.lookup(nothing, hnd).ptr)) - end - replaceWith = LLVM.ConstantInt(LLVM.IntType(8*sizeof(Int)), reinterpret(UInt, res)) - for u in LLVM.uses(inst) - st = LLVM.user(u) - if isa(st, LLVM.StoreInst) && LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(st, 0)) == inst - ptr = LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(st, 1)) - for u in LLVM.uses(ptr) - ld = LLVM.user(u) - if isa(ld, LLVM.LoadInst) - b = IRBuilder() - position!(b, ld) - for u in LLVM.uses(ld) - u = LLVM.user(u) - if isa(u, LLVM.CallInst) - push!(calls, u) + if fn == "jl_lazy_load_and_lookup" + res = ccall( + :jl_load_and_lookup, + Ptr{Cvoid}, + (Ptr{Cvoid}, Cstring, Ptr{Cvoid}), + arg1, + fname, + reinterpret(Ptr{Cvoid}, JIT.lookup(nothing, hnd).ptr), + ) + else + res = ccall( + :ijl_load_and_lookup, + Ptr{Cvoid}, + (Ptr{Cvoid}, Cstring, Ptr{Cvoid}), + arg1, + fname, + reinterpret(Ptr{Cvoid}, JIT.lookup(nothing, hnd).ptr), + ) + end + replaceWith = LLVM.ConstantInt( + LLVM.IntType(8 * sizeof(Int)), + reinterpret(UInt, res), + ) + for u in LLVM.uses(inst) + st = LLVM.user(u) + if isa(st, LLVM.StoreInst) && + LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(st, 0)) == inst + ptr = LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(st, 1)) + for u in LLVM.uses(ptr) + ld = LLVM.user(u) + if isa(ld, LLVM.LoadInst) + b = IRBuilder() + position!(b, ld) + for u in LLVM.uses(ld) + u = LLVM.user(u) + if isa(u, LLVM.CallInst) + push!(calls, u) + end end + replace_uses!( + ld, + LLVM.inttoptr!( + b, + replaceWith, + value_type(inst), + ), + ) end - replace_uses!(ld, LLVM.inttoptr!(b, replaceWith, value_type(inst))) end end end - end - b = IRBuilder() - position!(b, inst) - replacement = LLVM.inttoptr!(b, replaceWith, value_type(inst)) - for u in LLVM.uses(inst) - u = LLVM.user(u) - if isa(u, LLVM.CallInst) - push!(calls, u) - end - if isa(u, LLVM.PHIInst) - if all(x->first(x) == inst || first(x) == replacement, LLVM.incoming(u)) + b = IRBuilder() + position!(b, inst) + replacement = LLVM.inttoptr!(b, replaceWith, value_type(inst)) + for u in LLVM.uses(inst) + u = LLVM.user(u) + if isa(u, LLVM.CallInst) + push!(calls, u) + end + if isa(u, LLVM.PHIInst) + if all( + x -> first(x) == inst || first(x) == replacement, + LLVM.incoming(u), + ) - for u in LLVM.uses(u) - u = LLVM.user(u) - if isa(u, LLVM.CallInst) - push!(calls, u) - end - if isa(u, LLVM.BitCastInst) - for u1 in LLVM.uses(u) - u1 = LLVM.user(u1) - if isa(u1, LLVM.CallInst) - push!(calls, u1) - end - end - replace_uses!(u, LLVM.inttoptr!(b, replaceWith, value_type(u))) + for u in LLVM.uses(u) + u = LLVM.user(u) + if isa(u, LLVM.CallInst) + push!(calls, u) + end + if isa(u, LLVM.BitCastInst) + for u1 in LLVM.uses(u) + u1 = LLVM.user(u1) + if isa(u1, LLVM.CallInst) + push!(calls, u1) end end + replace_uses!( + u, + LLVM.inttoptr!( + b, + replaceWith, + value_type(u), + ), + ) end end end - replace_uses!(inst, replacement) - LLVM.API.LLVMInstructionEraseFromParent(inst) + end + end + replace_uses!(inst, replacement) + LLVM.API.LLVMInstructionEraseFromParent(inst) end end end - + elseif fn == "jl_lazy_load_and_lookup" || fn == "ijl_lazy_load_and_lookup" ofn = LLVM.parent(LLVM.parent(inst)) mod = LLVM.parent(ofn) @@ -469,7 +639,7 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) op = LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(op, 0)) end if isa(op, ConstantInt) - rep = reinterpret(Ptr{Cvoid}, convert(Csize_t, op)+8) + rep = reinterpret(Ptr{Cvoid}, convert(Csize_t, op) + 8) ld = unsafe_load(convert(Ptr{Ptr{Cvoid}}, rep)) flib = Base.unsafe_pointer_to_objref(ld) end @@ -485,8 +655,9 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) if isa(fname, LLVM.GlobalVariable) fname = LLVM.initializer(fname) end - if (isa(fname, LLVM.ConstantArray) || isa(fname, LLVM.ConstantDataArray)) && eltype(value_type(fname)) == LLVM.IntType(8) - fname = String(map((x)->convert(UInt8, x), collect(fname)[1:(end-1)])) + if (isa(fname, LLVM.ConstantArray) || isa(fname, LLVM.ConstantDataArray)) && + eltype(value_type(fname)) == LLVM.IntType(8) + fname = String(map((x) -> convert(UInt8, x), collect(fname)[1:(end-1)])) end if !isa(fname, String) || !isa(flib, String) @@ -494,7 +665,7 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) end found = false - + try data = open(flib, "r") do io lib = readmeta(io) @@ -537,14 +708,18 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) for u in LLVM.uses(inst) st = LLVM.user(u) - if isa(st, LLVM.StoreInst) && LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(st, 0)) == inst + if isa(st, LLVM.StoreInst) && + LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(st, 0)) == inst ptr = LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(st, 1)) for u in LLVM.uses(ptr) ld = LLVM.user(u) if isa(ld, LLVM.LoadInst) b = IRBuilder() position!(b, ld) - replace_uses!(ld, LLVM.pointercast!(b, replaceWith, value_type(inst))) + replace_uses!( + ld, + LLVM.pointercast!(b, replaceWith, value_type(inst)), + ) end end end @@ -558,14 +733,28 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) else if fn == "jl_lazy_load_and_lookup" - res = ccall(:jl_lazy_load_and_lookup, Ptr{Cvoid}, (Any, Cstring), flib, fname) + res = ccall( + :jl_lazy_load_and_lookup, + Ptr{Cvoid}, + (Any, Cstring), + flib, + fname, + ) else - res = ccall(:ijl_lazy_load_and_lookup, Ptr{Cvoid}, (Any, Cstring), flib, fname) + res = ccall( + :ijl_lazy_load_and_lookup, + Ptr{Cvoid}, + (Any, Cstring), + flib, + fname, + ) end - replaceWith = LLVM.ConstantInt(LLVM.IntType(8*sizeof(Int)), reinterpret(UInt, res)) + replaceWith = + LLVM.ConstantInt(LLVM.IntType(8 * sizeof(Int)), reinterpret(UInt, res)) for u in LLVM.uses(inst) st = LLVM.user(u) - if isa(st, LLVM.StoreInst) && LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(st, 0)) == inst + if isa(st, LLVM.StoreInst) && + LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(st, 0)) == inst ptr = LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(st, 1)) for u in LLVM.uses(ptr) ld = LLVM.user(u) @@ -578,7 +767,10 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) push!(calls, u) end end - replace_uses!(ld, LLVM.inttoptr!(b, replaceWith, value_type(inst))) + replace_uses!( + ld, + LLVM.inttoptr!(b, replaceWith, value_type(inst)), + ) end end end @@ -587,32 +779,38 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) b = IRBuilder() position!(b, inst) replacement = LLVM.inttoptr!(b, replaceWith, value_type(inst)) - for u in LLVM.uses(inst) + for u in LLVM.uses(inst) + u = LLVM.user(u) + if isa(u, LLVM.CallInst) + push!(calls, u) + end + if isa(u, LLVM.PHIInst) + if all( + x -> first(x) == inst || first(x) == replacement, + LLVM.incoming(u), + ) + + for u in LLVM.uses(u) u = LLVM.user(u) if isa(u, LLVM.CallInst) push!(calls, u) end - if isa(u, LLVM.PHIInst) - if all(x->first(x) == inst || first(x) == replacement, LLVM.incoming(u)) - - for u in LLVM.uses(u) - u = LLVM.user(u) - if isa(u, LLVM.CallInst) - push!(calls, u) - end - if isa(u, LLVM.BitCastInst) - for u1 in LLVM.uses(u) - u1 = LLVM.user(u1) - if isa(u1, LLVM.CallInst) - push!(calls, u1) - end - end - replace_uses!(u, LLVM.inttoptr!(b, replaceWith, value_type(u))) - end + if isa(u, LLVM.BitCastInst) + for u1 in LLVM.uses(u) + u1 = LLVM.user(u1) + if isa(u1, LLVM.CallInst) + push!(calls, u1) end end + replace_uses!( + u, + LLVM.inttoptr!(b, replaceWith, value_type(u)), + ) end end + end + end + end replace_uses!(inst, replacement) LLVM.API.LLVMInstructionEraseFromParent(inst) end @@ -622,29 +820,77 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) if isa(dest, LLVM.Function) && LLVM.name(dest) == "jl_f__apply_iterate" # Add 1 to account for function being first arg iteroff = 2 - + legal, iterlib = absint(operands(inst)[iteroff+1]) if legal && iterlib == Base.iterate - legal, GT = abs_typeof(operands(inst)[4+1], true) + legal, GT, byref = abs_typeof(operands(inst)[4+1], true) funcoff = 3 - legal2, funclib = abs_typeof(operands(inst)[funcoff+1]) + legal2, funclib, byref2 = abs_typeof(operands(inst)[funcoff+1]) if legal && (GT <: Vector || GT <: Tuple) if legal2 tys = [funclib, Vararg{Any}] - if funclib == typeof(Core.apply_type) || is_inactive(tys, world, method_table) + if funclib == typeof(Core.apply_type) || + is_inactive(tys, world, method_table) inactive = LLVM.StringAttribute("enzyme_inactive", "") - LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), inactive) + LLVM.API.LLVMAddCallSiteAttribute( + inst, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + inactive, + ) nofree = LLVM.EnumAttribute("nofree") - LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), nofree) - no_escaping_alloc = LLVM.StringAttribute("enzyme_no_escaping_allocation") - LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), no_escaping_alloc) - elseif funclib == typeof(Base.tuple) && length(operands(inst)) == 4+1+1 && Base.isconcretetype(GT) && Enzyme.Compiler.guaranteed_const_nongen(GT, world) + LLVM.API.LLVMAddCallSiteAttribute( + inst, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + nofree, + ) + no_escaping_alloc = + LLVM.StringAttribute("enzyme_no_escaping_allocation") + LLVM.API.LLVMAddCallSiteAttribute( + inst, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + no_escaping_alloc, + ) + elseif funclib == typeof(Base.tuple) && + length(operands(inst)) == 4 + 1 + 1 && + Base.isconcretetype(GT) && + Enzyme.Compiler.guaranteed_const_nongen(GT, world) inactive = LLVM.StringAttribute("enzyme_inactive", "") - LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), inactive) + LLVM.API.LLVMAddCallSiteAttribute( + inst, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + inactive, + ) nofree = LLVM.EnumAttribute("nofree") - LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), nofree) - no_escaping_alloc = LLVM.StringAttribute("enzyme_no_escaping_allocation") - LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), no_escaping_alloc) + LLVM.API.LLVMAddCallSiteAttribute( + inst, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + nofree, + ) + no_escaping_alloc = + LLVM.StringAttribute("enzyme_no_escaping_allocation") + LLVM.API.LLVMAddCallSiteAttribute( + inst, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + no_escaping_alloc, + ) end end end @@ -654,11 +900,11 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) if isa(dest, LLVM.Function) && in(LLVM.name(dest), keys(generic_method_offsets)) offset, start = generic_method_offsets[LLVM.name(dest)] # Add 1 to account for function being first arg - legal, flibty = abs_typeof(operands(inst)[offset+1]) + legal, flibty, byref = abs_typeof(operands(inst)[offset+1]) if legal tys = Type[flibty] for op in collect(operands(inst))[start+1:end-1] - legal, typ = abs_typeof(op, true) + legal, typ, byref2 = abs_typeof(op, true) if !legal typ = Any end @@ -673,11 +919,33 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) end if is_inactive(tys, world, method_table) inactive = LLVM.StringAttribute("enzyme_inactive", "") - LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), inactive) + LLVM.API.LLVMAddCallSiteAttribute( + inst, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + inactive, + ) nofree = LLVM.EnumAttribute("nofree") - LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), nofree) - no_escaping_alloc = LLVM.StringAttribute("enzyme_no_escaping_allocation") - LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), no_escaping_alloc) + LLVM.API.LLVMAddCallSiteAttribute( + inst, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + nofree, + ) + no_escaping_alloc = + LLVM.StringAttribute("enzyme_no_escaping_allocation") + LLVM.API.LLVMAddCallSiteAttribute( + inst, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + no_escaping_alloc, + ) end end end @@ -697,7 +965,7 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) ptr = Ptr{Cvoid}(ptr_val) # look it up in the Julia JIT cache - frames = ccall(:jl_lookup_code_address, Any, (Ptr{Cvoid}, Cint,), ptr, 0) + frames = ccall(:jl_lookup_code_address, Any, (Ptr{Cvoid}, Cint), ptr, 0) if length(frames) >= 1 fn, file, line, linfo, fromC, inlined = last(frames) @@ -709,11 +977,24 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) mod = LLVM.parent(LLVM.parent(LLVM.parent(inst))) lfn = LLVM.API.LLVMGetNamedFunction(mod, fn) if lfn == C_NULL - lfn = LLVM.API.LLVMAddFunction(mod, fn, LLVM.API.LLVMGetCalledFunctionType(inst)) + lfn = LLVM.API.LLVMAddFunction( + mod, + fn, + LLVM.API.LLVMGetCalledFunctionType(inst), + ) else - lfn = LLVM.API.LLVMConstBitCast(lfn, LLVM.PointerType(LLVM.FunctionType(LLVM.API.LLVMGetCalledFunctionType(inst)))) + lfn = LLVM.API.LLVMConstBitCast( + lfn, + LLVM.PointerType( + LLVM.FunctionType(LLVM.API.LLVMGetCalledFunctionType(inst)), + ), + ) end - LLVM.API.LLVMSetOperand(inst, LLVM.API.LLVMGetNumOperands(inst)-1, lfn) + LLVM.API.LLVMSetOperand( + inst, + LLVM.API.LLVMGetNumOperands(inst) - 1, + lfn, + ) end end end @@ -721,11 +1002,11 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) if isa(dest, LLVM.Function) && in(LLVM.name(dest), keys(generic_method_offsets)) offset, start = generic_method_offsets[LLVM.name(dest)] - legal, flibty = abs_typeof(operands(inst)[offset]) + legal, flibty, byref = abs_typeof(operands(inst)[offset]) if legal tys = Type[flibty] for op in collect(operands(inst))[start:end-1] - legal, typ = abs_typeof(op, true) + legal, typ, byref2 = abs_typeof(op, true) if !legal typ = Any end @@ -735,15 +1016,18 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) if legal && isa(flib, Core.MethodInstance) if !Base.isvarargtype(flib.specTypes.parameters[end]) if length(tys) != length(flib.specTypes.parameters) - msg = sprint() do io::IO - println(io, "Enzyme internal error (length(tys) != length(flib.specTypes.parameters))") - println(io, "tys=", tys) - println(io, "flib=", flib) - println(io, "inst=", inst) - println(io, "offset=", offset) - println(io, "start=", start) - end - throw(AssertionError(msg)) + msg = sprint() do io::IO + println( + io, + "Enzyme internal error (length(tys) != length(flib.specTypes.parameters))", + ) + println(io, "tys=", tys) + println(io, "flib=", flib) + println(io, "inst=", inst) + println(io, "offset=", offset) + println(io, "start=", start) + end + throw(AssertionError(msg)) end end tys = flib.specTypes.parameters @@ -752,11 +1036,33 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) ofn = LLVM.parent(LLVM.parent(inst)) mod = LLVM.parent(ofn) inactive = LLVM.StringAttribute("enzyme_inactive", "") - LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), inactive) + LLVM.API.LLVMAddCallSiteAttribute( + inst, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + inactive, + ) nofree = LLVM.EnumAttribute("nofree") - LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), nofree) - no_escaping_alloc = LLVM.StringAttribute("enzyme_no_escaping_allocation") - LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), no_escaping_alloc) + LLVM.API.LLVMAddCallSiteAttribute( + inst, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + nofree, + ) + no_escaping_alloc = + LLVM.StringAttribute("enzyme_no_escaping_allocation") + LLVM.API.LLVMAddCallSiteAttribute( + inst, + reinterpret( + LLVM.API.LLVMAttributeIndex, + LLVM.API.LLVMAttributeFunctionIndex, + ), + no_escaping_alloc, + ) end end end @@ -767,15 +1073,15 @@ end function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width) - todo = Tuple{LLVM.Value, Tuple}[] + todo = Tuple{LLVM.Value,Tuple}[] for b in blocks(enzymefn) term = terminator(b) if LLVM.API.LLVMIsAReturnInst(term) != C_NULL if width == 1 push!(todo, (operands(term)[1], off == -1 ? () : (off,))) else - for i in 1:width - push!(todo, (operands(term)[1], off == -1 ? (i,) : (off,i))) + for i = 1:width + push!(todo, (operands(term)[1], off == -1 ? (i,) : (off, i))) end end end @@ -803,7 +1109,7 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width if isa(cur, LLVM.ExtractValueInst) noff = off - for i in 1:LLVM.API.LLVMGetNumIndices(cur) + for i = 1:LLVM.API.LLVMGetNumIndices(cur) noff = (noff..., convert(Int, unsafe_load(LLVM.API.LLVMGetIndices(cur), i))) end push!(todo, (operands(cur)[1], noff)) @@ -819,7 +1125,7 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width # if inserting at the current desired offset, we have found the value we need if ind == off[1] push!(todo, (operands(cur)[2], off[2:end])) - # otherwise it must be inserted at a different point + # otherwise it must be inserted at a different point else push!(todo, (operands(cur)[1], off)) end @@ -833,15 +1139,18 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width nm = LLVM.name(fn) end - # Type tag is arg 3 if nm == "julia.gc_alloc_obj" - legal, Ty = abs_typeof(cur) + legal, Ty, byref = abs_typeof(cur) @assert legal reg = active_reg_inner(Ty, (), world) if reg == ActiveState || reg == MixedState NTy = Base.RefValue{Ty} @assert sizeof(Ty) == sizeof(NTy) - LLVM.API.LLVMSetOperand(cur, 2, unsafe_to_llvm(LLVM.IRBuilder(cur), NTy)) + LLVM.API.LLVMSetOperand( + cur, + 2, + unsafe_to_llvm(LLVM.IRBuilder(cur), NTy), + ) end continue end @@ -858,7 +1167,7 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width if isa(cur, LLVM.LoadInst) al = operands(cur)[1] if isa(al, LLVM.AllocaInst) - atodo = Tuple{LLVM.Value, Tuple, LLVM.Value}[] + atodo = Tuple{LLVM.Value,Tuple,LLVM.Value}[] for u in LLVM.uses(al) push!(atodo, (LLVM.user(u), off, al)) end @@ -893,22 +1202,23 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width continue end - msg = sprint() do io::IO - println(io, "Enzyme Internal Error (rewrite_union_returns_as_ref[1])") - println(io, string(enzymefn)) - println(io, "BAD") - println(io, "acur=", acur) - println(io, "aoff=", aoff) - println(io, "prev=", prev) - end - throw(AssertionError(msg)) + msg = sprint() do io::IO + println(io, "Enzyme Internal Error (rewrite_union_returns_as_ref[1])") + println(io, string(enzymefn)) + println(io, "BAD") + println(io, "acur=", acur) + println(io, "aoff=", aoff) + println(io, "prev=", prev) + end + throw(AssertionError(msg)) end continue end end - if length(off) == 0 && value_type(cur) == LLVM.PointerType(LLVM.StructType(LLVMType[]), Tracked) - legal, typ = abs_typeof(cur) + if length(off) == 0 && + value_type(cur) == LLVM.PointerType(LLVM.StructType(LLVMType[]), Tracked) + legal, typ, byref = abs_typeof(cur) if legal reg = active_reg_inner(typ, (), world) if !(reg == ActiveState || reg == MixedState) @@ -921,7 +1231,7 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width push!(todo, (cur[off[1]], off[2:end])) continue end - + if isa(cur, LLVM.CallInst) dest = called_operand(cur) if isa(dest, LLVM.Function) @@ -932,12 +1242,12 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width end end - msg = sprint() do io::IO - println(io, "Enzyme Internal Error (rewrite_union_returns_as_ref[2])") - println(io, string(enzymefn)) - println(io, "cur=", string(cur)) - println(io, "off=", off) - end - throw(AssertionError(msg)) + msg = sprint() do io::IO + println(io, "Enzyme Internal Error (rewrite_union_returns_as_ref[2])") + println(io, string(enzymefn)) + println(io, "cur=", string(cur)) + println(io, "off=", off) + end + throw(AssertionError(msg)) end end diff --git a/src/gradientutils.jl b/src/gradientutils.jl index ff83f60c1d7..b0a0bff26b8 100644 --- a/src/gradientutils.jl +++ b/src/gradientutils.jl @@ -6,14 +6,35 @@ end Base.unsafe_convert(::Type{API.EnzymeGradientUtilsRef}, gutils::GradientUtils) = gutils.ref LLVM.dispose(gutils::GradientUtils) = throw("Cannot free gutils") -function call_samefunc_with_inverted_bundles!(B::LLVM.IRBuilder, gutils::GradientUtils, orig::LLVM.CallInst, args::Vector{<:LLVM.Value}, valTys::Vector{API.CValueType}, lookup::Bool) +function call_samefunc_with_inverted_bundles!( + B::LLVM.IRBuilder, + gutils::GradientUtils, + orig::LLVM.CallInst, + args::Vector{<:LLVM.Value}, + valTys::Vector{API.CValueType}, + lookup::Bool, +) @assert length(args) == length(valTys) - return LLVM.Value(API.EnzymeGradientUtilsCallWithInvertedBundles(gutils, LLVM.called_operand(orig), LLVM.called_type(orig), args, length(args), orig, valTys, length(valTys), B, #=lookup=#false)) + return LLVM.Value( + API.EnzymeGradientUtilsCallWithInvertedBundles( + gutils, + LLVM.called_operand(orig), + LLVM.called_type(orig), + args, + length(args), + orig, + valTys, + length(valTys), + B, + false, + ), + ) #=lookup=# end get_width(gutils::GradientUtils) = API.EnzymeGradientUtilsGetWidth(gutils) get_mode(gutils::GradientUtils) = API.EnzymeGradientUtilsGetMode(gutils) -get_runtime_activity(gutils::GradientUtils) = API.EnzymeGradientUtilsGetRuntimeActivity(gutils) +get_runtime_activity(gutils::GradientUtils) = + API.EnzymeGradientUtilsGetRuntimeActivity(gutils) function get_shadow_type(gutils::GradientUtils, T::LLVM.LLVMType) w = get_width(gutils) @@ -23,26 +44,45 @@ function get_shadow_type(gutils::GradientUtils, T::LLVM.LLVMType) return LLVM.ArrayType(T, Int(w)) end end -function get_uncacheable(gutils::GradientUtils, orig::LLVM.CallInst) - uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig)))-1) - if API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) != 1 +function get_uncacheable(gutils::GradientUtils, orig::LLVM.CallInst) + uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig))) - 1) + if API.EnzymeGradientUtilsGetUncacheableArgs( + gutils, + orig, + uncacheable, + length(uncacheable), + ) != 1 uncacheable .= 1 end return uncacheable end -erase_with_placeholder(gutils::GradientUtils, inst::LLVM.Instruction, orig::LLVM.Instruction, erase::Bool=true) = API.EnzymeGradientUtilsEraseWithPlaceholder(gutils, inst, orig, erase) -is_constant_value(gutils::GradientUtils, val::LLVM.Value) = API.EnzymeGradientUtilsIsConstantValue(gutils, val) != 0 +erase_with_placeholder( + gutils::GradientUtils, + inst::LLVM.Instruction, + orig::LLVM.Instruction, + erase::Bool = true, +) = API.EnzymeGradientUtilsEraseWithPlaceholder(gutils, inst, orig, erase) +is_constant_value(gutils::GradientUtils, val::LLVM.Value) = + API.EnzymeGradientUtilsIsConstantValue(gutils, val) != 0 -is_constant_inst(gutils::GradientUtils, inst::LLVM.Instruction) = API.EnzymeGradientUtilsIsConstantInstruction(gutils, inst) != 0 +is_constant_inst(gutils::GradientUtils, inst::LLVM.Instruction) = + API.EnzymeGradientUtilsIsConstantInstruction(gutils, inst) != 0 -new_from_original(gutils::GradientUtils, val::LLVM.Value) = LLVM.Value(API.EnzymeGradientUtilsNewFromOriginal(gutils, val)) +new_from_original(gutils::GradientUtils, val::LLVM.Value) = + LLVM.Value(API.EnzymeGradientUtilsNewFromOriginal(gutils, val)) -lookup_value(gutils::GradientUtils, val::LLVM.Value, B::LLVM.IRBuilder) = LLVM.Value(API.EnzymeGradientUtilsLookup(gutils, val, B)) +lookup_value(gutils::GradientUtils, val::LLVM.Value, B::LLVM.IRBuilder) = + LLVM.Value(API.EnzymeGradientUtilsLookup(gutils, val, B)) -invert_pointer(gutils::GradientUtils, val::LLVM.Value, B::LLVM.IRBuilder) = LLVM.Value(API.EnzymeGradientUtilsInvertPointer(gutils, val, B)) +invert_pointer(gutils::GradientUtils, val::LLVM.Value, B::LLVM.IRBuilder) = + LLVM.Value(API.EnzymeGradientUtilsInvertPointer(gutils, val, B)) -function debug_from_orig!(gutils::GradientUtils, nval::LLVM.Instruction, oval::LLVM.Instruction) +function debug_from_orig!( + gutils::GradientUtils, + nval::LLVM.Instruction, + oval::LLVM.Instruction, +) API.EnzymeGradientUtilsSetDebugLocFromOriginal(gutils, nval, oval) nothing end diff --git a/src/internal_rules.jl b/src/internal_rules.jl index f29ed0d9776..f8c6e730bb0 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -66,7 +66,12 @@ end function EnzymeRules.inactive(::typeof(Core.kwfunc), args...) return nothing end -function EnzymeRules.inactive(::typeof(Random.rand!), ::Random.AbstractRNG, ::Random.Sampler, ::AbstractArray) +function EnzymeRules.inactive( + ::typeof(Random.rand!), + ::Random.AbstractRNG, + ::Random.Sampler, + ::AbstractArray, +) return nothing end function EnzymeRules.inactive(::typeof(Random.randn!), args...) @@ -96,7 +101,12 @@ end function EnzymeRules.inactive_noinl(::typeof(Base.size), args...) return nothing end -function EnzymeRules.inactive_noinl(::typeof(Base.setindex!), ::IdDict{K, V}, ::K, ::V) where {K, V <:Integer} +function EnzymeRules.inactive_noinl( + ::typeof(Base.setindex!), + ::IdDict{K,V}, + ::K, + ::V, +) where {K,V<:Integer} return nothing end @@ -123,24 +133,49 @@ Enzyme.EnzymeRules.inactive_noinl(::typeof(Core._compute_sparams), args...) = no # Note all of these forward mode definitions do not support runtime activity as # the do not keep the primal if shadow(x.y) == primal(x.y) -function EnzymeRules.forward(config::EnzymeRules.FwdConfig, ::Const{typeof(Base.deepcopy)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + ::Const{typeof(Base.deepcopy)}, + ::Type{<:DuplicatedNoNeed}, + x::Duplicated, +) return deepcopy(x.dval) end -function EnzymeRules.forward(config::EnzymeRules.FwdConfig, ::Const{typeof(Base.deepcopy)}, ::Type{<:BatchDuplicatedNoNeed}, x::BatchDuplicated{T, N}) where {T, N} +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + ::Const{typeof(Base.deepcopy)}, + ::Type{<:BatchDuplicatedNoNeed}, + x::BatchDuplicated{T,N}, +) where {T,N} ntuple(Val(N)) do _ deepcopy(x.dval) end end # Deepcopy preserving the primal if runtime inactive -@inline function deepcopy_rtact(copied::RT, primal::RT, seen::IdDict, shadow::RT) where {RT <: Union{Integer, Char}} +@inline function deepcopy_rtact( + copied::RT, + primal::RT, + seen::IdDict, + shadow::RT, +) where {RT<:Union{Integer,Char}} return Base.deepcopy_internal(shadow, seen) end -@inline function deepcopy_rtact(copied::RT, primal::RT, seen::IdDict, shadow::RT) where {RT <: AbstractFloat} +@inline function deepcopy_rtact( + copied::RT, + primal::RT, + seen::IdDict, + shadow::RT, +) where {RT<:AbstractFloat} return Base.deepcopy_internal(shadow, seen) end -@inline function deepcopy_rtact(copied::RT, primal::RT, seen::IdDict, shadow::RT) where {RT <: Array} +@inline function deepcopy_rtact( + copied::RT, + primal::RT, + seen::IdDict, + shadow::RT, +) where {RT<:Array} if !haskey(seen, shadow) if primal === shadow return seen[shadow] = copied @@ -154,19 +189,34 @@ end return seen[shadow] end -function EnzymeRules.forward(config::EnzymeRules.FwdConfig, func::Const{typeof(Base.deepcopy)}, ::Type{<:Duplicated}, x::Duplicated) +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + func::Const{typeof(Base.deepcopy)}, + ::Type{<:Duplicated}, + x::Duplicated, +) primal = func.val(x.val) return Duplicated(primal, deepcopy_rtact(primal, x.val, IdDict(), x.dval)) end -function EnzymeRules.forward(config::EnzymeRules.FwdConfig, func::Const{typeof(Base.deepcopy)}, ::Type{<:BatchDuplicated}, x::BatchDuplicated{T, N}) where {T,N} +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + func::Const{typeof(Base.deepcopy)}, + ::Type{<:BatchDuplicated}, + x::BatchDuplicated{T,N}, +) where {T,N} primal = func.val(x.val) return BatchDuplicated(primal, ntuple(Val(N)) do i deepcopy_rtact(primal, x.val, IdDict(), x.dval[i]) end) end -function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, func::Const{typeof(Base.deepcopy)}, ::Type{RT}, x::Annotation{Ty}) where {RT, Ty} +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfig, + func::Const{typeof(Base.deepcopy)}, + ::Type{RT}, + x::Annotation{Ty}, +) where {RT,Ty} primal = if EnzymeRules.needs_primal(config) func.val(x.val) else @@ -183,14 +233,16 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, func::Const shadow = if EnzymeRules.needs_shadow(config) if EnzymeRules.width(config) == 1 - Enzyme.make_zero(source, - #=copy_if_inactive=#Val(!EnzymeRules.needs_primal(config)) + Enzyme.make_zero( + source, + Val(!EnzymeRules.needs_primal(config)), #=copy_if_inactive=# ) else ntuple(Val(EnzymeRules.width(config))) do _ Base.@_inline_meta - Enzyme.make_zero(source, - #=copy_if_inactive=#Val(!EnzymeRules.needs_primal(config)) + Enzyme.make_zero( + source, + Val(!EnzymeRules.needs_primal(config)), #=copy_if_inactive=# ) end end @@ -202,7 +254,11 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, func::Const end -@inline function accumulate_into(into::RT, seen::IdDict, from::RT)::Tuple{RT,RT} where {RT<:Array} +@inline function accumulate_into( + into::RT, + seen::IdDict, + from::RT, +)::Tuple{RT,RT} where {RT<:Array} if Enzyme.Compiler.guaranteed_const(RT) return (into, from) end @@ -217,9 +273,13 @@ end return seen[into] end -@inline function accumulate_into(into::RT, seen::IdDict, from::RT)::Tuple{RT,RT} where {RT<:AbstractFloat} +@inline function accumulate_into( + into::RT, + seen::IdDict, + from::RT, +)::Tuple{RT,RT} where {RT<:AbstractFloat} if !haskey(seen, into) - seen[into] = (into+from, RT(0)) + seen[into] = (into + from, RT(0)) end return seen[into] end @@ -234,12 +294,18 @@ end return seen[into] end -function EnzymeRules.reverse(config::EnzymeRules.RevConfig, func::Const{typeof(Base.deepcopy)}, ::Type{RT}, shadow, x::Annotation{Ty}) where {RT, Ty} +function EnzymeRules.reverse( + config::EnzymeRules.RevConfig, + func::Const{typeof(Base.deepcopy)}, + ::Type{RT}, + shadow, + x::Annotation{Ty}, +) where {RT,Ty} if EnzymeRules.needs_shadow(config) if EnzymeRules.width(config) == 1 accumulate_into(x.dval, IdDict(), shadow) else - for i in 1:EnzymeRules.width(config) + for i = 1:EnzymeRules.width(config) accumulate_into(x.dval[i], IdDict(), shadow[i]) end end @@ -248,43 +314,100 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfig, func::Const{typeof(B return (nothing,) end -@inline function pmap_fwd(idx, tapes::Vector, thunk::ThunkTy, f::F, fargs::Vararg{Annotation, N}) where {ThunkTy, F, N} +@inline function pmap_fwd( + idx, + tapes::Vector, + thunk::ThunkTy, + f::F, + fargs::Vararg{Annotation,N}, +) where {ThunkTy,F,N} @inbounds tapes[idx] = thunk(f, Const(idx), fargs...)[1] end -@inline function pmap_fwd(idx, tapes::Ptr, thunk::ThunkTy, f::F, fargs::Vararg{Annotation, N}) where {ThunkTy, F, N} +@inline function pmap_fwd( + idx, + tapes::Ptr, + thunk::ThunkTy, + f::F, + fargs::Vararg{Annotation,N}, +) where {ThunkTy,F,N} unsafe_store!(tapes, thunk(f, Const(idx), fargs...)[1], idx) end -function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, func::Const{typeof(Enzyme.pmap)}, ::Type{Const{Nothing}}, body::BodyTy, count, args::Vararg{Annotation, N}) where {BodyTy, N} - - config2 = ReverseModeSplit{false, false, EnzymeRules.runtime_activity(config), EnzymeRules.width(config), EnzymeRules.overwritten(config)[2:end],InlineABI, false}() - fwd_thunk, rev_thunk = autodiff_thunk(config2, BodyTy, Const, typeof(count), map(typeof, args)...) +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfig, + func::Const{typeof(Enzyme.pmap)}, + ::Type{Const{Nothing}}, + body::BodyTy, + count, + args::Vararg{Annotation,N}, +) where {BodyTy,N} + + config2 = ReverseModeSplit{ + false, + false, + EnzymeRules.runtime_activity(config), + EnzymeRules.width(config), + EnzymeRules.overwritten(config)[2:end], + InlineABI, + false, + }() + fwd_thunk, rev_thunk = + autodiff_thunk(config2, BodyTy, Const, typeof(count), map(typeof, args)...) TapeType = EnzymeRules.tape_type(fwd_thunk) tapes = if Enzyme.Compiler.any_jltypes(TapeType) Vector{TapeType}(undef, count.val) else - Base.unsafe_convert(Ptr{TapeType}, Libc.malloc(sizeof(TapeType)*count.val)) + Base.unsafe_convert(Ptr{TapeType}, Libc.malloc(sizeof(TapeType) * count.val)) end Enzyme.pmap(pmap_fwd, count.val, tapes, fwd_thunk, body, args...) return EnzymeRules.AugmentedReturn(nothing, nothing, tapes) end -@inline function pmap_rev(idx, tapes::Vector, thunk::ThunkTy, f::F, fargs::Vararg{Annotation, N}) where {ThunkTy, F, N} +@inline function pmap_rev( + idx, + tapes::Vector, + thunk::ThunkTy, + f::F, + fargs::Vararg{Annotation,N}, +) where {ThunkTy,F,N} thunk(f, Const(idx), fargs..., @inbounds tapes[idx]) end -@inline function pmap_rev(idx, tapes::Ptr, thunk::ThunkTy, f::F, fargs::Vararg{Annotation, N}) where {ThunkTy, F, N} +@inline function pmap_rev( + idx, + tapes::Ptr, + thunk::ThunkTy, + f::F, + fargs::Vararg{Annotation,N}, +) where {ThunkTy,F,N} thunk(f, Const(idx), fargs..., unsafe_load(tapes, idx)) end -function EnzymeRules.reverse(config::EnzymeRules.RevConfig, func::Const{typeof(Enzyme.pmap)}, ::Type{Const{Nothing}}, tapes, body::BodyTy, count, args::Vararg{Annotation, N}) where {BodyTy, N} - - config2 = ReverseModeSplit{false, false, EnzymeRules.runtime_activity(config), EnzymeRules.width(config), EnzymeRules.overwritten(config)[2:end],InlineABI, false}() - fwd_thunk, rev_thunk = autodiff_thunk(config2, BodyTy, Const, typeof(count), map(typeof, args)...) +function EnzymeRules.reverse( + config::EnzymeRules.RevConfig, + func::Const{typeof(Enzyme.pmap)}, + ::Type{Const{Nothing}}, + tapes, + body::BodyTy, + count, + args::Vararg{Annotation,N}, +) where {BodyTy,N} + + config2 = ReverseModeSplit{ + false, + false, + EnzymeRules.runtime_activity(config), + EnzymeRules.width(config), + EnzymeRules.overwritten(config)[2:end], + InlineABI, + false, + }() + fwd_thunk, rev_thunk = + autodiff_thunk(config2, BodyTy, Const, typeof(count), map(typeof, args)...) Enzyme.pmap(pmap_rev, count.val, tapes, rev_thunk, body, args...) @@ -294,7 +417,7 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfig, func::Const{typeof(E Libc.free(tapes) end - return ntuple(Val(2+length(args))) do _ + return ntuple(Val(2 + length(args))) do _ Base.@_inline_meta nothing end @@ -303,7 +426,7 @@ end # From LinearAlgebra ~/.julia/juliaup/julia-1.10.0-beta3+0.x64.apple.darwin14/share/julia/stdlib/v1.10/LinearAlgebra/src/generic.jl:1110 -@inline function compute_lu_cache(cache_A::AT, b::BT) where {AT, BT} +@inline function compute_lu_cache(cache_A::AT, b::BT) where {AT,BT} LinearAlgebra.require_one_based_indexing(cache_A, b) m, n = size(cache_A) @@ -323,12 +446,18 @@ end return LinearAlgebra.qr(cache_A, ColumnNorm()) end -@inline onedimensionalize(::Type{T}) where T <: Array = Vector{eltype(T)} +@inline onedimensionalize(::Type{T}) where {T<:Array} = Vector{eltype(T)} # y=inv(A) B # dA −= z y^T # dB += z, where z = inv(A^T) dy -function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, func::Const{typeof(\)}, ::Type{RT}, A::Annotation{AT}, b::Annotation{BT}) where {RT, AT <: Array, BT <: Array} +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfig, + func::Const{typeof(\)}, + ::Type{RT}, + A::Annotation{AT}, + b::Annotation{BT}, +) where {RT,AT<:Array,BT<:Array} cache_A = if EnzymeRules.overwritten(config)[2] copy(A.val) @@ -368,30 +497,46 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, func::Const end UT = Union{ - LinearAlgebra.Diagonal{eltype(AT), onedimensionalize(BT)}, - LinearAlgebra.LowerTriangular{eltype(AT), AT}, - LinearAlgebra.UpperTriangular{eltype(AT), AT}, - LinearAlgebra.LU{eltype(AT), AT, Vector{Int}}, - LinearAlgebra.QRPivoted{eltype(AT), AT, onedimensionalize(BT), Vector{Int}} + LinearAlgebra.Diagonal{eltype(AT),onedimensionalize(BT)}, + LinearAlgebra.LowerTriangular{eltype(AT),AT}, + LinearAlgebra.UpperTriangular{eltype(AT),AT}, + LinearAlgebra.LU{eltype(AT),AT,Vector{Int}}, + LinearAlgebra.QRPivoted{eltype(AT),AT,onedimensionalize(BT),Vector{Int}}, } - cache = NamedTuple{(Symbol("1"),Symbol("2"), Symbol("3"), Symbol("4")), Tuple{ - eltype(RT), - EnzymeRules.needs_shadow(config) ? (EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{EnzymeRules.width(config), eltype(RT)}) : Nothing, - UT, - typeof(cache_b) - }}( - (cache_res, dres, cache_A, cache_b) - ) + cache = NamedTuple{ + (Symbol("1"), Symbol("2"), Symbol("3"), Symbol("4")), + Tuple{ + eltype(RT), + EnzymeRules.needs_shadow(config) ? + ( + EnzymeRules.width(config) == 1 ? eltype(RT) : + NTuple{EnzymeRules.width(config),eltype(RT)} + ) : Nothing, + UT, + typeof(cache_b), + }, + }((cache_res, dres, cache_A, cache_b)) return EnzymeRules.AugmentedReturn{ EnzymeRules.primal_type(config, RT), EnzymeRules.shadow_type(config, RT), - typeof(cache) - }(retres, dres, cache) + typeof(cache), + }( + retres, + dres, + cache, + ) end -function EnzymeRules.reverse(config::EnzymeRules.RevConfig, func::Const{typeof(\)}, ::Type{RT}, cache, A::Annotation{<:Array}, b::Annotation{<:Array}) where RT +function EnzymeRules.reverse( + config::EnzymeRules.RevConfig, + func::Const{typeof(\)}, + ::Type{RT}, + cache, + A::Annotation{<:Array}, + b::Annotation{<:Array}, +) where {RT} y, dys, cache_A, cache_b = cache @@ -448,14 +593,14 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfig, func::Const{typeof(\ dy .= eltype(dy)(0) end - return (nothing,nothing) + return (nothing, nothing) end const EnzymeTriangulars = Union{ UpperTriangular{<:Complex}, LowerTriangular{<:Complex}, UnitUpperTriangular{<:Complex}, - UnitLowerTriangular{<:Complex} + UnitLowerTriangular{<:Complex}, } function EnzymeRules.augmented_primal( @@ -464,8 +609,8 @@ function EnzymeRules.augmented_primal( ::Type{RT}, Y::Annotation{YT}, A::Annotation{AT}, - B::Annotation{BT} -) where {RT, YT <: Array, AT <: EnzymeTriangulars, BT <: Array} + B::Annotation{BT}, +) where {RT,YT<:Array,AT<:EnzymeTriangulars,BT<:Array} cache_Y = EnzymeRules.overwritten(config)[1] ? copy(Y.val) : Y.val cache_A = EnzymeRules.overwritten(config)[2] ? copy(A.val) : A.val cache_A = compute_lu_cache(cache_A, B.val) @@ -476,9 +621,11 @@ function EnzymeRules.augmented_primal( return EnzymeRules.AugmentedReturn{ EnzymeRules.primal_type(config, RT), EnzymeRules.shadow_type(config, RT), - Tuple{typeof(cache_Y), typeof(cache_A), typeof(cache_B)} + Tuple{typeof(cache_Y),typeof(cache_A),typeof(cache_B)}, }( - primal, shadow, (cache_Y, cache_A, cache_B) + primal, + shadow, + (cache_Y, cache_A, cache_B), ) end @@ -489,11 +636,11 @@ function EnzymeRules.reverse( cache, Y::Annotation{YT}, A::Annotation{AT}, - B::Annotation{BT} -) where {YT <: Array, RT, AT <: EnzymeTriangulars, BT <: Array} + B::Annotation{BT}, +) where {YT<:Array,RT,AT<:EnzymeTriangulars,BT<:Array} if !isa(Y, Const) (cache_Yout, cache_A, cache_B) = cache - for b in 1:EnzymeRules.width(config) + for b = 1:EnzymeRules.width(config) dY = EnzymeRules.width(config) == 1 ? Y.dval : Y.dval[b] z = adjoint(cache_A) \ dY if !isa(B, Const) @@ -516,7 +663,13 @@ _zero_unused_elements!(X, ::UnitUpperTriangular) = triu!(X, 1) _zero_unused_elements!(X, ::UnitLowerTriangular) = tril!(X, -1) # Force a rule around hvcat_fill as it is type unstable if the tuple is not of the same type (e.g., int, float, int, float) -function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple} +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfig, + func::Const{typeof(Base.hvcat_fill!)}, + ::Type{RT}, + out::Annotation{AT}, + inp::Annotation{BT}, +) where {RT,AT<:Array,BT<:Tuple} primal = if EnzymeRules.needs_primal(config) out.val else @@ -531,9 +684,16 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, func::Const return EnzymeRules.AugmentedReturn(primal, shadow, nothing) end -function EnzymeRules.reverse(config::EnzymeRules.RevConfig, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, _, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple} - nr, nc = size(out.val,1), size(out.val,2) - for b in 1:EnzymeRules.width(config) +function EnzymeRules.reverse( + config::EnzymeRules.RevConfig, + func::Const{typeof(Base.hvcat_fill!)}, + ::Type{RT}, + _, + out::Annotation{AT}, + inp::Annotation{BT}, +) where {RT,AT<:Array,BT<:Tuple} + nr, nc = size(out.val, 1), size(out.val, 2) + for b = 1:EnzymeRules.width(config) da = if EnzymeRules.width(config) == 1 out.dval else @@ -547,7 +707,7 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfig, func::Const{typeof(B res = da[i, j] da[i, j] = 0 j += 1 - if j == nc+1 + if j == nc + 1 i += 1 j = 1 end @@ -558,18 +718,19 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfig, func::Const{typeof(B T(0) end end - return (nothing, dinp)::Tuple{Nothing, BT} + return (nothing, dinp)::Tuple{Nothing,BT} end end return (nothing, nothing) end -function EnzymeRules.forward(config::EnzymeRules.FwdConfig, - ::Const{typeof(sort!)}, - RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, - xs::Duplicated{T}; - kwargs... - ) where {T <: AbstractArray{<:AbstractFloat}} +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + ::Const{typeof(sort!)}, + RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated}}, + xs::Duplicated{T}; + kwargs..., +) where {T<:AbstractArray{<:AbstractFloat}} inds = sortperm(xs.val; kwargs...) xs.val .= xs.val[inds] xs.dval .= xs.dval[inds] @@ -584,15 +745,16 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, end end -function EnzymeRules.forward(config::EnzymeRules.FwdConfig, - ::Const{typeof(sort!)}, - RT::Type{<:Union{Const, BatchDuplicatedNoNeed, BatchDuplicated}}, - xs::BatchDuplicated{T, N}; - kwargs... - ) where {T <: AbstractArray{<:AbstractFloat}, N} +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + ::Const{typeof(sort!)}, + RT::Type{<:Union{Const,BatchDuplicatedNoNeed,BatchDuplicated}}, + xs::BatchDuplicated{T,N}; + kwargs..., +) where {T<:AbstractArray{<:AbstractFloat},N} inds = sortperm(xs.val; kwargs...) xs.val .= xs.val[inds] - for i in 1:N + for i = 1:N xs.dval[i] .= xs.dval[i][inds] end if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) @@ -608,12 +770,12 @@ end function EnzymeRules.augmented_primal( - config::EnzymeRules.RevConfigWidth{1}, - ::Const{typeof(sort!)}, - RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, - xs::Duplicated{T}; - kwargs... - ) where {T <: AbstractArray{<:AbstractFloat}} + config::EnzymeRules.RevConfigWidth{1}, + ::Const{typeof(sort!)}, + RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated}}, + xs::Duplicated{T}; + kwargs..., +) where {T<:AbstractArray{<:AbstractFloat}} inds = sortperm(xs.val; kwargs...) xs.val .= xs.val[inds] xs.dval .= xs.dval[inds] @@ -631,26 +793,27 @@ function EnzymeRules.augmented_primal( end function EnzymeRules.reverse( - config::EnzymeRules.RevConfigWidth{1}, - ::Const{typeof(sort!)}, - RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, - tape, - xs::Duplicated{T}; - kwargs..., - ) where {T <: AbstractArray{<:AbstractFloat}} + config::EnzymeRules.RevConfigWidth{1}, + ::Const{typeof(sort!)}, + RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated}}, + tape, + xs::Duplicated{T}; + kwargs..., +) where {T<:AbstractArray{<:AbstractFloat}} inds = tape back_inds = sortperm(inds) xs.dval .= xs.dval[back_inds] return (nothing,) end -function EnzymeRules.forward(config::EnzymeRules.FwdConfig, - ::Const{typeof(partialsort!)}, - RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, - xs::Duplicated{T}, - k::Const{<:Union{Integer, OrdinalRange}}; - kwargs... - ) where {T <: AbstractArray{<:AbstractFloat}} +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + ::Const{typeof(partialsort!)}, + RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated}}, + xs::Duplicated{T}, + k::Const{<:Union{Integer,OrdinalRange}}; + kwargs..., +) where {T<:AbstractArray{<:AbstractFloat}} kv = k.val inds = collect(eachindex(xs.val)) partialsortperm!(inds, xs.val, kv; kwargs...) @@ -672,18 +835,19 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, end end -function EnzymeRules.forward(config::EnzymeRules.FwdConfig, - ::Const{typeof(partialsort!)}, - RT::Type{<:Union{Const, BatchDuplicatedNoNeed, BatchDuplicated}}, - xs::BatchDuplicated{T, N}, - k::Const{<:Union{Integer, OrdinalRange}}; - kwargs... - ) where {T <: AbstractArray{<:AbstractFloat}, N} +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + ::Const{typeof(partialsort!)}, + RT::Type{<:Union{Const,BatchDuplicatedNoNeed,BatchDuplicated}}, + xs::BatchDuplicated{T,N}, + k::Const{<:Union{Integer,OrdinalRange}}; + kwargs..., +) where {T<:AbstractArray{<:AbstractFloat},N} kv = k.val inds = collect(eachindex(xs.val)) partialsortperm!(inds, xs.val, kv; kwargs...) xs.val .= xs.val[inds] - for i in 1:N + for i = 1:N xs.dval[i] .= xs.dval[i][inds] end @@ -707,13 +871,13 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, end function EnzymeRules.augmented_primal( - config::EnzymeRules.RevConfigWidth{1}, - ::Const{typeof(partialsort!)}, - RT::Type{<:Union{Const, Active, DuplicatedNoNeed, Duplicated}}, - xs::Duplicated{T}, - k::Const{<:Union{Integer, OrdinalRange}}; - kwargs... - ) where {T <: AbstractArray{<:AbstractFloat}} + config::EnzymeRules.RevConfigWidth{1}, + ::Const{typeof(partialsort!)}, + RT::Type{<:Union{Const,Active,DuplicatedNoNeed,Duplicated}}, + xs::Duplicated{T}, + k::Const{<:Union{Integer,OrdinalRange}}; + kwargs..., +) where {T<:AbstractArray{<:AbstractFloat}} kv = k.val inds = collect(eachindex(xs.val)) partialsortperm!(inds, xs.val, kv; kwargs...) @@ -733,14 +897,14 @@ function EnzymeRules.augmented_primal( end function EnzymeRules.reverse( - config::EnzymeRules.RevConfigWidth{1}, - ::Const{typeof(partialsort!)}, - dret::Union{Active, Type{<:Union{Const, Active, DuplicatedNoNeed, Duplicated}}}, - tape, - xs::Duplicated{T}, - k::Const{<:Union{Integer, OrdinalRange}}; - kwargs..., - ) where {T <: AbstractArray{<:AbstractFloat}} + config::EnzymeRules.RevConfigWidth{1}, + ::Const{typeof(partialsort!)}, + dret::Union{Active,Type{<:Union{Const,Active,DuplicatedNoNeed,Duplicated}}}, + tape, + xs::Duplicated{T}, + k::Const{<:Union{Integer,OrdinalRange}}; + kwargs..., +) where {T<:AbstractArray{<:AbstractFloat}} inds = tape kv = k.val if dret isa Active @@ -760,11 +924,14 @@ end # -> # B(out) = inv(A) B(in) # dB(out) = inv(A) [ dB(in) - dA B(out) ] -function EnzymeRules.forward(config::EnzymeRules.FwdConfig, func::Const{typeof(ldiv!)}, - RT::Type{<:Union{Const,Duplicated,BatchDuplicated}}, - fact::Annotation{<:Cholesky}, - B::Annotation{<:AbstractVecOrMat}; - kwargs...) +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + func::Const{typeof(ldiv!)}, + RT::Type{<:Union{Const,Duplicated,BatchDuplicated}}, + fact::Annotation{<:Cholesky}, + B::Annotation{<:AbstractVecOrMat}; + kwargs..., +) if B isa Const retval = func.val(fact.val, B.val; kwargs...) if EnzymeRules.needs_primal(config) @@ -827,10 +994,16 @@ end # Float64 ranges in Julia use bitwise `&` with higher precision # to correct for numerical error, thus we put rules over the # operations as this is not directly differentiable -function EnzymeRules.forward(config::EnzymeRules.FwdConfig, func::Const{Colon}, - RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated, - BatchDuplicated,BatchDuplicatedNoNeed}}, - start::Annotation{<:AbstractFloat}, step::Annotation{<:AbstractFloat}, stop::Annotation{<:AbstractFloat}) +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + func::Const{Colon}, + RT::Type{ + <:Union{Const,DuplicatedNoNeed,Duplicated,BatchDuplicated,BatchDuplicatedNoNeed}, + }, + start::Annotation{<:AbstractFloat}, + step::Annotation{<:AbstractFloat}, + stop::Annotation{<:AbstractFloat}, +) ret = func.val(start.val, step.val, stop.val) dstart = if start isa Const zero(eltype(ret)) @@ -839,7 +1012,9 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, func::Const{Colon}, elseif start isa BatchDuplicated || start isa BatchDuplicatedNoNeed ntuple(i -> start.dval[i], Val(EnzymeRules.width(config))) else - error("Annotation type $(typeof(start)) not supported for range start. Please open an issue") + error( + "Annotation type $(typeof(start)) not supported for range start. Please open an issue", + ) end dstep = if step isa Const @@ -849,25 +1024,39 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, func::Const{Colon}, elseif step isa BatchDuplicated || step isa BatchDuplicatedNoNeed ntuple(i -> step.dval[i], Val(EnzymeRules.width(config))) else - error("Annotation type $(typeof(start)) not supported for range step. Please open an issue") + error( + "Annotation type $(typeof(start)) not supported for range step. Please open an issue", + ) end if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) if EnzymeRules.width(config) == 1 - return Duplicated(ret, range(dstart; step=dstep, length=length(ret))) + return Duplicated(ret, range(dstart; step = dstep, length = length(ret))) else - return BatchDuplicated(ret, - ntuple(i -> range(dstart isa Number ? dstart : dstart[i]; - step=dstep isa Number ? dstep : dstep[i], - length=length(ret)), Val(EnzymeRules.width(config)))) + return BatchDuplicated( + ret, + ntuple( + i -> range( + dstart isa Number ? dstart : dstart[i]; + step = dstep isa Number ? dstep : dstep[i], + length = length(ret), + ), + Val(EnzymeRules.width(config)), + ), + ) end elseif EnzymeRules.needs_shadow(config) if EnzymeRules.width(config) == 1 - return range(dstart; step=dstep, length=length(ret)) + return range(dstart; step = dstep, length = length(ret)) else - return ntuple(i -> range(dstart isa Number ? dstart : dstart[i]; - step=dstep isa Number ? dstep : dstep[i], - length=length(ret)), Val(EnzymeRules.width(config))) + return ntuple( + i -> range( + dstart isa Number ? dstart : dstart[i]; + step = dstep isa Number ? dstep : dstep[i], + length = length(ret), + ), + Val(EnzymeRules.width(config)), + ) end elseif EnzymeRules.needs_primal(config) return ret @@ -878,8 +1067,14 @@ end -function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, func::Const{Colon}, ::Type{<:Active}, - start::Annotation{<:AbstractFloat}, step::Annotation{<:AbstractFloat}, stop::Annotation{<:AbstractFloat}) +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfig, + func::Const{Colon}, + ::Type{<:Active}, + start::Annotation{<:AbstractFloat}, + step::Annotation{<:AbstractFloat}, + stop::Annotation{<:AbstractFloat}, +) if EnzymeRules.needs_primal(config) primal = func.val(start.val, step.val, stop.val) @@ -889,8 +1084,15 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, func::Const return EnzymeRules.AugmentedReturn(primal, nothing, nothing) end -function EnzymeRules.reverse(config::EnzymeRules.RevConfig, func::Const{Colon}, dret, tape::Nothing, - start::Annotation{T1}, step::Annotation{T2}, stop::Annotation{T3}) where {T1<:AbstractFloat, T2<:AbstractFloat, T3<:AbstractFloat} +function EnzymeRules.reverse( + config::EnzymeRules.RevConfig, + func::Const{Colon}, + dret, + tape::Nothing, + start::Annotation{T1}, + step::Annotation{T2}, + stop::Annotation{T3}, +) where {T1<:AbstractFloat,T2<:AbstractFloat,T3<:AbstractFloat} dstart = if start isa Const nothing @@ -929,11 +1131,12 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfig, func::Const{Colon}, end -function EnzymeRules.forward(config::EnzymeRules.FwdConfig, - Ty::Const{Type{BigFloat}}, - RT::Type{<:Union{DuplicatedNoNeed, Duplicated, BatchDuplicated, BatchDuplicatedNoNeed}}; - kwargs... - ) +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + Ty::Const{Type{BigFloat}}, + RT::Type{<:Union{DuplicatedNoNeed,Duplicated,BatchDuplicated,BatchDuplicatedNoNeed}}; + kwargs..., +) if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) if EnzymeRules.width(config) == 1 @@ -950,9 +1153,9 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, return Ty.val(; kwargs...) else return ntuple(Val(EnzymeRules.width(config))) do i - Base.@_inline_meta - Ty.val(; kwargs...) - end + Base.@_inline_meta + Ty.val(; kwargs...) + end end elseif EnzymeRules.needs_primal(config) return Ty.val(; kwargs...) @@ -962,11 +1165,11 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, end function EnzymeRules.augmented_primal( - config::EnzymeRules.RevConfig, - Ty::Const{Type{BigFloat}}, - RT::Type{<:Union{DuplicatedNoNeed, Duplicated, BatchDuplicated, BatchDuplicatedNoNeed}}, - kwargs... - ) + config::EnzymeRules.RevConfig, + Ty::Const{Type{BigFloat}}, + RT::Type{<:Union{DuplicatedNoNeed,Duplicated,BatchDuplicated,BatchDuplicatedNoNeed}}, + kwargs..., +) primal = if EnzymeRules.needs_primal(config) Ty.val(; kwargs...) else @@ -988,22 +1191,23 @@ function EnzymeRules.augmented_primal( end function EnzymeRules.reverse( - config::EnzymeRules.RevConfig, - Ty::Const{Type{BigFloat}}, - RT::Type{<:Union{DuplicatedNoNeed, Duplicated, BatchDuplicated, BatchDuplicatedNoNeed}}, - tape, - kwargs..., - ) + config::EnzymeRules.RevConfig, + Ty::Const{Type{BigFloat}}, + RT::Type{<:Union{DuplicatedNoNeed,Duplicated,BatchDuplicated,BatchDuplicatedNoNeed}}, + tape, + kwargs..., +) return () end -function EnzymeRules.forward(config::EnzymeRules.FwdConfig, - Ty::Const{typeof(Random.rand!)}, - RT::Type, - rng::Annotation{rngty}, - dst::Annotation{<:Array{FT}}, - smpl::Annotation{<:Random.SamplerTrivial{Random.CloseOpen01{FT}}}, - ) where {rngty <: Union{TaskLocalRNG, Xoshiro}, FT <: Union{Float32, Float64}} +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + Ty::Const{typeof(Random.rand!)}, + RT::Type, + rng::Annotation{rngty}, + dst::Annotation{<:Array{FT}}, + smpl::Annotation{<:Random.SamplerTrivial{Random.CloseOpen01{FT}}}, +) where {rngty<:Union{TaskLocalRNG,Xoshiro},FT<:Union{Float32,Float64}} Ty.val(rng.val, dst.val, smpl.val) if !(dst isa Const) @@ -1017,7 +1221,7 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, end end end - + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) dst elseif EnzymeRules.needs_shadow(config) @@ -1029,13 +1233,14 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, end end -function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, - Ty::Const{typeof(Random.rand!)}, - RT::Type, - rng::Annotation{rngty}, - dst::Annotation{<:Array{FT}}, - smpl::Annotation{<:Random.SamplerTrivial{Random.CloseOpen01{FT}}}, - ) where {rngty <: Union{TaskLocalRNG, Xoshiro}, FT <: Union{Float32, Float64}} +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfig, + Ty::Const{typeof(Random.rand!)}, + RT::Type, + rng::Annotation{rngty}, + dst::Annotation{<:Array{FT}}, + smpl::Annotation{<:Random.SamplerTrivial{Random.CloseOpen01{FT}}}, +) where {rngty<:Union{TaskLocalRNG,Xoshiro},FT<:Union{Float32,Float64}} Ty.val(rng.val, dst.val, smpl.val) if RT <: Duplicated || RT <: DuplicatedNoNeed fill!(dst.dval, 0) @@ -1047,16 +1252,21 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, nothing end end - return EnzymeRules.AugmentedReturn(EnzymeRules.needs_primal(config) ? dst.val : nothing, EnzymeRules.needs_shadow(config) ? dst.dval : nothing, nothing) + return EnzymeRules.AugmentedReturn( + EnzymeRules.needs_primal(config) ? dst.val : nothing, + EnzymeRules.needs_shadow(config) ? dst.dval : nothing, + nothing, + ) end -function EnzymeRules.reverse(config::EnzymeRules.RevConfig, - Ty::Const{typeof(Random.rand!)}, - RT::Type, - tape, - rng::Annotation{rngty}, - dst::Annotation{<:Array{FT}}, - smpl::Annotation{<:Random.SamplerTrivial{Random.CloseOpen01{FT}}}, - ) where {rngty <: Union{TaskLocalRNG, Xoshiro}, FT <: Union{Float32, Float64}} +function EnzymeRules.reverse( + config::EnzymeRules.RevConfig, + Ty::Const{typeof(Random.rand!)}, + RT::Type, + tape, + rng::Annotation{rngty}, + dst::Annotation{<:Array{FT}}, + smpl::Annotation{<:Random.SamplerTrivial{Random.CloseOpen01{FT}}}, +) where {rngty<:Union{TaskLocalRNG,Xoshiro},FT<:Union{Float32,Float64}} return (nothing, nothing, nothing) end diff --git a/src/pmap.jl b/src/pmap.jl index f5160e0b62c..a46d1c777ab 100644 --- a/src/pmap.jl +++ b/src/pmap.jl @@ -5,17 +5,17 @@ function pmap(body::Body, count, args::Vararg{Any,N}) where {Body,N} tasks = Vector{Task}(undef, n_gen) cnt = (count + n_gen - 1) ÷ n_gen for i = 0:(n_gen-1) - let start = i * cnt, endv = min(count, (i+1) * cnt)-1 - t = Task() do - for j in start:endv - body(j+1, args...) - end - nothing - end - t.sticky = true - ccall(:jl_set_task_tid, Cint, (Any, Cint), t, i) - @inbounds tasks[i+1] = t - schedule(t) + let start = i * cnt, endv = min(count, (i + 1) * cnt) - 1 + t = Task() do + for j = start:endv + body(j + 1, args...) + end + nothing + end + t.sticky = true + ccall(:jl_set_task_tid, Cint, (Any, Cint), t, i) + @inbounds tasks[i+1] = t + schedule(t) end end try @@ -28,27 +28,27 @@ function pmap(body::Body, count, args::Vararg{Any,N}) where {Body,N} end macro parallel(args...) - captured = args[1:end-1] - ex = args[end] - if !(isa(ex, Expr) && ex.head === :for) - throw(ArgumentError("@parallel requires a `for` loop expression")) - end - if !(ex.args[1] isa Expr && ex.args[1].head === :(=)) + captured = args[1:end-1] + ex = args[end] + if !(isa(ex, Expr) && ex.head === :for) + throw(ArgumentError("@parallel requires a `for` loop expression")) + end + if !(ex.args[1] isa Expr && ex.args[1].head === :(=)) throw(ArgumentError("nested outer loops are not currently supported by @parallel")) - end - iter = ex.args[1] - lidx = iter.args[1] # index - range = iter.args[2] - body = ex.args[2] - esc(quote - let range = $(range) - function bodyf(idx, iter, $(captured...)) - local $(lidx) = @inbounds iter[idx] - $(body) - nothing - end - lenr = length(range) - $pmap(bodyf, lenr, range, $(captured...)) - end - end) + end + iter = ex.args[1] + lidx = iter.args[1] # index + range = iter.args[2] + body = ex.args[2] + esc(quote + let range = $(range) + function bodyf(idx, iter, $(captured...)) + local $(lidx) = @inbounds iter[idx] + $(body) + nothing + end + lenr = length(range) + $pmap(bodyf, lenr, range, $(captured...)) + end + end) end diff --git a/src/rules/activityrules.jl b/src/rules/activityrules.jl index 9e320239571..7a940259fab 100644 --- a/src/rules/activityrules.jl +++ b/src/rules/activityrules.jl @@ -1,10 +1,13 @@ function julia_activity_rule(f::LLVM.Function) + if startswith(LLVM.name(f)) == "japi3" + return + end mi, RT = enzyme_custom_extract_mi(f) - llRT, sret, returnRoots = get_return_info(RT) + llRT, sret, returnRoots = get_return_info(RT) retRemoved, parmsRemoved = removed_ret_parms(f) - + dl = string(LLVM.datalayout(LLVM.parent(f))) expectLen = (sret !== nothing) + (returnRoots !== nothing) @@ -12,11 +15,18 @@ function julia_activity_rule(f::LLVM.Function) if isghostty(source_typ) || Core.Compiler.isconstType(source_typ) continue end - expectLen+=1 + expectLen += 1 end expectLen -= length(parmsRemoved) - swiftself = any(any(map(k->kind(k)==kind(EnumAttribute("swiftself")), collect(parameter_attributes(f, i)))) for i in 1:length(collect(parameters(f)))) + swiftself = any( + any( + map( + k -> kind(k) == kind(EnumAttribute("swiftself")), + collect(parameter_attributes(f, i)), + ), + ) for i = 1:length(collect(parameters(f))) + ) if swiftself expectLen += 1 @@ -31,21 +41,28 @@ function julia_activity_rule(f::LLVM.Function) # TODO fix the attributor inlining such that this can assert always true if expectLen != length(parameters(f)) - msg = sprint() do io::IO - println(io, "Enzyme Internal Error (expectLen != length(parameters(f)))") - println(io, string(f)) - println(io, "expectLen=", string(expectLen)) - println(io, "swiftself=", string(swiftself)) - println(io, "sret=", string(sret)) - println(io, "returnRoots=", string(returnRoots)) - println(io, "mi.specTypes.parameters=", string(mi.specTypes.parameters)) - println(io, "retRemoved=", string(retRemoved)) - println(io, "parmsRemoved=", string(parmsRemoved)) - end - throw(AssertionError(msg)) + msg = sprint() do io::IO + println(io, "Enzyme Internal Error (expectLen != length(parameters(f)))") + println(io, string(f)) + println(io, "expectLen=", string(expectLen)) + println(io, "swiftself=", string(swiftself)) + println(io, "sret=", string(sret)) + println(io, "returnRoots=", string(returnRoots)) + println(io, "mi.specTypes.parameters=", string(mi.specTypes.parameters)) + println(io, "retRemoved=", string(retRemoved)) + println(io, "parmsRemoved=", string(parmsRemoved)) + end + throw(AssertionError(msg)) end - jlargs = classify_arguments(mi.specTypes, function_type(f), sret !== nothing, returnRoots !== nothing, swiftself, parmsRemoved) + jlargs = classify_arguments( + mi.specTypes, + function_type(f), + sret !== nothing, + returnRoots !== nothing, + swiftself, + parmsRemoved, + ) if !Enzyme.Compiler.no_type_setting(mi.specTypes; world)[1] for arg in jlargs @@ -54,12 +71,15 @@ function julia_activity_rule(f::LLVM.Function) end op_idx = arg.codegen.i - + typ, _ = enzyme_extract_parm_type(f, arg.codegen.i) @assert typ == arg.typ if guaranteed_const_nongen(arg.typ, world) - push!(parameter_attributes(f, arg.codegen.i), StringAttribute("enzyme_inactive")) + push!( + parameter_attributes(f, arg.codegen.i), + StringAttribute("enzyme_inactive"), + ) end end end @@ -69,13 +89,19 @@ function julia_activity_rule(f::LLVM.Function) idx = 0 if !in(0, parmsRemoved) if guaranteed_const_nongen(RT, world) - push!(parameter_attributes(f, idx+1), StringAttribute("enzyme_inactive")) + push!( + parameter_attributes(f, idx + 1), + StringAttribute("enzyme_inactive"), + ) end - idx+=1 + idx += 1 end if returnRoots !== nothing if !in(idx, parmsRemoved) - push!(parameter_attributes(f, idx+1), StringAttribute("enzyme_inactive")) + push!( + parameter_attributes(f, idx + 1), + StringAttribute("enzyme_inactive"), + ) end end end diff --git a/src/rules/allocrules.jl b/src/rules/allocrules.jl index 8e626d185f0..1c1447dd658 100644 --- a/src/rules/allocrules.jl +++ b/src/rules/allocrules.jl @@ -1,16 +1,27 @@ -function array_inner(::Type{<:Array{T}}) where T +function array_inner(::Type{<:Array{T}}) where {T} return T end -function array_shadow_handler(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, numArgs::Csize_t, Args::Ptr{LLVM.API.LLVMValueRef}, gutils::API.EnzymeGradientUtilsRef)::LLVM.API.LLVMValueRef +function array_shadow_handler( + B::LLVM.API.LLVMBuilderRef, + OrigCI::LLVM.API.LLVMValueRef, + numArgs::Csize_t, + Args::Ptr{LLVM.API.LLVMValueRef}, + gutils::API.EnzymeGradientUtilsRef, +)::LLVM.API.LLVMValueRef inst = LLVM.Instruction(OrigCI) mod = LLVM.parent(LLVM.parent(LLVM.parent(inst))) ctx = LLVM.context(LLVM.Value(OrigCI)) gutils = GradientUtils(gutils) - legal, typ = abs_typeof(inst) + legal, typ, byref = abs_typeof(inst) if !legal - throw(AssertionError("Could not statically ahead-of-time determine allocation element type of "*string(inst))) + throw( + AssertionError( + "Could not statically ahead-of-time determine allocation element type of " * + string(inst), + ), + ) end typ = eltype(typ) @@ -25,7 +36,7 @@ function array_shadow_handler(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMV push!(valTys, API.VT_Primal) end - anti = call_samefunc_with_inverted_bundles!(b, gutils, orig, vals, valTys, #=lookup=#false) + anti = call_samefunc_with_inverted_bundles!(b, gutils, orig, vals, valTys, false) #=lookup=# prod = get_array_len(b, anti) @@ -33,11 +44,11 @@ function array_shadow_handler(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMV isunion = typ isa Union - LLT_ALIGN(x, sz) = (((x) + (sz)-1) & ~((sz)-1)) + LLT_ALIGN(x, sz) = (((x) + (sz) - 1) & ~((sz) - 1)) if !isunboxed elsz = sizeof(Ptr{Cvoid}) - al = elsz; + al = elsz else elsz = LLT_ALIGN(elsz, al) end @@ -63,7 +74,11 @@ function array_shadow_handler(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMV return ref end -function null_free_handler(B::LLVM.API.LLVMBuilderRef, ToFree::LLVM.API.LLVMValueRef, Fn::LLVM.API.LLVMValueRef)::LLVM.API.LLVMValueRef +function null_free_handler( + B::LLVM.API.LLVMBuilderRef, + ToFree::LLVM.API.LLVMValueRef, + Fn::LLVM.API.LLVMValueRef, +)::LLVM.API.LLVMValueRef return C_NULL end @@ -76,22 +91,78 @@ end @inline function register_alloc_rules() register_alloc_handler!( ("jl_alloc_array_1d", "ijl_alloc_array_1d"), - @cfunction(array_shadow_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, Csize_t, Ptr{LLVM.API.LLVMValueRef}, API.EnzymeGradientUtilsRef)), - @cfunction(null_free_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef)) + @cfunction( + array_shadow_handler, + LLVM.API.LLVMValueRef, + ( + LLVM.API.LLVMBuilderRef, + LLVM.API.LLVMValueRef, + Csize_t, + Ptr{LLVM.API.LLVMValueRef}, + API.EnzymeGradientUtilsRef, + ) + ), + @cfunction( + null_free_handler, + LLVM.API.LLVMValueRef, + (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef) + ) ) register_alloc_handler!( ("jl_alloc_array_2d", "ijl_alloc_array_2d"), - @cfunction(array_shadow_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, Csize_t, Ptr{LLVM.API.LLVMValueRef}, API.EnzymeGradientUtilsRef)), - @cfunction(null_free_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef)) + @cfunction( + array_shadow_handler, + LLVM.API.LLVMValueRef, + ( + LLVM.API.LLVMBuilderRef, + LLVM.API.LLVMValueRef, + Csize_t, + Ptr{LLVM.API.LLVMValueRef}, + API.EnzymeGradientUtilsRef, + ) + ), + @cfunction( + null_free_handler, + LLVM.API.LLVMValueRef, + (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef) + ) ) register_alloc_handler!( ("jl_alloc_array_3d", "ijl_alloc_array_3d"), - @cfunction(array_shadow_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, Csize_t, Ptr{LLVM.API.LLVMValueRef}, API.EnzymeGradientUtilsRef)), - @cfunction(null_free_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef)) + @cfunction( + array_shadow_handler, + LLVM.API.LLVMValueRef, + ( + LLVM.API.LLVMBuilderRef, + LLVM.API.LLVMValueRef, + Csize_t, + Ptr{LLVM.API.LLVMValueRef}, + API.EnzymeGradientUtilsRef, + ) + ), + @cfunction( + null_free_handler, + LLVM.API.LLVMValueRef, + (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef) + ) ) register_alloc_handler!( ("jl_new_array", "ijl_new_array"), - @cfunction(array_shadow_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, Csize_t, Ptr{LLVM.API.LLVMValueRef}, API.EnzymeGradientUtilsRef)), - @cfunction(null_free_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef)) + @cfunction( + array_shadow_handler, + LLVM.API.LLVMValueRef, + ( + LLVM.API.LLVMBuilderRef, + LLVM.API.LLVMValueRef, + Csize_t, + Ptr{LLVM.API.LLVMValueRef}, + API.EnzymeGradientUtilsRef, + ) + ), + @cfunction( + null_free_handler, + LLVM.API.LLVMValueRef, + (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef) + ) ) -end \ No newline at end of file +end diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 08cd15facbc..1985283da3e 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -1,5 +1,13 @@ -function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, mi, @nospecialize(RT), reverse::Bool, isKWCall::Bool) +function enzyme_custom_setup_args( + B, + orig::LLVM.CallInst, + gutils::GradientUtils, + mi, + @nospecialize(RT), + reverse::Bool, + isKWCall::Bool, +) ops = collect(operands(orig)) called = ops[end] ops = ops[1:end-1] @@ -12,10 +20,10 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, actives = LLVM.Value[] - mixeds = Tuple{LLVM.Value, Type, LLVM.Value}[] + mixeds = Tuple{LLVM.Value,Type,LLVM.Value}[] uncacheable = get_uncacheable(gutils, orig) mode = get_mode(gutils) - + retRemoved, parmsRemoved = removed_ret_parms(orig) @assert length(parmsRemoved) == 0 @@ -25,8 +33,22 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, returnRoots = returnRoots !== nothing cv = LLVM.called_operand(orig) - swiftself = any(any(map(k->kind(k)==kind(EnumAttribute("swiftself")), collect(parameter_attributes(cv, i)))) for i in 1:length(collect(parameters(cv)))) - jlargs = classify_arguments(mi.specTypes, called_type(orig), sret, returnRoots, swiftself, parmsRemoved) + swiftself = any( + any( + map( + k -> kind(k) == kind(EnumAttribute("swiftself")), + collect(parameter_attributes(cv, i)), + ), + ) for i = 1:length(collect(parameters(cv))) + ) + jlargs = classify_arguments( + mi.specTypes, + called_type(orig), + sret, + returnRoots, + swiftself, + parmsRemoved, + ) alloctx = LLVM.IRBuilder() position!(alloctx, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils))) @@ -49,23 +71,33 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, push!(overwritten, false) end if B !== nothing - if Core.Compiler.isconstType(arg.typ) && !Core.Compiler.isconstType(Const{arg.typ}) - llty = convert(LLVMType, Const{arg.typ}) - al0 = al = emit_allocobj!(B, Const{arg.typ}) - al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al)))) - al = addrspacecast!(B, al, LLVM.PointerType(llty, Derived)) - - ptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 0)]) - val = unsafe_to_llvm(B, arg.typ.parameters[1]) - store!(B, val, ptr) + if Core.Compiler.isconstType(arg.typ) && + !Core.Compiler.isconstType(Const{arg.typ}) + llty = convert(LLVMType, Const{arg.typ}) + al0 = al = emit_allocobj!(B, Const{arg.typ}) + al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al)))) + al = addrspacecast!(B, al, LLVM.PointerType(llty, Derived)) + + ptr = inbounds_gep!( + B, + llty, + al, + [ + LLVM.ConstantInt(LLVM.IntType(64), 0), + LLVM.ConstantInt(LLVM.IntType(32), 0), + ], + ) + val = unsafe_to_llvm(B, arg.typ.parameters[1]) + store!(B, val, ptr) - if any_jltypes(llty) - emit_writebarrier!(B, get_julia_inner_types(B, al0, val)) + if any_jltypes(llty) + emit_writebarrier!(B, get_julia_inner_types(B, al0, val)) + end + push!(args, al) + else + @assert isghostty(Const{arg.typ}) || + Core.Compiler.isconstType(Const{arg.typ}) end - push!(args, al) - else - @assert isghostty(Const{arg.typ}) || Core.Compiler.isconstType(Const{arg.typ}) - end end continue end @@ -82,7 +114,7 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, val = lookup_value(gutils, val, B) end - activep = API.EnzymeGradientUtilsGetDiffeType(gutils, op, #=isforeign=#false) + activep = API.EnzymeGradientUtilsGetDiffeType(gutils, op, false) #=isforeign=# if isKWCall && arg.arg_i == 2 Ty = arg.typ @@ -103,13 +135,21 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, if activep == API.DFT_CONSTANT Ty = Const{arg.typ} llty = convert(LLVMType, Ty) - arty = convert(LLVMType, arg.typ; allow_boxed=true) + arty = convert(LLVMType, arg.typ; allow_boxed = true) if B !== nothing al0 = al = emit_allocobj!(B, Ty) al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al)))) al = addrspacecast!(B, al, LLVM.PointerType(llty, Derived)) - ptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 0)]) + ptr = inbounds_gep!( + B, + llty, + al, + [ + LLVM.ConstantInt(LLVM.IntType(64), 0), + LLVM.ConstantInt(LLVM.IntType(32), 0), + ], + ) if value_type(val) != eltype(value_type(ptr)) val = load!(B, arty, val) end @@ -124,30 +164,45 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, push!(activity, Ty) - elseif activep == API.DFT_OUT_DIFF || (mode != API.DEM_ForwardMode && active_reg_inner(arg.typ, (), world) == ActiveState) + elseif activep == API.DFT_OUT_DIFF || ( + mode != API.DEM_ForwardMode && + active_reg_inner(arg.typ, (), world) == ActiveState + ) Ty = Active{arg.typ} llty = convert(LLVMType, Ty) - arty = convert(LLVMType, arg.typ; allow_boxed=true) + arty = convert(LLVMType, arg.typ; allow_boxed = true) if B !== nothing al0 = al = emit_allocobj!(B, Ty) al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al)))) al = addrspacecast!(B, al, LLVM.PointerType(llty, Derived)) - ptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 0)]) + ptr = inbounds_gep!( + B, + llty, + al, + [ + LLVM.ConstantInt(LLVM.IntType(64), 0), + LLVM.ConstantInt(LLVM.IntType(32), 0), + ], + ) if value_type(val) != eltype(value_type(ptr)) if overwritten[end] emit_error( B, orig, - "Enzyme: active by ref type $Ty is overwritten in application of custom rule for $mi val=$(string(val)) ptr=$(string(ptr)). " - * "As a workaround until support for this is added, try passing values as separate arguments rather than as an aggregate of type $Ty.", + "Enzyme: active by ref type $Ty is overwritten in application of custom rule for $mi val=$(string(val)) ptr=$(string(ptr)). " * + "As a workaround until support for this is added, try passing values as separate arguments rather than as an aggregate of type $Ty.", ) end if arty == eltype(value_type(val)) val = load!(B, arty, val) else val = LLVM.UndefValue(arty) - emit_error(B, orig, "Enzyme: active by ref type $Ty is wrong type in application of custom rule for $mi val=$(string(val)) ptr=$(string(ptr))") + emit_error( + B, + orig, + "Enzyme: active by ref type $Ty is wrong type in application of custom rule for $mi val=$(string(val)) ptr=$(string(ptr))", + ) end end @@ -157,7 +212,11 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, emit_writebarrier!(B, get_julia_inner_types(B, al0, val)) end else - emit_error(B, orig, "Enzyme: active by ref type $Ty is wrong store type in application of custom rule for $mi val=$(string(val)) ptr=$(string(ptr))") + emit_error( + B, + orig, + "Enzyme: active by ref type $Ty is wrong store type in application of custom rule for $mi val=$(string(val)) ptr=$(string(ptr))", + ) end push!(args, al) @@ -193,21 +252,21 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, if active_reg_inner(arg.typ, (), world) == MixedState # TODO batchmixedupnoneed shadowty = Base.RefValue{shadowty} - Ty = BatchMixedDuplicated{arg.typ, Int(width)} + Ty = BatchMixedDuplicated{arg.typ,Int(width)} mixed = true else if activep == API.DFT_DUP_ARG - Ty = BatchDuplicated{arg.typ, Int(width)} + Ty = BatchDuplicated{arg.typ,Int(width)} else @assert activep == API.DFT_DUP_NONEED - Ty = BatchDuplicatedNoNeed{arg.typ, Int(width)} + Ty = BatchDuplicatedNoNeed{arg.typ,Int(width)} end end end llty = convert(LLVMType, Ty) - arty = convert(LLVMType, arg.typ; allow_boxed=true) - iarty = convert(LLVMType, shadowty; allow_boxed=true) + arty = convert(LLVMType, arg.typ; allow_boxed = true) + iarty = convert(LLVMType, shadowty; allow_boxed = true) sarty = LLVM.LLVMType(API.EnzymeGetShadowType(width, arty)) siarty = LLVM.LLVMType(API.EnzymeGetShadowType(width, iarty)) if B !== nothing @@ -215,42 +274,63 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al)))) al = addrspacecast!(B, al, LLVM.PointerType(llty, Derived)) - ptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 0)]) + ptr = inbounds_gep!( + B, + llty, + al, + [ + LLVM.ConstantInt(LLVM.IntType(64), 0), + LLVM.ConstantInt(LLVM.IntType(32), 0), + ], + ) needsload = false if value_type(val) != eltype(value_type(ptr)) val = load!(B, arty, val) if !mixed ptr_val = ival ival = UndefValue(siarty) - for idx in 1:width - ev = (width == 1) ? ptr_val : extract_value!(B, ptr_val, idx-1) + for idx = 1:width + ev = + (width == 1) ? ptr_val : extract_value!(B, ptr_val, idx - 1) ld = load!(B, iarty, ev) - ival = (width == 1 ) ? ld : insert_value!(B, ival, ld, idx-1) + ival = (width == 1) ? ld : insert_value!(B, ival, ld, idx - 1) end end needsload = true end store!(B, val, ptr) - iptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 1)]) - + iptr = inbounds_gep!( + B, + llty, + al, + [ + LLVM.ConstantInt(LLVM.IntType(64), 0), + LLVM.ConstantInt(LLVM.IntType(32), 1), + ], + ) + if mixed RefTy = arg.typ if width != 1 - RefTy = NTuple{Int(width), RefTy} + RefTy = NTuple{Int(width),RefTy} end llrty = convert(LLVMType, RefTy) RefTy = Base.RefValue{RefTy} refal0 = refal = emit_allocobj!(B, RefTy) - refal = bitcast!(B, refal, LLVM.PointerType(llrty, addrspace(value_type(refal)))) + refal = bitcast!( + B, + refal, + LLVM.PointerType(llrty, addrspace(value_type(refal))), + ) @assert needsload ptr_val = ival ival = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, llrty))) - for idx in 1:width - ev = (width == 1) ? ptr_val : extract_value!(B, ptr_val, idx-1) + for idx = 1:width + ev = (width == 1) ? ptr_val : extract_value!(B, ptr_val, idx - 1) ld = load!(B, llrty, ev) - ival = (width == 1 ) ? ld : insert_value!(B, ival, ld, idx-1) + ival = (width == 1) ? ld : insert_value!(B, ival, ld, idx - 1) end store!(B, ival, refal) emit_writebarrier!(B, get_julia_inner_types(B, refal0, ival)) @@ -273,10 +353,16 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, return args, activity, (overwritten...,), actives, kwtup, mixeds end -function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi, @nospecialize(RealRt), B) +function enzyme_custom_setup_ret( + gutils::GradientUtils, + orig::LLVM.CallInst, + mi, + @nospecialize(RealRt), + B, +) width = get_width(gutils) mode = get_mode(gutils) - + world = enzyme_extract_world(LLVM.parent(LLVM.parent(orig))) needsShadowP = Ref{UInt8}(0) @@ -286,28 +372,41 @@ function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi, # calls differential use analysis to determine needsprimal/shadow. However, since now this function # is used as part of differential use analysis, we need to avoid an ininite recursion. Thus use # the version without differential use if actual unreachable results are not available anyways. - uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig)))-1) + uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig))) - 1) cmode = mode if cmode == API.DEM_ReverseModeGradient cmode = API.DEM_ReverseModePrimal end - activep = if mode == API.DEM_ForwardMode || API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) == 1 - API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, cmode) - else - actv = API.EnzymeGradientUtilsGetDiffeType(gutils, orig, false) - if !isghostty(RealRt) - needsPrimalP[] = 1 - if actv == API.DFT_DUP_ARG || actv == API.DFT_DUP_NONEED - needsShadowP[] = 1 + activep = + if mode == API.DEM_ForwardMode || + API.EnzymeGradientUtilsGetUncacheableArgs( + gutils, + orig, + uncacheable, + length(uncacheable), + ) == 1 + API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + cmode, + ) + else + actv = API.EnzymeGradientUtilsGetDiffeType(gutils, orig, false) + if !isghostty(RealRt) + needsPrimalP[] = 1 + if actv == API.DFT_DUP_ARG || actv == API.DFT_DUP_NONEED + needsShadowP[] = 1 + end end + actv end - actv - end needsPrimal = needsPrimalP[] != 0 origNeedsPrimal = needsPrimal _, sret, _ = get_return_info(RealRt) if sret !== nothing - activep = API.EnzymeGradientUtilsGetDiffeType(gutils, operands(orig)[1], #=isforeign=#false) + activep = API.EnzymeGradientUtilsGetDiffeType(gutils, operands(orig)[1], false) #=isforeign=# needsPrimal = activep == API.DFT_DUP_ARG || activep == API.DFT_CONSTANT needsShadowP[] = activep == API.DFT_DUP_ARG || activep == API.DFT_DUP_NONEED end @@ -315,13 +414,20 @@ function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi, if !needsPrimal && activep == API.DFT_DUP_ARG activep = API.DFT_DUP_NONEED end - + if activep == API.DFT_CONSTANT RT = Const{RealRt} - elseif activep == API.DFT_OUT_DIFF || (mode != API.DEM_ForwardMode && active_reg_inner(RealRt, (), world, #=justActive=#Val(true)) == ActiveState) - if active_reg_inner(RealRt, (), world, #=justActive=#Val(false)) == MixedState && B !== nothing - emit_error(B, orig, "Enzyme: Return type $RealRt has mixed internal activity types in evaluation of custom rule for $mi. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information") + elseif activep == API.DFT_OUT_DIFF || ( + mode != API.DEM_ForwardMode && + active_reg_inner(RealRt, (), world, Val(true)) == ActiveState + ) #=justActive=# + if active_reg_inner(RealRt, (), world, Val(false)) == MixedState && B !== nothing #=justActive=# + emit_error( + B, + orig, + "Enzyme: Return type $RealRt has mixed internal activity types in evaluation of custom rule for $mi. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information", + ) end RT = Active{RealRt} @@ -329,20 +435,20 @@ function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi, if width == 1 RT = Duplicated{RealRt} else - RT = BatchDuplicated{RealRt, Int(width)} + RT = BatchDuplicated{RealRt,Int(width)} end else @assert activep == API.DFT_DUP_NONEED if width == 1 RT = DuplicatedNoNeed{RealRt} else - RT = BatchDuplicatedNoNeed{RealRt, Int(width)} + RT = BatchDuplicatedNoNeed{RealRt,Int(width)} end end return RT, needsPrimal, needsShadowP[] != 0, origNeedsPrimal end -function custom_rule_method_error(world, fn, args...) +function custom_rule_method_error(world, fn, args...) throw(MethodError(fn, (args...,), world)) end @@ -354,7 +460,10 @@ end width = get_width(gutils) if shadowR != C_NULL - unsafe_store!(shadowR,UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))).ref) + unsafe_store!( + shadowR, + UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))).ref, + ) end # TODO: don't inject the code multiple times for multiple calls @@ -370,10 +479,17 @@ end end # 2) Create activity, and annotate function spec - args, activity, overwritten, actives, kwtup, _ = enzyme_custom_setup_args(B, orig, gutils, mi, RealRt, #=reverse=#false, isKWCall) - RT, needsPrimal, needsShadow, origNeedsPrimal = enzyme_custom_setup_ret(gutils, orig, mi, RealRt, B) - - C = EnzymeRules.FwdConfig{Bool(needsPrimal), Bool(needsShadow), Int(width), get_runtime_activity(gutils)} + args, activity, overwritten, actives, kwtup, _ = + enzyme_custom_setup_args(B, orig, gutils, mi, RealRt, false, isKWCall) #=reverse=# + RT, needsPrimal, needsShadow, origNeedsPrimal = + enzyme_custom_setup_ret(gutils, orig, mi, RealRt, B) + + C = EnzymeRules.FwdConfig{ + Bool(needsPrimal), + Bool(needsShadow), + Int(width), + get_runtime_activity(gutils), + } alloctx = LLVM.IRBuilder() position!(alloctx, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils))) @@ -413,7 +529,7 @@ end llvmf = nested_codegen!(mode, mod, kwfunc, TT, world) fwd_RT = Core.Compiler.return_type(kwfunc, TT, world) else - TT = Tuple{typeof(world), typeof(kwfunc), TT.parameters...} + TT = Tuple{typeof(world),typeof(kwfunc),TT.parameters...} llvmf = nested_codegen!(mode, mod, custom_rule_method_error, TT, world) pushfirst!(args, LLVM.ConstantInt(world)) fwd_RT = Union{} @@ -424,16 +540,23 @@ end llvmf = nested_codegen!(mode, mod, EnzymeRules.forward, TT, world) fwd_RT = Core.Compiler.return_type(EnzymeRules.forward, TT, world) else - TT = Tuple{typeof(world), typeof(EnzymeRules.forward), TT.parameters...} + TT = Tuple{typeof(world),typeof(EnzymeRules.forward),TT.parameters...} llvmf = nested_codegen!(mode, mod, custom_rule_method_error, TT, world) pushfirst!(args, LLVM.ConstantInt(world)) fwd_RT = Union{} end end - + push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0)) - swiftself = any(any(map(k->kind(k)==kind(EnumAttribute("swiftself")), collect(parameter_attributes(llvmf, i)))) for i in 1:length(collect(parameters(llvmf)))) + swiftself = any( + any( + map( + k -> kind(k) == kind(EnumAttribute("swiftself")), + collect(parameter_attributes(llvmf, i)), + ), + ) for i = 1:length(collect(parameters(llvmf))) + ) if swiftself pushfirst!(reinsert_gcmarker!(fn, B)) end @@ -452,7 +575,16 @@ end end if length(args) != length(parameters(llvmf)) - GPUCompiler.@safe_error "Calling convention mismatch", args, llvmf, string(value_type(llvmf)), orig, isKWCall, kwtup, TT, sret, returnRoots + GPUCompiler.@safe_error "Calling convention mismatch", + args, + llvmf, + string(value_type(llvmf)), + orig, + isKWCall, + kwtup, + TT, + sret, + returnRoots return false end @@ -471,7 +603,12 @@ end debug_from_orig!(gutils, res, orig) callconv!(res, callconv(llvmf)) - hasNoRet = any(map(k->kind(k)==kind(EnumAttribute("noreturn")), collect(function_attributes(llvmf)))) + hasNoRet = any( + map( + k -> kind(k) == kind(EnumAttribute("noreturn")), + collect(function_attributes(llvmf)), + ), + ) if hasNoRet return false @@ -488,7 +625,11 @@ end end if swiftself attr = EnumAttribute("swiftself") - LLVM.API.LLVMAddCallSiteAttribute(res, LLVM.API.LLVMAttributeIndex(1+(sret !== nothing)), attr) + LLVM.API.LLVMAddCallSiteAttribute( + res, + LLVM.API.LLVMAttributeIndex(1 + (sret !== nothing)), + attr, + ) end shadowV = C_NULL @@ -497,7 +638,18 @@ end if RT <: Const if needsPrimal if RealRt != fwd_RT - emit_error(B, orig, "Enzyme: incorrect return type of const primal-only forward custom rule - $C "*(string(RT))*" "*string(activity)*" want just return type "*string(RealRt)*" found "*string(fwd_RT)) + emit_error( + B, + orig, + "Enzyme: incorrect return type of const primal-only forward custom rule - $C " * + (string(RT)) * + " " * + string(activity) * + " want just return type " * + string(RealRt) * + " found " * + string(fwd_RT), + ) return false end if get_return_info(RealRt)[2] !== nothing @@ -508,7 +660,16 @@ end end else if Nothing != fwd_RT - emit_error(B, orig, "Enzyme: incorrect return type of const no-primal forward custom rule - $C "*(string(RT))*" "*string(activity)*" want just return type Nothing found "*string(fwd_RT)) + emit_error( + B, + orig, + "Enzyme: incorrect return type of const no-primal forward custom rule - $C " * + (string(RT)) * + " " * + string(activity) * + " want just return type Nothing found " * + string(fwd_RT), + ) return false end end @@ -516,17 +677,28 @@ end if !needsPrimal ST = RealRt if width != 1 - ST = NTuple{Int(width), ST} + ST = NTuple{Int(width),ST} end if ST != fwd_RT - emit_error(B, orig, "Enzyme: incorrect return type of shadow-only forward custom rule - $C "*(string(RT))*" "*string(activity)*" want just shadow type "*string(ST)*" found "*string(fwd_RT)) + emit_error( + B, + orig, + "Enzyme: incorrect return type of shadow-only forward custom rule - $C " * + (string(RT)) * + " " * + string(activity) * + " want just shadow type " * + string(ST) * + " found " * + string(fwd_RT), + ) return false end if get_return_info(RealRt)[2] !== nothing dval_ptr = invert_pointer(gutils, operands(orig)[1], B) - for idx in 1:width - ev = (width == 1) ? dval : extract_value!(B, dval, idx-1) - pev = (width == 1) ? dval_ptr : extract_value!(B, dval_ptr, idx-1) + for idx = 1:width + ev = (width == 1) ? dval : extract_value!(B, dval, idx - 1) + pev = (width == 1) ? dval_ptr : extract_value!(B, dval_ptr, idx - 1) store!(B, res, pev) end else @@ -536,21 +708,32 @@ end ST = if width == 1 Duplicated{RealRt} else - BatchDuplicated{RealRt, Int(width)} + BatchDuplicated{RealRt,Int(width)} end if ST != fwd_RT - emit_error(B, orig, "Enzyme: incorrect return type of prima/shadow forward custom rule - $C "*(string(RT))*" "*string(activity)*" want just shadow type "*string(ST)*" found "*string(fwd_RT)) + emit_error( + B, + orig, + "Enzyme: incorrect return type of prima/shadow forward custom rule - $C " * + (string(RT)) * + " " * + string(activity) * + " want just shadow type " * + string(ST) * + " found " * + string(fwd_RT), + ) return false end if get_return_info(RealRt)[2] !== nothing val = new_from_original(gutils, operands(orig)[1]) store!(B, extract_value!(B, res, 0), val) - + dval_ptr = invert_pointer(gutils, operands(orig)[1], B) dval = extract_value!(B, res, 1) - for idx in 1:width - ev = (width == 1) ? dval : extract_value!(B, dval, idx-1) - pev = (width == 1) ? dval_ptr : extract_value!(B, dval_ptr, idx-1) + for idx = 1:width + ev = (width == 1) ? dval : extract_value!(B, dval, idx - 1) + pev = (width == 1) ? dval_ptr : extract_value!(B, dval_ptr, idx - 1) store!(B, ev, pev) end else @@ -570,7 +753,11 @@ end else ni = new_from_original(gutils, orig) if value_type(ni) != LLVM.VoidType() - API.EnzymeGradientUtilsReplaceAWithB(gutils, ni, LLVM.UndefValue(value_type(ni))) + API.EnzymeGradientUtilsReplaceAWithB( + gutils, + ni, + LLVM.UndefValue(value_type(ni)), + ) end API.EnzymeGradientUtilsErase(gutils, ni) end @@ -578,7 +765,12 @@ end return false end -@inline function aug_fwd_mi(orig::LLVM.CallInst, gutils::GradientUtils, forward=false, B=nothing) +@inline function aug_fwd_mi( + orig::LLVM.CallInst, + gutils::GradientUtils, + forward = false, + B = nothing, +) width = get_width(gutils) # 1) extract out the MI from attributes @@ -586,8 +778,10 @@ end isKWCall = isKWCallSignature(mi.specTypes) # 2) Create activity, and annotate function spec - args, activity, overwritten, actives, kwtup, mixeds = enzyme_custom_setup_args(B, orig, gutils, mi, RealRt, #=reverse=#!forward, isKWCall) - RT, needsPrimal, needsShadow, origNeedsPrimal = enzyme_custom_setup_ret(gutils, orig, mi, RealRt, B) + args, activity, overwritten, actives, kwtup, mixeds = + enzyme_custom_setup_args(B, orig, gutils, mi, RealRt, !forward, isKWCall) #=reverse=# + RT, needsPrimal, needsShadow, origNeedsPrimal = + enzyme_custom_setup_ret(gutils, orig, mi, RealRt, B) needsShadowJL = if RT <: Active false @@ -598,8 +792,14 @@ end fn = LLVM.parent(LLVM.parent(orig)) world = enzyme_extract_world(fn) - C = EnzymeRules.RevConfig{Bool(needsPrimal), Bool(needsShadowJL), Int(width), overwritten, get_runtime_activity(gutils)} - + C = EnzymeRules.RevConfig{ + Bool(needsPrimal), + Bool(needsShadowJL), + Int(width), + overwritten, + get_runtime_activity(gutils), + } + mode = get_mode(gutils) ami = nothing @@ -617,10 +817,14 @@ end kwfunc = Core.kwfunc(EnzymeRules.augmented_primal) try ami = GPUCompiler.methodinstance(Core.Typeof(kwfunc), augprimal_TT, world) - @safe_debug "Applying custom augmented_primal rule (kwcall)" TT=augprimal_TT + @safe_debug "Applying custom augmented_primal rule (kwcall)" TT = augprimal_TT catch e - augprimal_TT = Tuple{typeof(world), typeof(kwfunc), augprimal_TT.parameters...} - ami = GPUCompiler.methodinstance(typeof(custom_rule_method_error), augprimal_TT, world) + augprimal_TT = Tuple{typeof(world),typeof(kwfunc),augprimal_TT.parameters...} + ami = GPUCompiler.methodinstance( + typeof(custom_rule_method_error), + augprimal_TT, + world, + ) if forward pushfirst!(args, LLVM.ConstantInt(world)) end @@ -632,24 +836,57 @@ end augprimal_TT = Tuple{augprimal_tt...} try - ami = GPUCompiler.methodinstance(Core.Typeof(EnzymeRules.augmented_primal), augprimal_TT, world) - @safe_debug "Applying custom augmented_primal rule" TT=augprimal_TT + ami = GPUCompiler.methodinstance( + Core.Typeof(EnzymeRules.augmented_primal), + augprimal_TT, + world, + ) + @safe_debug "Applying custom augmented_primal rule" TT = augprimal_TT catch e - augprimal_TT = Tuple{typeof(world), typeof(EnzymeRules.augmented_primal), augprimal_TT.parameters...} - ami = GPUCompiler.methodinstance(typeof(custom_rule_method_error), augprimal_TT, world) + augprimal_TT = Tuple{ + typeof(world), + typeof(EnzymeRules.augmented_primal), + augprimal_TT.parameters..., + } + ami = GPUCompiler.methodinstance( + typeof(custom_rule_method_error), + augprimal_TT, + world, + ) if forward pushfirst!(args, LLVM.ConstantInt(world)) end end end - return ami, augprimal_TT, (args, activity, overwritten, actives, kwtup, RT, needsPrimal, needsShadow, origNeedsPrimal, mixeds) + return ami, + augprimal_TT, + ( + args, + activity, + overwritten, + actives, + kwtup, + RT, + needsPrimal, + needsShadow, + origNeedsPrimal, + mixeds, + ) end @inline function has_aug_fwd_rule(orig, gutils) return aug_fwd_mi(orig, gutils)[1] !== nothing end -@register_rev function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, normalR, shadowR, tape)::LLVM.API.LLVMValueRef +@register_rev function enzyme_custom_common_rev( + forward::Bool, + B, + orig::LLVM.CallInst, + gutils, + normalR, + shadowR, + tape, +)::LLVM.API.LLVMValueRef ctx = LLVM.context(orig) @@ -657,7 +894,7 @@ end shadowType = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) if shadowR != C_NULL - unsafe_store!(shadowR,UndefValue(shadowType).ref) + unsafe_store!(shadowR, UndefValue(shadowType).ref) end # TODO: don't inject the code multiple times for multiple calls @@ -668,7 +905,16 @@ end # 2) Create activity, and annotate function spec ami, augprimal_TT, setup = aug_fwd_mi(orig, gutils, forward, B) - args, activity, overwritten, actives, kwtup, RT, needsPrimal, needsShadow, origNeedsPrimal, mixeds = setup + args, + activity, + overwritten, + actives, + kwtup, + RT, + needsPrimal, + needsShadow, + origNeedsPrimal, + mixeds = setup needsShadowJL = if RT <: Active false @@ -676,7 +922,13 @@ end needsShadow end - C = EnzymeRules.RevConfig{Bool(needsPrimal), Bool(needsShadowJL), Int(width), overwritten, get_runtime_activity(gutils)} + C = EnzymeRules.RevConfig{ + Bool(needsPrimal), + Bool(needsShadowJL), + Int(width), + overwritten, + get_runtime_activity(gutils), + } alloctx = LLVM.IRBuilder() position!(alloctx, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils))) @@ -690,10 +942,24 @@ end @assert ami !== nothing target = DefaultCompilerTarget() params = PrimalCompilerParams(mode) - aug_RT = something(Core.Compiler.typeinf_type(GPUCompiler.get_interpreter(CompilerJob(ami, CompilerConfig(target, params; kernel=false), world)), ami.def, ami.specTypes, ami.sparam_vals), Any) + aug_RT = something( + Core.Compiler.typeinf_type( + GPUCompiler.get_interpreter( + CompilerJob(ami, CompilerConfig(target, params; kernel = false), world), + ), + ami.def, + ami.specTypes, + ami.sparam_vals, + ), + Any, + ) if kwtup !== nothing && kwtup <: Duplicated @safe_debug "Non-constant keyword argument found for " augprimal_TT - emit_error(B, orig, "Enzyme: Non-constant keyword argument found for " * string(augprimal_TT)) + emit_error( + B, + orig, + "Enzyme: Non-constant keyword argument found for " * string(augprimal_TT), + ) return C_NULL end @@ -702,15 +968,26 @@ end TapeT = Nothing - if (aug_RT <: EnzymeRules.AugmentedReturn || aug_RT <: EnzymeRules.AugmentedReturnFlexShadow) && !(aug_RT isa UnionAll) && !(aug_RT isa Union) && !(aug_RT === Union{}) + if ( + aug_RT <: EnzymeRules.AugmentedReturn || + aug_RT <: EnzymeRules.AugmentedReturnFlexShadow + ) && + !(aug_RT isa UnionAll) && + !(aug_RT isa Union) && + !(aug_RT === Union{}) TapeT = EnzymeRules.tape_type(aug_RT) - elseif (aug_RT isa UnionAll) && (aug_RT <: EnzymeRules.AugmentedReturn) && aug_RT.body.name == EnzymeCore.EnzymeRules.AugmentedReturn.body.body.body.name + elseif (aug_RT isa UnionAll) && + (aug_RT <: EnzymeRules.AugmentedReturn) && + aug_RT.body.name == EnzymeCore.EnzymeRules.AugmentedReturn.body.body.body.name if aug_RT.body.parameters[3] isa TypeVar TapeT = aug_RT.body.parameters[3].ub else TapeT = Any end - elseif (aug_RT isa UnionAll) && (aug_RT <: EnzymeRules.AugmentedReturnFlexShadow) && aug_RT.body.name == EnzymeCore.EnzymeRules.AugmentedReturnFlexShadow.body.body.body.name + elseif (aug_RT isa UnionAll) && + (aug_RT <: EnzymeRules.AugmentedReturnFlexShadow) && + aug_RT.body.name == + EnzymeCore.EnzymeRules.AugmentedReturnFlexShadow.body.body.body.name if aug_RT.body.parameters[3] isa TypeVar TapeT = aug_RT.body.parameters[3].ub else @@ -749,11 +1026,11 @@ end if isKWCall rkwfunc = Core.kwfunc(EnzymeRules.reverse) if EnzymeRules.isapplicable(rkwfunc, rev_TT; world) - @safe_debug "Applying custom reverse rule (kwcall)" TT=rev_TT + @safe_debug "Applying custom reverse rule (kwcall)" TT = rev_TT llvmf = nested_codegen!(mode, mod, rkwfunc, rev_TT, world) rev_RT = Core.Compiler.return_type(rkwfunc, rev_TT, world) else - rev_TT = Tuple{typeof(world), typeof(rkwfunc), rev_TT.parameters...} + rev_TT = Tuple{typeof(world),typeof(rkwfunc),rev_TT.parameters...} llvmf = nested_codegen!(mode, mod, custom_rule_method_error, rev_TT, world) pushfirst!(args, LLVM.ConstantInt(world)) rev_RT = Union{} @@ -761,11 +1038,12 @@ end end else if EnzymeRules.isapplicable(EnzymeRules.reverse, rev_TT; world) - @safe_debug "Applying custom reverse rule" TT=rev_TT + @safe_debug "Applying custom reverse rule" TT = rev_TT llvmf = nested_codegen!(mode, mod, EnzymeRules.reverse, rev_TT, world) rev_RT = Core.Compiler.return_type(EnzymeRules.reverse, rev_TT, world) else - rev_TT = Tuple{typeof(world), typeof(EnzymeRules.reverse), rev_TT.parameters...} + rev_TT = + Tuple{typeof(world),typeof(EnzymeRules.reverse),rev_TT.parameters...} llvmf = nested_codegen!(mode, mod, custom_rule_method_error, rev_TT, world) pushfirst!(args, LLVM.ConstantInt(world)) rev_RT = Union{} @@ -780,7 +1058,7 @@ end tapeV = C_NULL if forward && needsTape - tapeV = LLVM.UndefValue(convert(LLVMType, TapeT; allow_boxed=true)).ref + tapeV = LLVM.UndefValue(convert(LLVMType, TapeT; allow_boxed = true)).ref end # if !forward @@ -796,28 +1074,59 @@ end # llvmf = nested_codegen!(mode, mod, rev_func, Tuple{argTys...}, world) # end - swiftself = any(any(map(k->kind(k)==kind(EnumAttribute("swiftself")), collect(parameter_attributes(llvmf, i)))) for i in 1:length(collect(parameters(llvmf)))) + swiftself = any( + any( + map( + k -> kind(k) == kind(EnumAttribute("swiftself")), + collect(parameter_attributes(llvmf, i)), + ), + ) for i = 1:length(collect(parameters(llvmf))) + ) miRT = enzyme_custom_extract_mi(llvmf)[2] _, sret, returnRoots = get_return_info(miRT) sret_union = is_sret_union(miRT) - if sret_union - emit_error(B, orig, "Enzyme: Augmented forward pass custom rule " * string(augprimal_TT) * " had a union sret of type "*string(miRT)*" which is not currently supported") + if sret_union + emit_error( + B, + orig, + "Enzyme: Augmented forward pass custom rule " * + string(augprimal_TT) * + " had a union sret of type " * + string(miRT) * + " which is not currently supported", + ) return tapeV end if !forward - funcTy = rev_TT.parameters[isKWCall ? 4 : 2] + funcTy = rev_TT.parameters[isKWCall ? 4 : 2] if needsTape @assert tape != C_NULL - tape_idx = 1+(kwtup!==nothing && !isghostty(kwtup)) + !isghostty(funcTy) + (!applicablefn) - trueidx = tape_idx+(sret !== nothing)+(returnRoots !== nothing)+swiftself + (RT <: Active) + tape_idx = + 1 + + (kwtup !== nothing && !isghostty(kwtup)) + + !isghostty(funcTy) + + (!applicablefn) + trueidx = + tape_idx + + (sret !== nothing) + + (returnRoots !== nothing) + + swiftself + + (RT <: Active) innerTy = value_type(parameters(llvmf)[trueidx]) if innerTy != value_type(tape) - if isabstracttype(TapeT) || TapeT isa UnionAll || TapeT == Tuple || TapeT.layout == C_NULL || TapeT == Array + if isabstracttype(TapeT) || + TapeT isa UnionAll || + TapeT == Tuple || + TapeT.layout == C_NULL || + TapeT == Array msg = sprint() do io - println(io, "Enzyme : mismatch between innerTy $innerTy and tape type $(value_type(tape))") + println( + io, + "Enzyme : mismatch between innerTy $innerTy and tape type $(value_type(tape))", + ) println(io, "tape_idx=", tape_idx) println(io, "true_idx=", trueidx) println(io, "isKWCall=", isKWCall) @@ -840,7 +1149,7 @@ end end throw(AssertionError(msg)) end - llty = convert(LLVMType, TapeT; allow_boxed=true) + llty = convert(LLVMType, TapeT; allow_boxed = true) al0 = al = emit_allocobj!(B, TapeT) al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al)))) store!(B, tape, al) @@ -855,18 +1164,18 @@ end llty = convert(LLVMType, RT) - if API.EnzymeGradientUtilsGetDiffeType(gutils, orig, #=isforeign=#false) == API.DFT_OUT_DIFF + if API.EnzymeGradientUtilsGetDiffeType(gutils, orig, false) == API.DFT_OUT_DIFF #=isforeign=# val = LLVM.Value(API.EnzymeGradientUtilsDiffe(gutils, orig, B)) API.EnzymeGradientUtilsSetDiffe(gutils, orig, LLVM.null(value_type(val)), B) else - llety = convert(LLVMType, eltype(RT); allow_boxed=true) - ptr_val = invert_pointer(gutils, operands(orig)[1 + !isghostty(funcTy)], B) + llety = convert(LLVMType, eltype(RT); allow_boxed = true) + ptr_val = invert_pointer(gutils, operands(orig)[1+!isghostty(funcTy)], B) val = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, llety))) - for idx in 1:width - ev = (width == 1) ? ptr_val : extract_value!(B, ptr_val, idx-1) + for idx = 1:width + ev = (width == 1) ? ptr_val : extract_value!(B, ptr_val, idx - 1) ld = load!(B, llety, ev) store!(B, LLVM.null(llety), ev) - val = (width == 1 ) ? ld : insert_value!(B, val, ld, idx-1) + val = (width == 1) ? ld : insert_value!(B, val, ld, idx - 1) end end @@ -874,13 +1183,28 @@ end al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al)))) al = addrspacecast!(B, al, LLVM.PointerType(llty, Derived)) - ptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 0)]) + ptr = inbounds_gep!( + B, + llty, + al, + [ + LLVM.ConstantInt(LLVM.IntType(64), 0), + LLVM.ConstantInt(LLVM.IntType(32), 0), + ], + ) store!(B, val, ptr) if any_jltypes(llty) emit_writebarrier!(B, get_julia_inner_types(B, al0, val)) end - insert!(args, 1+(!isghostty(funcTy))+(kwtup!==nothing && !isghostty(kwtup)) + (!applicablefn), al) + insert!( + args, + 1 + + (!isghostty(funcTy)) + + (kwtup !== nothing && !isghostty(kwtup)) + + (!applicablefn), + al, + ) end end @@ -902,16 +1226,26 @@ end end if length(args) != length(parameters(llvmf)) - GPUCompiler.@safe_error "Calling convention mismatch", args, llvmf, orig, isKWCall, kwtup, augprimal_TT, rev_TT, fn, sret, returnRoots + GPUCompiler.@safe_error "Calling convention mismatch", + args, + llvmf, + orig, + isKWCall, + kwtup, + augprimal_TT, + rev_TT, + fn, + sret, + returnRoots return tapeV end - + T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - for i in 1:length(args) - party = value_type(parameters(llvmf)[i]) + for i = 1:length(args) + party = value_type(parameters(llvmf)[i]) if value_type(args[i]) != party if party == T_prjlvalue while true @@ -939,7 +1273,15 @@ end println(io, "args[i] = ", args[i]) println(io, "party = ", party) end - args[i] = calling_conv_fixup(B, args[i], party, LLVM.UndefValue(party), Cuint[], Cuint[], msg) + args[i] = calling_conv_fixup( + B, + args[i], + party, + LLVM.UndefValue(party), + Cuint[], + Cuint[], + msg, + ) end res = LLVM.call!(B, LLVM.function_type(llvmf), llvmf, args) @@ -947,7 +1289,12 @@ end debug_from_orig!(gutils, res, orig) callconv!(res, callconv(llvmf)) - hasNoRet = any(map(k->kind(k)==kind(EnumAttribute("noreturn")), collect(function_attributes(llvmf)))) + hasNoRet = any( + map( + k -> kind(k) == kind(EnumAttribute("noreturn")), + collect(function_attributes(llvmf)), + ), + ) if hasNoRet return tapeV @@ -959,13 +1306,21 @@ end else attr = EnumAttribute("sret") end - LLVM.API.LLVMAddCallSiteAttribute(res, LLVM.API.LLVMAttributeIndex(1+swiftself), attr) + LLVM.API.LLVMAddCallSiteAttribute( + res, + LLVM.API.LLVMAttributeIndex(1 + swiftself), + attr, + ) res = load!(B, eltype(value_type(parameters(llvmf)[1+swiftself])), sret) API.SetMustCache!(res) end if swiftself attr = EnumAttribute("swiftself") - LLVM.API.LLVMAddCallSiteAttribute(res, LLVM.API.LLVMAttributeIndex(1+(sret !== nothing)+(returnRoots !== nothing)), attr) + LLVM.API.LLVMAddCallSiteAttribute( + res, + LLVM.API.LLVMAttributeIndex(1 + (sret !== nothing) + (returnRoots !== nothing)), + attr, + ) end shadowV = C_NULL @@ -975,35 +1330,86 @@ end if forward ShadT = RealRt if width != 1 - ShadT = NTuple{Int(width), RealRt} + ShadT = NTuple{Int(width),RealRt} end - ST = EnzymeRules.AugmentedReturn{needsPrimal ? RealRt : Nothing, needsShadowJL ? ShadT : Nothing, TapeT} + ST = EnzymeRules.AugmentedReturn{ + needsPrimal ? RealRt : Nothing, + needsShadowJL ? ShadT : Nothing, + TapeT, + } if aug_RT != ST if aug_RT <: EnzymeRules.AugmentedReturnFlexShadow - if convert(LLVMType, EnzymeRules.shadow_type(aug_RT); allow_boxed=true) != - convert(LLVMType, EnzymeRules.shadow_type(ST) ; allow_boxed=true) - emit_error(B, orig, "Enzyme: Augmented forward pass custom rule " * string(augprimal_TT) * " flex shadow ABI return type mismatch, expected "*string(ST)*" found "* string(aug_RT)) + if convert(LLVMType, EnzymeRules.shadow_type(aug_RT); allow_boxed = true) != + convert(LLVMType, EnzymeRules.shadow_type(ST); allow_boxed = true) + emit_error( + B, + orig, + "Enzyme: Augmented forward pass custom rule " * + string(augprimal_TT) * + " flex shadow ABI return type mismatch, expected " * + string(ST) * + " found " * + string(aug_RT), + ) return tapeV end - ST = EnzymeRules.AugmentedReturnFlexShadow{needsPrimal ? RealRt : Nothing, needsShadowJL ? EnzymeRules.shadow_type(aug_RT) : Nothing, TapeT} + ST = EnzymeRules.AugmentedReturnFlexShadow{ + needsPrimal ? RealRt : Nothing, + needsShadowJL ? EnzymeRules.shadow_type(aug_RT) : Nothing, + TapeT, + } end end abstract = false if aug_RT != ST - abs = (EnzymeRules.AugmentedReturn{needsPrimal ? RealRt : Nothing, needsShadowJL ? ShadT : Nothing, T} where T) + abs = ( + EnzymeRules.AugmentedReturn{ + needsPrimal ? RealRt : Nothing, + needsShadowJL ? ShadT : Nothing, + T, + } where {T} + ) if aug_RT <: abs abstract = true else - ST = EnzymeRules.AugmentedReturn{needsPrimal ? RealRt : Nothing, needsShadowJL ? ShadT : Nothing, Any} - emit_error(B, orig, "Enzyme: Augmented forward pass custom rule " * string(augprimal_TT) * " return type mismatch, expected "*string(ST)*" found "* string(aug_RT)) + ST = EnzymeRules.AugmentedReturn{ + needsPrimal ? RealRt : Nothing, + needsShadowJL ? ShadT : Nothing, + Any, + } + emit_error( + B, + orig, + "Enzyme: Augmented forward pass custom rule " * + string(augprimal_TT) * + " return type mismatch, expected " * + string(ST) * + " found " * + string(aug_RT), + ) return tapeV end end resV = if abstract - StructTy = convert(LLVMType, EnzymeRules.AugmentedReturn{needsPrimal ? RealRt : Nothing, needsShadowJL ? ShadT : Nothing, Nothing}) + StructTy = convert( + LLVMType, + EnzymeRules.AugmentedReturn{ + needsPrimal ? RealRt : Nothing, + needsShadowJL ? ShadT : Nothing, + Nothing, + }, + ) if StructTy != LLVM.VoidType() - load!(B, StructTy, bitcast!(B, res, LLVM.PointerType(StructTy, addrspace(value_type(res))))) + load!( + B, + StructTy, + bitcast!( + B, + res, + LLVM.PointerType(StructTy, addrspace(value_type(res))), + ), + ) else res end @@ -1022,7 +1428,7 @@ end @assert value_type(normalV) == value_type(orig) normalV = normalV.ref end - idx+=1 + idx += 1 end if needsShadow if needsShadowJL @@ -1031,10 +1437,11 @@ end if get_return_info(RealRt)[2] !== nothing dval = invert_pointer(gutils, operands(orig)[1], B) - for idx in 1:width - to_store = (width == 1) ? shadowV : extract_value!(B, shadowV, idx-1) + for idx = 1:width + to_store = + (width == 1) ? shadowV : extract_value!(B, shadowV, idx - 1) - store_ptr = (width == 1) ? dval : extract_value!(B, dval, idx-1) + store_ptr = (width == 1) ? dval : extract_value!(B, dval, idx - 1) store!(B, to_store, store_ptr) end @@ -1043,7 +1450,7 @@ end @assert value_type(shadowV) == shadowType shadowV = shadowV.ref end - idx+=1 + idx += 1 end end if needsTape @@ -1052,23 +1459,37 @@ end else extract_value!(B, res, idx).ref end - idx+=1 + idx += 1 end else - Tys = (A <: Active ? (width == 1 ? eltype(A) : NTuple{Int(width), eltype(A)}) : Nothing for A in activity[2+isKWCall:end]) + Tys = ( + A <: Active ? (width == 1 ? eltype(A) : NTuple{Int(width),eltype(A)}) : Nothing for A in activity[2+isKWCall:end] + ) ST = Tuple{Tys...} if rev_RT != ST - emit_error(B, orig, "Enzyme: Reverse pass custom rule " * string(rev_TT) * " return type mismatch, expected "*string(ST)*" found "* string(rev_RT)) + emit_error( + B, + orig, + "Enzyme: Reverse pass custom rule " * + string(rev_TT) * + " return type mismatch, expected " * + string(ST) * + " found " * + string(rev_RT), + ) return tapeV end - if length(actives) >= 1 && !isa(value_type(res), LLVM.StructType) && !isa(value_type(res), LLVM.ArrayType) - GPUCompiler.@safe_error "Shadow arg calling convention mismatch found return ", res + if length(actives) >= 1 && + !isa(value_type(res), LLVM.StructType) && + !isa(value_type(res), LLVM.ArrayType) + GPUCompiler.@safe_error "Shadow arg calling convention mismatch found return ", + res return tapeV end idx = 0 dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(orig))))) - Tys2 = (eltype(A) for A in activity[(2 + isKWCall):end] if A <: Active) + Tys2 = (eltype(A) for A in activity[(2+isKWCall):end] if A <: Active) seen = TypeTreeTable() for (v, Ty) in zip(actives, Tys2) TT = typetree(Ty, ctx, dl, seen) @@ -1079,24 +1500,35 @@ end size = sizeof(Ty) align = 0 premask = C_NULL - API.EnzymeGradientUtilsAddToInvertedPointerDiffeTT(gutils, orig, C_NULL, TT, size, v, ext, B, align, premask) + API.EnzymeGradientUtilsAddToInvertedPointerDiffeTT( + gutils, + orig, + C_NULL, + TT, + size, + v, + ext, + B, + align, + premask, + ) else @assert value_type(ext) == shadowVType API.EnzymeGradientUtilsAddToDiffe(gutils, v, ext, B, Typ) end - idx+=1 + idx += 1 end for (ptr_val, argTyp, refal) in mixeds RefTy = argTyp if width != 1 - RefTy = NTuple{Int(width), RefTy} + RefTy = NTuple{Int(width),RefTy} end curs = load!(B, convert(LLVMType, RefTy), refal) - for idx in 1:width - evp = (width == 1) ? ptr_val : extract_value!(B, ptr_val, idx-1) - evcur = (width == 1) ? curs : extract_value!(B, curs, idx-1) + for idx = 1:width + evp = (width == 1) ? ptr_val : extract_value!(B, ptr_val, idx - 1) + evcur = (width == 1) ? curs : extract_value!(B, curs, idx - 1) store_nonjl_types!(B, evcur, evp) end end @@ -1121,10 +1553,12 @@ end @register_aug function enzyme_custom_augfwd(B, orig, gutils, normalR, shadowR, tapeR) - if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) && !has_aug_fwd_rule(orig, gutils) + if is_constant_value(gutils, orig) && + is_constant_inst(gutils, orig) && + !has_aug_fwd_rule(orig, gutils) return true end - tape = enzyme_custom_common_rev(#=forward=#true, B, orig, gutils, normalR, shadowR, #=tape=#nothing) + tape = enzyme_custom_common_rev(true, B, orig, gutils, normalR, shadowR, nothing) #=tape=# if tape != C_NULL unsafe_store!(tapeR, tape) end @@ -1132,34 +1566,41 @@ end end @register_rev function enzyme_custom_rev(B, orig, gutils, tape) - if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) && !has_aug_fwd_rule(orig, gutils) + if is_constant_value(gutils, orig) && + is_constant_inst(gutils, orig) && + !has_aug_fwd_rule(orig, gutils) return end - enzyme_custom_common_rev(#=forward=#false, B, orig, gutils, #=normalR=#C_NULL, #=shadowR=#C_NULL, #=tape=#tape) + enzyme_custom_common_rev(false, B, orig, gutils, C_NULL, C_NULL, tape) #=tape=# return nothing end @register_diffuse function enzyme_custom_diffuse(orig, gutils, val, isshadow, mode) # use default - if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) && !has_aug_fwd_rule(orig, gutils) + if is_constant_value(gutils, orig) && + is_constant_inst(gutils, orig) && + !has_aug_fwd_rule(orig, gutils) return (false, true) end non_rooting_use = false fop = called_operand(orig)::LLVM.Function for (i, v) in enumerate(operands(orig)[1:end-1]) - if v == val - if !any(a->kind(a) == kind(StringAttribute("enzymejl_returnRoots")), collect(parameter_attributes(fop, i))) - non_rooting_use = true - break - end - end + if v == val + if !any( + a -> kind(a) == kind(StringAttribute("enzymejl_returnRoots")), + collect(parameter_attributes(fop, i)), + ) + non_rooting_use = true + break + end + end end - + # If the operand is just rooting, we don't need it and should override defaults if !non_rooting_use - return (false, false) + return (false, false) end - + # don't use default and always require the arg return (true, false) end diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index 01edec7118d..75bc4156544 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -1,4 +1,13 @@ -function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, iterate=false; func=true, mixed_or_active = false, reverse=false) +function setup_macro_wraps( + forwardMode::Bool, + N::Int, + Width::Int, + base = nothing, + iterate = false; + func = true, + mixed_or_active = false, + reverse = false, +) primargs = Union{Symbol,Expr}[] shadowargs = Union{Symbol,Expr}[] batchshadowargs = Vector{Union{Symbol,Expr}}[] @@ -8,7 +17,7 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, dfns = Union{Symbol,Expr}[:df] base_idx = 1 if func - for w in 2:Width + for w = 2:Width if base === nothing shad = Symbol("df_$w") t = Symbol("DF__$w*") @@ -22,7 +31,7 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, push!(dfns, shad) end end - for i in 1:N + for i = 1:N if base === nothing prim = Symbol("primal_$i") t = Symbol("PT_$i") @@ -37,7 +46,7 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, push!(primargs, prim) push!(primtypes, t) shadows = Union{Symbol,Expr}[] - for w in 1:Width + for w = 1:Width if base === nothing shad = Symbol("shadow_$(i)_$w") t = Symbol("ST_$(i)_$w") @@ -62,7 +71,7 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, wrapped = Expr[] modbetween = Expr[:(MB[1])] active_refs = Expr[] - for i in 1:N + for i = 1:N if iterate push!(modbetween, quote ntuple(Val(length($(primargs[i])))) do _ @@ -73,7 +82,7 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, end aref = Symbol("active_ref_$i") push!(active_refs, quote - $aref = active_reg_nothrow($(primtypes[i]), Val(nothing)); + $aref = active_reg_nothrow($(primtypes[i]), Val(nothing)) end) expr = if iterate if forwardMode @@ -83,34 +92,60 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, end else quote - iterate_unwrap_fwd_batchdup(Val($Width), $(primargs[i]), $(shadowargs[i])) + iterate_unwrap_fwd_batchdup( + Val($Width), + $(primargs[i]), + $(shadowargs[i]), + ) end end :( - if ActivityTup[$i+1] && !guaranteed_const($(primtypes[i])) - @assert $(primtypes[i]) !== DataType - $dupexpr - else - map(Const, $(primargs[i])) - end + if ActivityTup[$i+1] && !guaranteed_const($(primtypes[i])) + @assert $(primtypes[i]) !== DataType + $dupexpr + else + map(Const, $(primargs[i])) + end ) else mixexpr = if Width == 1 quote - iterate_unwrap_augfwd_mix(Val($reverse), refs, $(primargs[i]), $(shadowargs[i])) + iterate_unwrap_augfwd_mix( + Val($reverse), + refs, + $(primargs[i]), + $(shadowargs[i]), + ) end else quote - iterate_unwrap_augfwd_batchmix(Val($reverse), refs, Val($Width), $(primargs[i]), $(shadowargs[i])) + iterate_unwrap_augfwd_batchmix( + Val($reverse), + refs, + Val($Width), + $(primargs[i]), + $(shadowargs[i]), + ) end end dupexpr = if Width == 1 quote - iterate_unwrap_augfwd_dup(Val($reverse), refs, $(primargs[i]), $(shadowargs[i])) + iterate_unwrap_augfwd_dup( + Val($reverse), + refs, + $(primargs[i]), + $(shadowargs[i]), + ) end else quote - iterate_unwrap_augfwd_batchdup(Val($reverse), refs, Val($Width), $(primargs[i]), $(shadowargs[i])) + iterate_unwrap_augfwd_batchdup( + Val($reverse), + refs, + Val($Width), + $(primargs[i]), + $(shadowargs[i]), + ) end end :( @@ -132,7 +167,10 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, if forwardMode quote if ActivityTup[$i+1] && !guaranteed_const($(primtypes[i])) - $((Width == 1) ? :Duplicated : :BatchDuplicated)($(primargs[i]), $(shadowargs[i])) + $((Width == 1) ? :Duplicated : :BatchDuplicated)( + $(primargs[i]), + $(shadowargs[i]), + ) else Const($(primargs[i])) end @@ -144,9 +182,15 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, if $aref == ActiveState Active($(primargs[i])) elseif $aref == MixedState - $((Width == 1) ? :MixedDuplicated : :BatchMixedDuplicated)($(primargs[i]), $(shadowargs[i])) + $((Width == 1) ? :MixedDuplicated : :BatchMixedDuplicated)( + $(primargs[i]), + $(shadowargs[i]), + ) else - $((Width == 1) ? :Duplicated : :BatchDuplicated)($(primargs[i]), $(shadowargs[i])) + $((Width == 1) ? :Duplicated : :BatchDuplicated)( + $(primargs[i]), + $(shadowargs[i]), + ) end else Const($(primargs[i])) @@ -157,8 +201,10 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, push!(wrapped, expr) end - any_mixed = quote false end - for i in 1:N + any_mixed = quote + false + end + for i = 1:N aref = Symbol("active_ref_$i") if mixed_or_active any_mixed = :($any_mixed || $aref == MixedState || $aref == ActiveState) @@ -169,19 +215,27 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, push!(active_refs, quote any_mixed = $any_mixed end) - return primargs, shadowargs, primtypes, allargs, typeargs, wrapped, batchshadowargs, modbetween, active_refs + return primargs, + shadowargs, + primtypes, + allargs, + typeargs, + wrapped, + batchshadowargs, + modbetween, + active_refs end function body_runtime_generic_fwd(N, Width, wrapped, primtypes) - nnothing = Vector{Nothing}(undef, Width+1) - nres = Vector{Expr}(undef, Width+1) + nnothing = Vector{Nothing}(undef, Width + 1) + nres = Vector{Expr}(undef, Width + 1) fill!(nnothing, nothing) fill!(nres, :(res[1])) - ModifiedBetween = Vector{Bool}(undef, N+1) + ModifiedBetween = Vector{Bool}(undef, N + 1) fill!(ModifiedBetween, false) ElTypes = Vector{Expr}(undef, N) Types = Vector{Expr}(undef, N) - for i in 1:N + for i = 1:N @inbounds ElTypes[i] = :(eltype(Core.Typeof(args[$i]))) @inbounds Types[i] = :(Core.Typeof(args[$i])) end @@ -195,7 +249,7 @@ function body_runtime_generic_fwd(N, Width, wrapped, primtypes) :(Duplicated(f, df)) else fargs = [:df] - for i in 2:Width + for i = 2:Width push!(fargs, Symbol("df_$i")) end :(BatchDuplicated(f, ($(fargs...),))) @@ -203,7 +257,7 @@ function body_runtime_generic_fwd(N, Width, wrapped, primtypes) dupty = if Width == 1 :(Duplicated{FT}) else - :(BatchDuplicated{FT, $Width}) + :(BatchDuplicated{FT,$Width}) end return quote @@ -221,7 +275,7 @@ function body_runtime_generic_fwd(N, Width, wrapped, primtypes) end if $Width != 1 if annotation <: Duplicated - annotation = BatchDuplicated{rt, $Width} + annotation = BatchDuplicated{rt,$Width} end end @@ -233,7 +287,20 @@ function body_runtime_generic_fwd(N, Width, wrapped, primtypes) world = codegen_world_age(FT, tt) opt_mi = Val(world) - forward = thunk(opt_mi, dupClosure ? $dupty : Const{FT}, annotation, tt′, Val(API.DEM_ForwardMode), width, #=ModifiedBetween=#Val(($(ModifiedBetween...),)), #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity) + forward = thunk( + opt_mi, + dupClosure ? $dupty : Const{FT}, + annotation, + tt′, + Val(API.DEM_ForwardMode), + width, + Val(($(ModifiedBetween...),)), + Val(true), + Val(false), + FFIABI, + Val(false), + runtimeActivity, + ) #=erriffuncwritten=# res = forward(dupClosure ? $dup : Const(f), args...) @@ -253,44 +320,69 @@ function func_runtime_generic_fwd(N, Width) body = body_runtime_generic_fwd(N, Width, wrapped, primtypes) quote - function runtime_generic_fwd(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{$Width}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...)) where {ActivityTup, RuntimeActivity, ReturnType, F, DF, $(typeargs...)} + function runtime_generic_fwd( + activity::Type{Val{ActivityTup}}, + runtimeActivity::Val{RuntimeActivity}, + width::Val{$Width}, + RT::Val{ReturnType}, + f::F, + df::DF, + $(allargs...), + ) where {ActivityTup,RuntimeActivity,ReturnType,F,DF,$(typeargs...)} $body end end end -@generated function runtime_generic_fwd(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{Width}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, RuntimeActivity, Width, ReturnType, F, DF} - N = div(length(allargs)+2, Width+1)-1 +@generated function runtime_generic_fwd( + activity::Type{Val{ActivityTup}}, + runtimeActivity::Val{RuntimeActivity}, + width::Val{Width}, + RT::Val{ReturnType}, + f::F, + df::DF, + allargs..., +) where {ActivityTup,RuntimeActivity,Width,ReturnType,F,DF} + N = div(length(allargs) + 2, Width + 1) - 1 _, _, primtypes, _, _, wrapped, _, _, _ = setup_macro_wraps(true, N, Width, :allargs) return body_runtime_generic_fwd(N, Width, wrapped, primtypes) end function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) - nres = Vector{Symbol}(undef, Width+1) + nres = Vector{Symbol}(undef, Width + 1) fill!(nres, :origRet) nzeros = Vector{Expr}(undef, Width) fill!(nzeros, :(Ref(make_zero(origRet)))) - + ElTypes = Vector{Expr}(undef, N) MakeTypes = Vector{Expr}(undef, N) Types = Vector{Symbol}(undef, N) MixedTypes = Vector{Expr}(undef, N) - for i in 1:N + for i = 1:N @inbounds ElTypes[i] = :(eltype($(Symbol("type_$i")))) @inbounds MakeTypes[i] = :($(Symbol("type_$i")) = Core.Typeof(args[$i])) @inbounds Types[i] = Symbol("type_$i") - @inbounds MixedTypes[i] = :($(Symbol("active_ref_$i") == MixedState) ? Ref($(Symbol("type_$i"))) : $(Symbol("type_$i"))) + @inbounds MixedTypes[i] = :( + $(Symbol("active_ref_$i") == MixedState) ? Ref($(Symbol("type_$i"))) : + $(Symbol("type_$i")) + ) end ending = if Width == 1 quote if annotation <: MixedDuplicated shadow_return = initShadow - tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + tape = Tape{typeof(internal_tape),typeof(shadow_return),resT}( + internal_tape, + shadow_return, + ) return ReturnType((origRet, shadow_return, tape)) else shadow_return = nothing - tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + tape = Tape{typeof(internal_tape),typeof(shadow_return),resT}( + internal_tape, + shadow_return, + ) return ReturnType((origRet, initShadow, tape)) end end @@ -298,33 +390,39 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) quote if annotation <: BatchMixedDuplicated shadow_return = (initShadow...,) - tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + tape = Tape{typeof(internal_tape),typeof(shadow_return),resT}( + internal_tape, + shadow_return, + ) return ReturnType((origRet, initShadow..., tape)) else shadow_return = nothing - tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + tape = Tape{typeof(internal_tape),typeof(shadow_return),resT}( + internal_tape, + shadow_return, + ) return ReturnType((origRet, initShadow..., tape)) end end end - + shadowretinit = if Width == 1 :(Ref(make_zero(origRet))) else :(($(nzeros...),)) end - + shadowretret = if Width == 1 :(return ReturnType((origRet, shadow_return, tape))) else :(return ReturnType((origRet, shadow_return..., tape))) end - + dup = if Width == 1 :(Duplicated(f, df)) else fargs = [:df] - for i in 2:Width + for i = 2:Width push!(fargs, Symbol("df_$i")) end :(BatchDuplicated(f, ($(fargs...),))) @@ -332,14 +430,14 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) dupty = if Width == 1 :(Duplicated{FT}) else - :(BatchDuplicated{FT, $Width}) + :(BatchDuplicated{FT,$Width}) end return quote $(active_refs...) args = ($(wrapped...),) $(MakeTypes...) - + FT = Core.Typeof(f) dupClosure0 = if ActivityTup[1] !guaranteed_const(FT) @@ -352,18 +450,29 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) annotationA = if $Width != 1 && annotation0 <: Duplicated - BatchDuplicated{rt, $Width} + BatchDuplicated{rt,$Width} elseif $Width != 1 && annotation0 <: MixedDuplicated - BatchMixedDuplicated{rt, $Width} + BatchMixedDuplicated{rt,$Width} else annotation0 end world = codegen_world_age(FT, tt) opt_mi = Val(world) - forward, adjoint = thunk(opt_mi, dupClosure0 ? $dupty : Const{FT}, - annotationA, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width, - ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity) + forward, adjoint = thunk( + opt_mi, + dupClosure0 ? $dupty : Const{FT}, + annotationA, + Tuple{$(Types...)}, + Val(API.DEM_ReverseModePrimal), + width, + ModifiedBetween, + Val(true), + Val(false), + FFIABI, + Val(false), + runtimeActivity, + ) #=erriffuncwritten=# internal_tape, origRet, initShadow = forward(dupClosure0 ? $dup : Const(f), args...) annotation = annotationA @@ -371,11 +480,17 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) resT = typeof(origRet) if annotation <: Const shadow_return = nothing - tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + tape = Tape{typeof(internal_tape),typeof(shadow_return),resT}( + internal_tape, + shadow_return, + ) return ReturnType(($(nres...), tape)) elseif annotation <: Active shadow_return = $shadowretinit - tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + tape = Tape{typeof(internal_tape),typeof(shadow_return),resT}( + internal_tape, + shadow_return, + ) $shadowretret end @@ -384,31 +499,51 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) end function func_runtime_generic_augfwd(N, Width) - _, _, primtypes, allargs, typeargs, wrapped, _, _, active_refs = setup_macro_wraps(false, N, Width) + _, _, primtypes, allargs, typeargs, wrapped, _, _, active_refs = + setup_macro_wraps(false, N, Width) body = body_runtime_generic_augfwd(N, Width, wrapped, primtypes, active_refs) quote - function runtime_generic_augfwd(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...))::ReturnType where {ActivityTup, MB, ReturnType, RuntimeActivity, F, DF, $(typeargs...)} + function runtime_generic_augfwd( + activity::Type{Val{ActivityTup}}, + runtimeActivity::Val{RuntimeActivity}, + width::Val{$Width}, + ModifiedBetween::Val{MB}, + RT::Val{ReturnType}, + f::F, + df::DF, + $(allargs...), + )::ReturnType where {ActivityTup,MB,ReturnType,RuntimeActivity,F,DF,$(typeargs...)} $body end end end -@generated function runtime_generic_augfwd(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, allargs...)::ReturnType where {ActivityTup, MB, RuntimeActivity, Width, ReturnType, F, DF} - N = div(length(allargs)+2, Width+1)-1 - _, _, primtypes, _, _, wrapped, _, _, active_refs = setup_macro_wraps(false, N, Width, :allargs) +@generated function runtime_generic_augfwd( + activity::Type{Val{ActivityTup}}, + runtimeActivity::Val{RuntimeActivity}, + width::Val{Width}, + ModifiedBetween::Val{MB}, + RT::Val{ReturnType}, + f::F, + df::DF, + allargs..., +)::ReturnType where {ActivityTup,MB,RuntimeActivity,Width,ReturnType,F,DF} + N = div(length(allargs) + 2, Width + 1) - 1 + _, _, primtypes, _, _, wrapped, _, _, active_refs = + setup_macro_wraps(false, N, Width, :allargs) return body_runtime_generic_augfwd(N, Width, wrapped, primtypes, active_refs) end -function nonzero_active_data(x::T) where T<: AbstractFloat +function nonzero_active_data(x::T) where {T<:AbstractFloat} return x != zero(T) end -nonzero_active_data(::T) where T<: Base.RefValue = false -nonzero_active_data(::T) where T<: Array = false -nonzero_active_data(::T) where T<: Ptr = false +nonzero_active_data(::T) where {T<:Base.RefValue} = false +nonzero_active_data(::T) where {T<:Array} = false +nonzero_active_data(::T) where {T<:Ptr} = false -function nonzero_active_data(x::T) where T +function nonzero_active_data(x::T) where {T} if guaranteed_const(T) return false end @@ -427,21 +562,33 @@ end function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, active_refs) outs = [] - for i in 1:N - for w in 1:Width + for i = 1:N + for w = 1:Width expr = if Width == 1 :(tup[$i]) else :(tup[$i][$w]) end shad = shadowargs[i][w] - out = :(if tup[$i] === nothing - elseif $shad isa Base.RefValue - $shad[] = recursive_add($shad[], $expr) + out = quote + if tup[$i] === nothing + elseif $shad isa Base.RefValue + $shad[] = recursive_add($shad[], $expr) else - error("Enzyme Mutability Error: Cannot add one in place to immutable value "*string($shad)*" tup[i]="*string(tup[$i])*" i="*string($i)*" w="*string($w)*" tup="*string(tup)) + error( + "Enzyme Mutability Error: Cannot add one in place to immutable value " * + string($shad) * + " tup[i]=" * + string(tup[$i]) * + " i=" * + string($i) * + " w=" * + string($w) * + " tup=" * + string(tup), + ) end - ) + end push!(outs, out) end end @@ -450,7 +597,7 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act shadowret = :(tape.shadow_return[]) else shadowret = [] - for w in 1:Width + for w = 1:Width push!(shadowret, :(tape.shadow_return[$w][])) end shadowret = :(($(shadowret...),)) @@ -459,7 +606,7 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act ElTypes = Vector{Expr}(undef, N) MakeTypes = Vector{Expr}(undef, N) Types = Vector{Symbol}(undef, N) - for i in 1:N + for i = 1:N @inbounds ElTypes[i] = :(eltype($(Symbol("type_$i")))) @inbounds MakeTypes[i] = :($(Symbol("type_$i")) = Core.Typeof(args[$i])) @inbounds Types[i] = Symbol("type_$i") @@ -469,7 +616,7 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act :(Duplicated(f, df)) else fargs = [:df] - for i in 2:Width + for i = 2:Width push!(fargs, Symbol("df_$i")) end :(BatchDuplicated(f, ($(fargs...),))) @@ -477,14 +624,14 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act dupty = if Width == 1 :(Duplicated{FT}) else - :(BatchDuplicated{FT, $Width}) + :(BatchDuplicated{FT,$Width}) end quote $(active_refs...) args = ($(wrapped...),) $(MakeTypes...) - + FT = Core.Typeof(f) dupClosure0 = if ActivityTup[1] !guaranteed_const(FT) @@ -497,7 +644,7 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) annotation = if $Width != 1 && annotation0 <: Duplicated - BatchDuplicated{rt, $Width} + BatchDuplicated{rt,$Width} else annotation0 end @@ -505,36 +652,84 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act world = codegen_world_age(FT, tt) opt_mi = Val(world) - _, adjoint = thunk(opt_mi, dupClosure0 ? $dupty : Const{FT}, - annotation, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width, - ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity) - - tup = if annotation0 <: Active || annotation0 <: MixedDuplicated || annotation0 <: BatchMixedDuplicated - adjoint(dupClosure0 ? $dup : Const(f), args..., $shadowret, tape.internal_tape)[1] - else - adjoint(dupClosure0 ? $dup : Const(f), args..., tape.internal_tape)[1] - end + _, adjoint = thunk( + opt_mi, + dupClosure0 ? $dupty : Const{FT}, + annotation, + Tuple{$(Types...)}, + Val(API.DEM_ReverseModePrimal), + width, + ModifiedBetween, + Val(true), + Val(false), + FFIABI, + Val(false), + runtimeActivity, + ) #=erriffuncwritten=# + + tup = + if annotation0 <: Active || + annotation0 <: MixedDuplicated || + annotation0 <: BatchMixedDuplicated + adjoint( + dupClosure0 ? $dup : Const(f), + args..., + $shadowret, + tape.internal_tape, + )[1] + else + adjoint(dupClosure0 ? $dup : Const(f), args..., tape.internal_tape)[1] + end $(outs...) + return nothing end end function func_runtime_generic_rev(N, Width) - _, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width) - body = body_runtime_generic_rev(N, Width, wrapped, primtypes, batchshadowargs, active_refs) + _, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = + setup_macro_wraps(false, N, Width) + body = + body_runtime_generic_rev(N, Width, wrapped, primtypes, batchshadowargs, active_refs) quote - function runtime_generic_rev(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{$Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, $(allargs...)) where {ActivityTup, RuntimeActivity, MB, TapeType, F, DF, $(typeargs...)} + function runtime_generic_rev( + activity::Type{Val{ActivityTup}}, + runtimeActivity::Val{RuntimeActivity}, + width::Val{$Width}, + ModifiedBetween::Val{MB}, + tape::TapeType, + f::F, + df::DF, + $(allargs...), + ) where {ActivityTup,RuntimeActivity,MB,TapeType,F,DF,$(typeargs...)} $body end end end -@generated function runtime_generic_rev(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, allargs...) where {ActivityTup, MB, RuntimeActivity, Width, TapeType, F, DF} - N = div(length(allargs)+2, Width+1)-1 - _, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width, :allargs) - return body_runtime_generic_rev(N, Width, wrapped, primtypes, batchshadowargs, active_refs) +@generated function runtime_generic_rev( + activity::Type{Val{ActivityTup}}, + runtimeActivity::Val{RuntimeActivity}, + width::Val{Width}, + ModifiedBetween::Val{MB}, + tape::TapeType, + f::F, + df::DF, + allargs..., +) where {ActivityTup,MB,RuntimeActivity,Width,TapeType,F,DF} + N = div(length(allargs) + 2, Width + 1) - 1 + _, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = + setup_macro_wraps(false, N, Width, :allargs) + return body_runtime_generic_rev( + N, + Width, + wrapped, + primtypes, + batchshadowargs, + active_refs, + ) end @inline concat() = () @@ -545,7 +740,8 @@ end @inline iterate_unwrap_inner_fwd(x::Const) = (map(Const, x.val)...,) @inline iterate_unwrap_inner_fwd(x::Duplicated) = (map(Duplicated, x.val, x.dval)...,) @inline batch_dup_tuple(x, vals...) = BatchDuplicated(x, (vals...,)) -@inline iterate_unwrap_inner_fwd(x::BatchDuplicated) = (map(batch_dup_tuple, x.val, x.dval...)...,) +@inline iterate_unwrap_inner_fwd(x::BatchDuplicated) = + (map(batch_dup_tuple, x.val, x.dval...)...,) @inline function iterate_unwrap_fwd(args...) ntuple(Val(length(args))) do i @@ -596,7 +792,7 @@ end end end -function push_if_not_ref(::Val{reverse}, vals, darg, ::Type{T2}) where {reverse, T2} +function push_if_not_ref(::Val{reverse}, vals, darg, ::Type{T2}) where {reverse,T2} if reverse return popfirst!(vals) else @@ -606,11 +802,21 @@ function push_if_not_ref(::Val{reverse}, vals, darg, ::Type{T2}) where {reverse, end end -function push_if_not_ref(::Val{reverse}, vals, darg::Base.RefValue{T2}, ::Type{T2}) where {reverse, T2} +function push_if_not_ref( + ::Val{reverse}, + vals, + darg::Base.RefValue{T2}, + ::Type{T2}, +) where {reverse,T2} return darg end -@inline function iterate_unwrap_augfwd_dup(::Val{reverse}, vals, args, dargs) where reverse +@inline function iterate_unwrap_augfwd_dup( + ::Val{reverse}, + vals, + args, + dargs, +) where {reverse} ntuple(Val(length(args))) do i Base.@_inline_meta arg = args[i] @@ -622,14 +828,23 @@ end Active(arg) elseif actreg == MixedState darg = Base.inferencebarrier(dargs[i]) - MixedDuplicated(arg, push_if_not_ref(Val(reverse), vals, darg, ty)::Base.RefValue{ty}) + MixedDuplicated( + arg, + push_if_not_ref(Val(reverse), vals, darg, ty)::Base.RefValue{ty}, + ) else Duplicated(arg, dargs[i]) end end end -@inline function iterate_unwrap_augfwd_batchdup(::Val{reverse}, vals, ::Val{Width}, args, dargs) where {reverse, Width} +@inline function iterate_unwrap_augfwd_batchdup( + ::Val{reverse}, + vals, + ::Val{Width}, + args, + dargs, +) where {reverse,Width} ntuple(Val(length(args))) do i Base.@_inline_meta arg = args[i] @@ -640,11 +855,14 @@ end elseif actreg == ActiveState Active(arg) elseif actreg == MixedState - BatchMixedDuplicated(arg, ntuple(Val(Width)) do j - Base.@_inline_meta - darg = Base.inferencebarrier(dargs[j][i]) - push_if_not_ref(Val(reverse), vals, darg, ty)::Base.RefValue{ty} - end) + BatchMixedDuplicated( + arg, + ntuple(Val(Width)) do j + Base.@_inline_meta + darg = Base.inferencebarrier(dargs[j][i]) + push_if_not_ref(Val(reverse), vals, darg, ty)::Base.RefValue{ty} + end, + ) else BatchDuplicated(arg, ntuple(Val(Width)) do j Base.@_inline_meta @@ -654,7 +872,12 @@ end end end -@inline function iterate_unwrap_augfwd_mix(::Val{reverse}, vals, args, dargs0) where reverse +@inline function iterate_unwrap_augfwd_mix( + ::Val{reverse}, + vals, + args, + dargs0, +) where {reverse} dargs = dargs0[] ntuple(Val(length(args))) do i Base.@_inline_meta @@ -667,14 +890,23 @@ end Active(arg) elseif actreg == MixedState darg = Base.inferencebarrier(dargs[i]) - MixedDuplicated(arg, push_if_not_ref(Val(reverse), vals, darg, ty)::Base.RefValue{ty}) + MixedDuplicated( + arg, + push_if_not_ref(Val(reverse), vals, darg, ty)::Base.RefValue{ty}, + ) else Duplicated(arg, dargs[i]) end end end -@inline function iterate_unwrap_augfwd_batchmix(::Val{reverse}, vals, ::Val{Width}, args, dargs) where {reverse, Width} +@inline function iterate_unwrap_augfwd_batchmix( + ::Val{reverse}, + vals, + ::Val{Width}, + args, + dargs, +) where {reverse,Width} ntuple(Val(length(args))) do i Base.@_inline_meta arg = args[i] @@ -685,11 +917,14 @@ end elseif actreg == ActiveState Active(arg) elseif actreg == MixedState - BatchMixedDuplicated(arg, ntuple(Val(Width)) do j - Base.@_inline_meta - darg = Base.inferencebarrier(dargs[j][][i]) - push_if_not_ref(Val(reverse), vals, darg, ty)::Base.RefValue{ty} - end) + BatchMixedDuplicated( + arg, + ntuple(Val(Width)) do j + Base.@_inline_meta + darg = Base.inferencebarrier(dargs[j][][i]) + push_if_not_ref(Val(reverse), vals, darg, ty)::Base.RefValue{ty} + end, + ) else BatchDuplicated(arg, ntuple(Val(Width)) do j Base.@_inline_meta @@ -699,21 +934,21 @@ end end end -@inline function allFirst(::Val{Width}, res) where Width +@inline function allFirst(::Val{Width}, res) where {Width} ntuple(Val(Width)) do i Base.@_inline_meta res[1] end end -@inline function allSame(::Val{Width}, res) where Width +@inline function allSame(::Val{Width}, res) where {Width} ntuple(Val(Width)) do i Base.@_inline_meta res end end -@inline function allZero(::Val{Width}, res) where Width +@inline function allZero(::Val{Width}, res) where {Width} ntuple(Val(Width)) do i Base.@_inline_meta Ref(make_zero(res)) @@ -721,21 +956,31 @@ end end # This is explicitly escaped here to be what is apply generic in total [and thus all the insides are stable] -function fwddiff_with_return(runtimeActivity::Val{RuntimeActivity}, ::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType}, ::Type{FT}, ::Type{tt′}, f::FT, df::DF, args::Vararg{Annotation, Nargs})::ReturnType where {RuntimeActivity, width, dupClosure0, ReturnType, FT, tt′, DF, Nargs} +function fwddiff_with_return( + runtimeActivity::Val{RuntimeActivity}, + ::Val{width}, + ::Val{dupClosure0}, + ::Type{ReturnType}, + ::Type{FT}, + ::Type{tt′}, + f::FT, + df::DF, + args::Vararg{Annotation,Nargs}, +)::ReturnType where {RuntimeActivity,width,dupClosure0,ReturnType,FT,tt′,DF,Nargs} ReturnPrimal = Val(true) - ModifiedBetween = Val(Enzyme.falses_from_args(Nargs+1)) + ModifiedBetween = Val(Enzyme.falses_from_args(Nargs + 1)) dupClosure = dupClosure0 && !guaranteed_const(FT) FA = dupClosure ? Duplicated{FT} : Const{FT} - tt = Enzyme.vaEltypes(tt′) + tt = Enzyme.vaEltypes(tt′) rt = Core.Compiler.return_type(f, tt) annotation0 = guess_activity(rt, API.DEM_ForwardMode) annotation = if width != 1 if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated - BatchDuplicated{rt, width} + BatchDuplicated{rt,width} else Const{rt} end @@ -758,10 +1003,25 @@ function fwddiff_with_return(runtimeActivity::Val{RuntimeActivity}, ::Val{width} Const(f) end opt_mi = Val(world) - res = thunk(opt_mi, FA, annotation, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), - ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity)(fa, args...) + res = thunk( + opt_mi, + FA, + annotation, + tt′, + Val(API.DEM_ForwardMode), + Val(width), #=Mode=# + ModifiedBetween, + ReturnPrimal, + Val(false), + FFIABI, + Val(false), + runtimeActivity, + )( + fa, + args..., + ) #=erriffuncwritten=# return if annotation <: Const - ReturnType(allFirst(Val(width+1), res)) + ReturnType(allFirst(Val(width + 1), res)) else if width == 1 ReturnType((res[2], res[1])) @@ -773,38 +1033,66 @@ end function body_runtime_iterate_fwd(N, Width, wrapped, primtypes, active_refs) wrappedexexpand = Vector{Expr}(undef, N) - for i in 1:N + for i = 1:N @inbounds wrappedexexpand[i] = :($(wrapped[i])...) end return quote $(active_refs...) args = ($(wrappedexexpand...),) - tt′ = Enzyme.vaTypeof(args...) + tt′ = Enzyme.vaTypeof(args...) FT = Core.Typeof(f) - fwddiff_with_return(runtimeActivity, Val($Width), Val(ActivityTup[1]), ReturnType, FT, tt′, f, df, args...)::ReturnType + fwddiff_with_return( + runtimeActivity, + Val($Width), + Val(ActivityTup[1]), + ReturnType, + FT, + tt′, + f, + df, + args..., + )::ReturnType end end function func_runtime_iterate_fwd(N, Width) - _, _, primtypes, allargs, typeargs, wrapped, _, _, active_refs = setup_macro_wraps(true, N, Width, #=base=#nothing, #=iterate=#true) + _, _, primtypes, allargs, typeargs, wrapped, _, _, active_refs = + setup_macro_wraps(true, N, Width, nothing, true) #=iterate=# body = body_runtime_iterate_fwd(N, Width, wrapped, primtypes, active_refs) quote - function runtime_iterate_fwd(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{$Width}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...)) where {ActivityTup, RuntimeActivity, ReturnType, F, DF, $(typeargs...)} + function runtime_iterate_fwd( + activity::Type{Val{ActivityTup}}, + runtimeActivity::Val{RuntimeActivity}, + width::Val{$Width}, + RT::Val{ReturnType}, + f::F, + df::DF, + $(allargs...), + ) where {ActivityTup,RuntimeActivity,ReturnType,F,DF,$(typeargs...)} $body end end end -@generated function runtime_iterate_fwd(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{Width}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, RuntimeActivity, Width, ReturnType, F, DF} - N = div(length(allargs)+2, Width+1)-1 - _, _, primtypes, _, _, wrapped, _, _, active_refs = setup_macro_wraps(true, N, Width, :allargs, #=iterate=#true) +@generated function runtime_iterate_fwd( + activity::Type{Val{ActivityTup}}, + runtimeActivity::Val{RuntimeActivity}, + width::Val{Width}, + RT::Val{ReturnType}, + f::F, + df::DF, + allargs..., +) where {ActivityTup,RuntimeActivity,Width,ReturnType,F,DF} + N = div(length(allargs) + 2, Width + 1) - 1 + _, _, primtypes, _, _, wrapped, _, _, active_refs = + setup_macro_wraps(true, N, Width, :allargs, true) #=iterate=# return body_runtime_iterate_fwd(N, Width, wrapped, primtypes, active_refs) end -@generated function primal_tuple(args::Vararg{Annotation, Nargs}) where Nargs +@generated function primal_tuple(args::Vararg{Annotation,Nargs}) where {Nargs} expr = Vector{Expr}(undef, Nargs) - for i in 1:Nargs + for i = 1:Nargs @inbounds expr[i] = :(args[$i].val) end return quote @@ -813,16 +1101,20 @@ end end end -@generated function shadow_tuple(::Type{Ann}, ::Val{1}, args::Vararg{Annotation, Nargs}) where {Ann, Nargs} +@generated function shadow_tuple( + ::Type{Ann}, + ::Val{1}, + args::Vararg{Annotation,Nargs}, +) where {Ann,Nargs} expr = Vector{Expr}(undef, Nargs) - for i in 1:Nargs + for i = 1:Nargs @inbounds expr[i] = quote @assert !(args[$i] isa Active) if args[$i] isa Const args[$i].val elseif args[$i] isa MixedDuplicated args[$i].dval[] - else + else args[$i].dval end end @@ -837,18 +1129,22 @@ end end end -@generated function shadow_tuple(::Type{Ann}, ::Val{width}, args::Vararg{Annotation, Nargs}) where {Ann, width, Nargs} +@generated function shadow_tuple( + ::Type{Ann}, + ::Val{width}, + args::Vararg{Annotation,Nargs}, +) where {Ann,width,Nargs} wexpr = Vector{Expr}(undef, width) - for w in 1:width + for w = 1:width expr = Vector{Expr}(undef, Nargs) - for i in 1:Nargs + for i = 1:Nargs @inbounds expr[i] = quote @assert !(args[$i] isa Active) if args[$i] isa Const args[$i].val elseif args[$i] isa BatchMixedDuplicated args[$i].dval[$w][] - else + else args[$i].dval[$w] end end @@ -867,19 +1163,40 @@ end end # This is explicitly escaped here to be what is apply generic in total [and thus all the insides are stable] -function augfwd_with_return(runtimeActivity::Val{RuntimeActivity}, ::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType}, ::Val{ModifiedBetween0}, ::Type{FT}, ::Type{tt′}, f::FT, df::DF, args::Vararg{Annotation, Nargs})::ReturnType where {RuntimeActivity, width, dupClosure0, ReturnType, ModifiedBetween0, FT, tt′, DF, Nargs} +function augfwd_with_return( + runtimeActivity::Val{RuntimeActivity}, + ::Val{width}, + ::Val{dupClosure0}, + ::Type{ReturnType}, + ::Val{ModifiedBetween0}, + ::Type{FT}, + ::Type{tt′}, + f::FT, + df::DF, + args::Vararg{Annotation,Nargs}, +)::ReturnType where { + RuntimeActivity, + width, + dupClosure0, + ReturnType, + ModifiedBetween0, + FT, + tt′, + DF, + Nargs, +} ReturnPrimal = Val(true) ModifiedBetween = Val(ModifiedBetween0) - tt = Enzyme.vaEltypes(tt′) + tt = Enzyme.vaEltypes(tt′) rt = Core.Compiler.return_type(f, tt) annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) annotation = if width != 1 if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated - BatchDuplicated{rt, width} + BatchDuplicated{rt,width} elseif annotation0 <: MixedDuplicated - BatchMixedDuplicated{rt, width} + BatchMixedDuplicated{rt,width} elseif annotation0 <: Active Active{rt} else @@ -912,27 +1229,46 @@ function augfwd_with_return(runtimeActivity::Val{RuntimeActivity}, ::Val{width}, end world = codegen_world_age(FT, tt) opt_mi = Val(world) - forward, adjoint = thunk(opt_mi, FA, - annotation, tt′, Val(API.DEM_ReverseModePrimal), Val(width), - ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity) + forward, adjoint = thunk( + opt_mi, + FA, + annotation, + tt′, + Val(API.DEM_ReverseModePrimal), + Val(width), + ModifiedBetween, + Val(true), + Val(false), + FFIABI, + Val(false), + runtimeActivity, + ) #=erriffuncwritten=# forward(fa, args...) else - nothing, primal_tuple(args...), annotation <: Active ? nothing : shadow_tuple(annotation, Val(width), args...) + nothing, + primal_tuple(args...), + annotation <: Active ? nothing : shadow_tuple(annotation, Val(width), args...) end resT = typeof(origRet) if annotation <: Const shadow_return = nothing - tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) - return ReturnType((allSame(Val(width+1), origRet)..., tape)) + tape = Tape{typeof(internal_tape),typeof(shadow_return),resT}( + internal_tape, + shadow_return, + ) + return ReturnType((allSame(Val(width + 1), origRet)..., tape)) elseif annotation <: Active shadow_return = if width == 1 Ref(make_zero(origRet)) else allZero(Val(width), origRet) end - tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + tape = Tape{typeof(internal_tape),typeof(shadow_return),resT}( + internal_tape, + shadow_return, + ) if width == 1 return ReturnType((origRet, shadow_return, tape)) else @@ -943,21 +1279,33 @@ function augfwd_with_return(runtimeActivity::Val{RuntimeActivity}, ::Val{width}, if width == 1 if annotation <: MixedDuplicated shadow_return = initShadow - tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + tape = Tape{typeof(internal_tape),typeof(shadow_return),resT}( + internal_tape, + shadow_return, + ) return ReturnType((origRet, initShadow, tape)) else shadow_return = nothing - tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + tape = Tape{typeof(internal_tape),typeof(shadow_return),resT}( + internal_tape, + shadow_return, + ) return ReturnType((origRet, initShadow, tape)) end else if annotation <: BatchMixedDuplicated shadow_return = initShadow - tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + tape = Tape{typeof(internal_tape),typeof(shadow_return),resT}( + internal_tape, + shadow_return, + ) return ReturnType((origRet, initShadow..., tape)) else shadow_return = nothing - tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + tape = Tape{typeof(internal_tape),typeof(shadow_return),resT}( + internal_tape, + shadow_return, + ) return ReturnType((origRet, initShadow..., tape)) end end @@ -965,65 +1313,129 @@ end function body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes, active_refs) wrappedexexpand = Vector{Expr}(undef, N) - for i in 1:N + for i = 1:N @inbounds wrappedexexpand[i] = :($(wrapped[i])...) end - results = Vector{Expr}(undef, Width+1) - for i in 1:(Width+1) + results = Vector{Expr}(undef, Width + 1) + for i = 1:(Width+1) results[i] = :(tmpvals[$i]) end return quote refs = Base.RefValue[] $(active_refs...) args = ($(wrappedexexpand...),) - tt′ = Enzyme.vaTypeof(args...) + tt′ = Enzyme.vaTypeof(args...) FT = Core.Typeof(f) - tmpvals = augfwd_with_return(runtimeActivity, Val($Width), Val(ActivityTup[1]), ReturnType, Val(concat($(modbetween...))), FT, tt′, f, df, args...)::ReturnType - ReturnType(($(results...), (tmpvals[$(Width+2)], refs))) + tmpvals = augfwd_with_return( + runtimeActivity, + Val($Width), + Val(ActivityTup[1]), + ReturnType, + Val(concat($(modbetween...))), + FT, + tt′, + f, + df, + args..., + )::ReturnType + ReturnType(($(results...), (tmpvals[$(Width + 2)], refs))) end end function func_runtime_iterate_augfwd(N, Width) - _, _, primtypes, allargs, typeargs, wrapped, _, modbetween, active_refs = setup_macro_wraps(false, N, Width, #=base=#nothing, #=iterate=#true) - body = body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes, active_refs) + _, _, primtypes, allargs, typeargs, wrapped, _, modbetween, active_refs = + setup_macro_wraps(false, N, Width, nothing, true) #=iterate=# + body = + body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes, active_refs) quote - function runtime_iterate_augfwd(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...)) where {ActivityTup, RuntimeActivity, MB, ReturnType, F, DF, $(typeargs...)} + function runtime_iterate_augfwd( + activity::Type{Val{ActivityTup}}, + runtimeActivity::Val{RuntimeActivity}, + width::Val{$Width}, + ModifiedBetween::Val{MB}, + RT::Val{ReturnType}, + f::F, + df::DF, + $(allargs...), + ) where {ActivityTup,RuntimeActivity,MB,ReturnType,F,DF,$(typeargs...)} $body end end end -@generated function runtime_iterate_augfwd(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, RuntimeActivity, MB, Width, ReturnType, F, DF} - N = div(length(allargs)+2, Width+1)-1 - _, _, primtypes, _, _, wrapped, _ , modbetween, active_refs = setup_macro_wraps(false, N, Width, :allargs, #=iterate=#true) - return body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes, active_refs) +@generated function runtime_iterate_augfwd( + activity::Type{Val{ActivityTup}}, + runtimeActivity::Val{RuntimeActivity}, + width::Val{Width}, + ModifiedBetween::Val{MB}, + RT::Val{ReturnType}, + f::F, + df::DF, + allargs..., +) where {ActivityTup,RuntimeActivity,MB,Width,ReturnType,F,DF} + N = div(length(allargs) + 2, Width + 1) - 1 + _, _, primtypes, _, _, wrapped, _, modbetween, active_refs = + setup_macro_wraps(false, N, Width, :allargs, true) #=iterate=# + return body_runtime_iterate_augfwd( + N, + Width, + modbetween, + wrapped, + primtypes, + active_refs, + ) end function add_into_vec!(val::Base.RefValue, expr, vec, idx_in_vec) - val[] = recursive_add(val[], expr, identity, guaranteed_nonactive) - nothing + val[] = recursive_add(val[], expr, identity, guaranteed_nonactive) + nothing end -function add_into_vec!(val::T, expr, vec, idx_in_vec) where T +function add_into_vec!(val::T, expr, vec, idx_in_vec) where {T} if ismutable(vec) @inbounds vec[idx_in_vec] = recursive_add(val, expr, identity, guaranteed_nonactive) else - error("Enzyme Mutability Error: Cannot in place to immutable value vec[$idx_in_vec] = $val, vec=$vec") + error( + "Enzyme Mutability Error: Cannot in place to immutable value vec[$idx_in_vec] = $val, vec=$vec", + ) end nothing end # This is explicitly escaped here to be what is apply generic in total [and thus all the insides are stable] -@generated function rev_with_return(runtimeActivity::Val{RuntimeActivity}, ::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween0}, ::Val{lengths}, ::Type{FT}, ::Type{ttp}, f::FT, df::DF, tape, shadowargs, args::Vararg{Annotation, Nargs})::Nothing where {RuntimeActivity, width, dupClosure0, ModifiedBetween0, lengths, FT, ttp, DF, Nargs} +@generated function rev_with_return( + runtimeActivity::Val{RuntimeActivity}, + ::Val{width}, + ::Val{dupClosure0}, + ::Val{ModifiedBetween0}, + ::Val{lengths}, + ::Type{FT}, + ::Type{ttp}, + f::FT, + df::DF, + tape, + shadowargs, + args::Vararg{Annotation,Nargs}, +)::Nothing where { + RuntimeActivity, + width, + dupClosure0, + ModifiedBetween0, + lengths, + FT, + ttp, + DF, + Nargs, +} nontupexprs = Vector{Expr}(undef, Nargs) - for i in 1:Nargs + for i = 1:Nargs mid = if width == 1 :(tape.shadow_return[][$i]) else mexprs = Vector{Expr}(undef, width) - for w in 1:width + for w = 1:width @inbounds mexprs[w] = :(tape.shadow_return[$w][][$i]) end quote @@ -1032,7 +1444,9 @@ end end @inbounds nontupexprs[i] = quote - if args[$i] isa Active || args[$i] isa MixedDuplicated || args[$i] isa BatchMixedDuplicated + if args[$i] isa Active || + args[$i] isa MixedDuplicated || + args[$i] isa BatchMixedDuplicated $mid else nothing @@ -1041,10 +1455,12 @@ end end endexprs = Matrix{Expr}(undef, Nargs, width) - for i in 1:Nargs - for w in 1:width + for i = 1:Nargs + for w = 1:width @inbounds endexprs[i, w] = quote - if args[$i] isa Active || args[$i] isa MixedDuplicated || args[$i] isa BatchMixedDuplicated + if args[$i] isa Active || + args[$i] isa MixedDuplicated || + args[$i] isa BatchMixedDuplicated expr = if args[$i] isa Active || f == Base.tuple if $width == 1 tup[$i] @@ -1061,7 +1477,7 @@ end idx_of_vec, idx_in_vec = $(lengths[i]) vec = @inbounds shadowargs[idx_of_vec][$w] if vec isa Base.RefValue - vecld = vec[] + vecld = vec[] T = Core.Typeof(vecld) vec[] = recursive_index_add(T, vecld, Val(idx_in_vec), expr) else @@ -1079,9 +1495,9 @@ end annotation = if width != 1 quote if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated - BatchDuplicated{rt, $width} + BatchDuplicated{rt,$width} elseif annotation0 <: MixedDuplicated - BatchMixedDuplicated{rt, $width} + BatchMixedDuplicated{rt,$width} elseif annotation0 <: Active Active{rt} else @@ -1106,7 +1522,7 @@ end :(adjoint(fa, args..., tape.shadow_return[], tape.internal_tape)[1]) else margs = Vector{Expr}(undef, width) - for w in 1:width + for w = 1:width @inbounds margs[w] = :(tape.shadow_return[$w][]) end :(adjoint(fa, args..., ($(margs...),), tape.internal_tape)[1]) @@ -1121,7 +1537,7 @@ end dupClosure = $dupClosure0 && !guaranteed_const($FT) FA = dupClosure ? Duplicated{$FT} : Const{$FT} - tt = $tt + tt = $tt rt = Core.Compiler.return_type(f, tt) annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) @@ -1135,10 +1551,21 @@ end Const(f) end opt_mi = Val(world) - forward, adjoint = thunk(opt_mi, FA, - annotation, $ttp, Val(API.DEM_ReverseModePrimal), Val($width), - ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity) - + forward, adjoint = thunk( + opt_mi, + FA, + annotation, + $ttp, + Val(API.DEM_ReverseModePrimal), + Val($width), + ModifiedBetween, + Val(true), + Val(false), + FFIABI, + Val(false), + runtimeActivity, + ) #=erriffuncwritten=# + tup = if tape.shadow_return !== nothing $shadadj else @@ -1154,9 +1581,9 @@ end end end -@generated function ntuple_pair(::Val{Len}, ::Val{i}) where {Len, i} +@generated function ntuple_pair(::Val{Len}, ::Val{i}) where {Len,i} mexprs = Vector{Expr}(undef, Len) - for j in 1:Len + for j = 1:Len @inbounds mexprs[j] = quote ($i, $j) end @@ -1167,24 +1594,32 @@ end end end -function body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, shadowargs, active_refs) +function body_runtime_iterate_rev( + N, + Width, + modbetween, + wrapped, + primargs, + shadowargs, + active_refs, +) shadow_ret = nothing if Width == 1 shadowret = :(tape.shadow_return[]) else shadowret = Expr[] - for w in 1:Width + for w = 1:Width push!(shadowret, :(tape.shadow_return[$w][])) end shadowret = :(($(shadowret...),)) end wrappedexexpand = Vector{Expr}(undef, N) - for i in 1:N + for i = 1:N wrappedexexpand[i] = :($(wrapped[i])...) end lengths = Vector{Expr}(undef, N) - for i in 1:N + for i = 1:N lengths[i] = quote ntuple_pair(Val(length($(primargs[i]))), Val($i)) end @@ -1198,28 +1633,84 @@ function body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, shado (tape0, refs) = tape $(active_refs...) args = ($(wrappedexexpand...),) - tt′ = Enzyme.vaTypeof(args...) + tt′ = Enzyme.vaTypeof(args...) FT = Core.Typeof(f) - rev_with_return(runtimeActivity, Val($Width), Val(ActivityTup[1]), Val(concat($(modbetween...))), Val(concat($(lengths...))), FT, tt′, f, df, tape0, ($(shadowsplat...),), args...) + rev_with_return( + runtimeActivity, + Val($Width), + Val(ActivityTup[1]), + Val(concat($(modbetween...))), + Val(concat($(lengths...))), + FT, + tt′, + f, + df, + tape0, + ($(shadowsplat...),), + args..., + ) return nothing end end function func_runtime_iterate_rev(N, Width) - primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, modbetween, active_refs = setup_macro_wraps(false, N, Width, #=body=#nothing, #=iterate=#true; reverse=true) - body = body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, batchshadowargs, active_refs) + primargs, + _, + primtypes, + allargs, + typeargs, + wrapped, + batchshadowargs, + modbetween, + active_refs = setup_macro_wraps(false, N, Width, nothing, true; reverse = true) #=iterate=# + body = body_runtime_iterate_rev( + N, + Width, + modbetween, + wrapped, + primargs, + batchshadowargs, + active_refs, + ) quote - function runtime_iterate_rev(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{$Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, $(allargs...)) where {ActivityTup, RuntimeActivity, MB, TapeType, F, DF, $(typeargs...)} + function runtime_iterate_rev( + activity::Type{Val{ActivityTup}}, + runtimeActivity::Val{RuntimeActivity}, + width::Val{$Width}, + ModifiedBetween::Val{MB}, + tape::TapeType, + f::F, + df::DF, + $(allargs...), + ) where {ActivityTup,RuntimeActivity,MB,TapeType,F,DF,$(typeargs...)} $body end end end -@generated function runtime_iterate_rev(activity::Type{Val{ActivityTup}}, runtimeActivity::Val{RuntimeActivity}, width::Val{Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, allargs...) where {ActivityTup, RuntimeActivity, MB, Width, TapeType, F, DF} - N = div(length(allargs)+2, Width+1)-1 - primargs, _, primtypes, _, _, wrapped, batchshadowargs, modbetween, active_refs = setup_macro_wraps(false, N, Width, :allargs, #=iterate=#true; reverse=true) - return body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, batchshadowargs, active_refs) +@generated function runtime_iterate_rev( + activity::Type{Val{ActivityTup}}, + runtimeActivity::Val{RuntimeActivity}, + width::Val{Width}, + ModifiedBetween::Val{MB}, + tape::TapeType, + f::F, + df::DF, + allargs..., +) where {ActivityTup,RuntimeActivity,MB,Width,TapeType,F,DF} + N = div(length(allargs) + 2, Width + 1) - 1 + primargs, _, primtypes, _, _, wrapped, batchshadowargs, modbetween, active_refs = + setup_macro_wraps(false, N, Width, :allargs, true; reverse = true) #=iterate=# + return body_runtime_iterate_rev( + N, + Width, + modbetween, + wrapped, + primargs, + batchshadowargs, + active_refs, + ) end # Create specializations @@ -1232,7 +1723,21 @@ for (N, Width) in Iterators.product(0:30, 1:10) eval(func_runtime_iterate_rev(N, Width)) end -function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, lookup; sret=nothing, tape=nothing, firstconst=false, endcast=true, firstconst_after_tape=true, runtime_activity=true) +function generic_setup( + orig, + func, + ReturnType, + gutils, + start, + B::LLVM.IRBuilder, + lookup; + sret = nothing, + tape = nothing, + firstconst = false, + endcast = true, + firstconst_after_tape = true, + runtime_activity = true, +) width = get_width(gutils) mode = get_mode(gutils) mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) @@ -1271,7 +1776,7 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, inverted = nothing active = !is_constant_value(gutils, op) - + if !active push!(ActivityList, unsafe_to_llvm(B, false)) else @@ -1285,19 +1790,27 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, else extract_value!(B, inverted, 0) end - push!(ActivityList, select!(B, icmp!(B, LLVM.API.LLVMIntNE, val, inv_0), unsafe_to_llvm(B, true), unsafe_to_llvm(B, false))) + push!( + ActivityList, + select!( + B, + icmp!(B, LLVM.API.LLVMIntNE, val, inv_0), + unsafe_to_llvm(B, true), + unsafe_to_llvm(B, false), + ), + ) else push!(ActivityList, unsafe_to_llvm(B, true)) end end - for w in 1:width + for w = 1:width ev = fill_val if inverted !== nothing if width == 1 ev = inverted else - ev = extract_value!(B, inverted, w-1) + ev = extract_value!(B, inverted, w - 1) end end @@ -1317,7 +1830,7 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, else pushfirst!(vals, unsafe_to_llvm(B, Val(ReturnType))) end - + if firstconst && firstconst_after_tape val = new_from_original(gutils, operands(orig)[start]) if lookup @@ -1333,7 +1846,7 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, ModifiedBetween = Bool[] - for idx in 1:(length(ops)+firstconst) + for idx = 1:(length(ops)+firstconst) push!(ModifiedBetween, uncacheable[(start-1)+idx] != 0) end pushfirst!(vals, unsafe_to_llvm(B, Val((ModifiedBetween...,)))) @@ -1344,7 +1857,7 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, pushfirst!(vals, unsafe_to_llvm(B, Val(get_runtime_activity(gutils)))) end etup0 = emit_tuple!(B, ActivityList) - etup = emit_apply_type!(B, Base.Val, [etup0]) + etup = emit_apply_type!(B, Base.Val, [etup0]) if isa(etup, LLVM.Instruction) @assert length(collect(LLVM.uses(etup0))) == 1 end @@ -1355,7 +1868,7 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, cal = emit_apply_generic!(B, vals) debug_from_orig!(gutils, cal, orig) - + if tape === nothing && endcast llty = convert(LLVMType, ReturnType) cal = LLVM.addrspacecast!(B, cal, LLVM.PointerType(T_jlvalue, Derived)) @@ -1366,42 +1879,69 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, end function common_generic_fwd(offset, B, orig, gutils, normalR, shadowR) - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing - shadow = (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + shadow = + (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : nothing T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) - - if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + get_mode(gutils), + ) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0) && + is_constant_inst(gutils, orig) return true end width = get_width(gutils) - sret = generic_setup(orig, runtime_generic_fwd, AnyArray(1+Int(width)), gutils, #=start=#offset, B, false) - AT = LLVM.ArrayType(T_prjlvalue, 1+Int(width)) + sret = generic_setup( + orig, + runtime_generic_fwd, + AnyArray(1 + Int(width)), + gutils, + offset, + B, + false, + ) #=start=# + AT = LLVM.ArrayType(T_prjlvalue, 1 + Int(width)) if unsafe_load(shadowR) != C_NULL if width == 1 - gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) + gep = + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) shadow = LLVM.load!(B, T_prjlvalue, gep) else ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) shadow = LLVM.UndefValue(ST) - for i in 1:width - gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(i)]) + for i = 1:width + gep = LLVM.inbounds_gep!( + B, + AT, + sret, + [LLVM.ConstantInt(0), LLVM.ConstantInt(i)], + ) ld = LLVM.load!(B, T_prjlvalue, gep) - shadow = insert_value!(B, shadow, ld, i-1) + shadow = insert_value!(B, shadow, ld, i - 1) end end unsafe_store!(shadowR, shadow.ref) end if unsafe_load(normalR) != C_NULL - normal = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)])) + normal = LLVM.load!( + B, + T_prjlvalue, + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)]), + ) unsafe_store!(normalR, normal.ref) else # Delete the primal code @@ -1419,45 +1959,76 @@ end end function common_generic_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing - shadow = (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + shadow = + (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : nothing T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) - - if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + get_mode(gutils), + ) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0) && + is_constant_inst(gutils, orig) return true end width = get_width(gutils) - sret = generic_setup(orig, runtime_generic_augfwd, AnyArray(2+Int(width)), gutils, #=start=#offset, B, false) - AT = LLVM.ArrayType(T_prjlvalue, 2+Int(width)) - - if unsafe_load(shadowR) != C_NULL + sret = generic_setup( + orig, + runtime_generic_augfwd, + AnyArray(2 + Int(width)), + gutils, + offset, + B, + false, + ) #=start=# + AT = LLVM.ArrayType(T_prjlvalue, 2 + Int(width)) + + if unsafe_load(shadowR) != C_NULL if width == 1 - gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) + gep = + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) shadow = LLVM.load!(B, T_prjlvalue, gep) else ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) shadow = LLVM.UndefValue(ST) - for i in 1:width - gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(i)]) + for i = 1:width + gep = LLVM.inbounds_gep!( + B, + AT, + sret, + [LLVM.ConstantInt(0), LLVM.ConstantInt(i)], + ) ld = LLVM.load!(B, T_prjlvalue, gep) - shadow = insert_value!(B, shadow, ld, i-1) + shadow = insert_value!(B, shadow, ld, i - 1) end end unsafe_store!(shadowR, shadow.ref) end - tape = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1+width)])) + tape = LLVM.load!( + B, + T_prjlvalue, + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1 + width)]), + ) unsafe_store!(tapeR, tape.ref) if normalR != C_NULL - normal = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)])) + normal = LLVM.load!( + B, + T_prjlvalue, + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)]), + ) unsafe_store!(normalR, normal.ref) else # Delete the primal code @@ -1479,15 +2050,22 @@ end function common_generic_rev(offset, B, orig, gutils, tape)::Cvoid needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal) - - if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + API.DEM_ReverseModePrimal, + ) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0) && + is_constant_inst(gutils, orig) return nothing end @assert tape !== C_NULL width = get_width(gutils) - generic_setup(orig, runtime_generic_rev, Nothing, gutils, #=start=#offset, B, true; tape) + generic_setup(orig, runtime_generic_rev, Nothing, gutils, offset, B, true; tape) #=start=# return nothing end @@ -1504,9 +2082,16 @@ end function common_apply_latest_fwd(offset, B, orig, gutils, normalR, shadowR) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) - - if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + get_mode(gutils), + ) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0) && + is_constant_inst(gutils, orig) return true end mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) @@ -1515,27 +2100,45 @@ function common_apply_latest_fwd(offset, B, orig, gutils, normalR, shadowR) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) width = get_width(gutils) - AT = LLVM.ArrayType(T_prjlvalue, 1+Int(width)) - sret = generic_setup(orig, runtime_generic_fwd, AnyArray(1+Int(width)), gutils, #=start=#offset+1, B, false) + AT = LLVM.ArrayType(T_prjlvalue, 1 + Int(width)) + sret = generic_setup( + orig, + runtime_generic_fwd, + AnyArray(1 + Int(width)), + gutils, + offset + 1, + B, + false, + ) #=start=# if unsafe_load(shadowR) != C_NULL if width == 1 - gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) + gep = + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) shadow = LLVM.load!(B, T_prjlvalue, gep) else ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) shadow = LLVM.UndefValue(ST) - for i in 1:width - gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(i)]) + for i = 1:width + gep = LLVM.inbounds_gep!( + B, + AT, + sret, + [LLVM.ConstantInt(0), LLVM.ConstantInt(i)], + ) ld = LLVM.load!(B, T_prjlvalue, gep) - shadow = insert_value!(B, shadow, ld, i-1) + shadow = insert_value!(B, shadow, ld, i - 1) end end unsafe_store!(shadowR, shadow.ref) end if unsafe_load(normalR) != C_NULL - normal = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)])) + normal = LLVM.load!( + B, + T_prjlvalue, + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)]), + ) unsafe_store!(normalR, normal.ref) else # Delete the primal code @@ -1549,9 +2152,16 @@ end function common_apply_latest_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) - - if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + get_mode(gutils), + ) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0) && + is_constant_inst(gutils, orig) return true end @@ -1559,31 +2169,53 @@ function common_apply_latest_augfwd(offset, B, orig, gutils, normalR, shadowR, t T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) width = get_width(gutils) - AT = LLVM.ArrayType(T_prjlvalue, 2+Int(width)) + AT = LLVM.ArrayType(T_prjlvalue, 2 + Int(width)) # sret = generic_setup(orig, runtime_apply_latest_augfwd, AnyArray(2+Int(width)), gutils, #=start=#offset+1, ctx, B, false) - sret = generic_setup(orig, runtime_generic_augfwd, AnyArray(2+Int(width)), gutils, #=start=#offset+1, B, false) + sret = generic_setup( + orig, + runtime_generic_augfwd, + AnyArray(2 + Int(width)), + gutils, + offset + 1, + B, + false, + ) #=start=# if unsafe_load(shadowR) != C_NULL if width == 1 - gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) + gep = + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) shadow = LLVM.load!(B, T_prjlvalue, gep) else ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) shadow = LLVM.UndefValue(ST) - for i in 1:width - gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(i)]) + for i = 1:width + gep = LLVM.inbounds_gep!( + B, + AT, + sret, + [LLVM.ConstantInt(0), LLVM.ConstantInt(i)], + ) ld = LLVM.load!(B, T_prjlvalue, gep) - shadow = insert_value!(B, shadow, ld, i-1) + shadow = insert_value!(B, shadow, ld, i - 1) end end unsafe_store!(shadowR, shadow.ref) end - tape = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1+width)])) + tape = LLVM.load!( + B, + T_prjlvalue, + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1 + width)]), + ) unsafe_store!(tapeR, tape.ref) if unsafe_load(normalR) != C_NULL - normal = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)])) + normal = LLVM.load!( + B, + T_prjlvalue, + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)]), + ) unsafe_store!(normalR, normal.ref) else # Delete the primal code @@ -1596,14 +2228,21 @@ end function common_apply_latest_rev(offset, B, orig, gutils, tape)::Cvoid needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal) - - if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + API.DEM_ReverseModePrimal, + ) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0) && + is_constant_inst(gutils, orig) return nothing end if !is_constant_value(gutils, orig) || !is_constant_inst(gutils, orig) width = get_width(gutils) - generic_setup(orig, runtime_generic_rev, Nothing, gutils, #=start=#offset+1, B, true; tape) + generic_setup(orig, runtime_generic_rev, Nothing, gutils, offset + 1, B, true; tape) #=start=# end return nothing @@ -1637,9 +2276,16 @@ end function common_apply_iterate_fwd(offset, B, orig, gutils, normalR, shadowR) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) - - if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + get_mode(gutils), + ) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0) && + is_constant_inst(gutils, orig) return true end @@ -1648,9 +2294,14 @@ function common_apply_iterate_fwd(offset, B, orig, gutils, normalR, shadowR) width = get_width(gutils) - if v && v2 && isiter == Base.iterate && istup == Base.tuple && length(operands(orig)) >= offset+4 + if v && + v2 && + isiter == Base.iterate && + istup == Base.tuple && + length(operands(orig)) >= offset + 4 origops = collect(operands(orig)[1:end-1]) - shadowins = [ invert_pointer(gutils, origops[i], B) for i in (offset+3):length(origops) ] + shadowins = + [invert_pointer(gutils, origops[i], B) for i = (offset+3):length(origops)] shadowres = if width == 1 newops = LLVM.Value[] newvals = API.CValueType[] @@ -1664,18 +2315,25 @@ function common_apply_iterate_fwd(offset, B, orig, gutils, normalR, shadowR) push!(newvals, API.VT_Primal) end end - cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) + cal = call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + newops, + newvals, + false, + ) #=lookup=# callconv!(cal, callconv(orig)) cal else ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) shadow = LLVM.UndefValue(ST) - for j in 1:width + for j = 1:width newops = LLVM.Value[] newvals = API.CValueType[] for (i, v) in enumerate(origops) if i >= offset + 3 - shadowin2 = extract_value!(B, shadowins[i-offset-3+1], j-1) + shadowin2 = extract_value!(B, shadowins[i-offset-3+1], j - 1) push!(newops, shadowin2) push!(newvals, API.VT_Shadow) else @@ -1683,9 +2341,16 @@ function common_apply_iterate_fwd(offset, B, orig, gutils, normalR, shadowR) push!(newvals, API.VT_Primal) end end - cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) + cal = call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + newops, + newvals, + false, + ) #=lookup=# callconv!(cal, callconv(orig)) - shadow = insert_value!(B, shadow, cal, j-1) + shadow = insert_value!(B, shadow, cal, j - 1) end shadow end @@ -1698,26 +2363,48 @@ function common_apply_iterate_fwd(offset, B, orig, gutils, normalR, shadowR) T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - sret = generic_setup(orig, runtime_iterate_fwd, AnyArray(1+Int(width)), gutils, #=start=#offset+2, B, false) - AT = LLVM.ArrayType(T_prjlvalue, 1+Int(width)) + sret = generic_setup( + orig, + runtime_iterate_fwd, + AnyArray(1 + Int(width)), + gutils, + offset + 2, + B, + false, + ) #=start=# + AT = LLVM.ArrayType(T_prjlvalue, 1 + Int(width)) if unsafe_load(shadowR) != C_NULL if width == 1 - gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) + gep = LLVM.inbounds_gep!( + B, + AT, + sret, + [LLVM.ConstantInt(0), LLVM.ConstantInt(1)], + ) shadow = LLVM.load!(B, T_prjlvalue, gep) else ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) shadow = LLVM.UndefValue(ST) - for i in 1:width - gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(i)]) + for i = 1:width + gep = LLVM.inbounds_gep!( + B, + AT, + sret, + [LLVM.ConstantInt(0), LLVM.ConstantInt(i)], + ) ld = LLVM.load!(B, T_prjlvalue, gep) - shadow = insert_value!(B, shadow, ld, i-1) + shadow = insert_value!(B, shadow, ld, i - 1) end end unsafe_store!(shadowR, shadow.ref) end if unsafe_load(normalR) != C_NULL - normal = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)])) + normal = LLVM.load!( + B, + T_prjlvalue, + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)]), + ) unsafe_store!(normalR, normal.ref) else # Delete the primal code @@ -1727,7 +2414,12 @@ function common_apply_iterate_fwd(offset, B, orig, gutils, normalR, shadowR) return false end - emit_error(B, orig, "Enzyme: Not yet implemented augmented forward for jl_f__apply_iterate "*string((v, v2, isiter, istup, length(operands(orig)), offset+4))) + emit_error( + B, + orig, + "Enzyme: Not yet implemented augmented forward for jl_f__apply_iterate " * + string((v, v2, isiter, istup, length(operands(orig)), offset + 4)), + ) return false end @@ -1735,9 +2427,16 @@ end function common_apply_iterate_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) - - if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + get_mode(gutils), + ) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0) && + is_constant_inst(gutils, orig) return true end @@ -1750,30 +2449,61 @@ function common_apply_iterate_augfwd(offset, B, orig, gutils, normalR, shadowR, T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - sret = generic_setup(orig, runtime_iterate_augfwd, AnyArray(2+Int(width)), gutils, #=start=#offset+2, B, false) - AT = LLVM.ArrayType(T_prjlvalue, 2+Int(width)) + sret = generic_setup( + orig, + runtime_iterate_augfwd, + AnyArray(2 + Int(width)), + gutils, + offset + 2, + B, + false, + ) #=start=# + AT = LLVM.ArrayType(T_prjlvalue, 2 + Int(width)) - if unsafe_load(shadowR) != C_NULL + if unsafe_load(shadowR) != C_NULL if width == 1 - gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) + gep = LLVM.inbounds_gep!( + B, + AT, + sret, + [LLVM.ConstantInt(0), LLVM.ConstantInt(1)], + ) shadow = LLVM.load!(B, T_prjlvalue, gep) else ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) shadow = LLVM.UndefValue(ST) - for i in 1:width - gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(i)]) + for i = 1:width + gep = LLVM.inbounds_gep!( + B, + AT, + sret, + [LLVM.ConstantInt(0), LLVM.ConstantInt(i)], + ) ld = LLVM.load!(B, T_prjlvalue, gep) - shadow = insert_value!(B, shadow, ld, i-1) + shadow = insert_value!(B, shadow, ld, i - 1) end end unsafe_store!(shadowR, shadow.ref) end - tape = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1+width)])) + tape = LLVM.load!( + B, + T_prjlvalue, + LLVM.inbounds_gep!( + B, + AT, + sret, + [LLVM.ConstantInt(0), LLVM.ConstantInt(1 + width)], + ), + ) unsafe_store!(tapeR, tape.ref) if normalR != C_NULL - normal = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)])) + normal = LLVM.load!( + B, + T_prjlvalue, + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)]), + ) unsafe_store!(normalR, normal.ref) else # Delete the primal code @@ -1784,24 +2514,39 @@ function common_apply_iterate_augfwd(offset, B, orig, gutils, normalR, shadowR, return false end - emit_error(B, orig, "Enzyme: Not yet implemented augmented forward for jl_f__apply_iterate "*string((v, v2, isiter, istup, length(operands(orig)), offset+4))) + emit_error( + B, + orig, + "Enzyme: Not yet implemented augmented forward for jl_f__apply_iterate " * + string((v, v2, isiter, istup, length(operands(orig)), offset + 4)), + ) - unsafe_store!(shadowR,UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))).ref) + unsafe_store!( + shadowR, + UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))).ref, + ) return false end function common_apply_iterate_rev(offset, B, orig, gutils, tape) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal) - - if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + API.DEM_ReverseModePrimal, + ) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0) && + is_constant_inst(gutils, orig) return nothing end @assert tape !== C_NULL width = get_width(gutils) - generic_setup(orig, runtime_iterate_rev, Nothing, gutils, #=start=#offset+2, B, true; tape) + generic_setup(orig, runtime_iterate_rev, Nothing, gutils, offset + 2, B, true; tape) #=start=# return nothing end @@ -1821,37 +2566,62 @@ end function common_invoke_fwd(offset, B, orig, gutils, normalR, shadowR) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) - - if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + get_mode(gutils), + ) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0) && + is_constant_inst(gutils, orig) return true end - + T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) width = get_width(gutils) - sret = generic_setup(orig, runtime_generic_fwd, AnyArray(1+Int(width)), gutils, #=start=#offset+1, B, false) - AT = LLVM.ArrayType(T_prjlvalue, 1+Int(width)) + sret = generic_setup( + orig, + runtime_generic_fwd, + AnyArray(1 + Int(width)), + gutils, + offset + 1, + B, + false, + ) #=start=# + AT = LLVM.ArrayType(T_prjlvalue, 1 + Int(width)) if unsafe_load(shadowR) != C_NULL if width == 1 - gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) + gep = + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) shadow = LLVM.load!(B, T_prjlvalue, gep) else ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) shadow = LLVM.UndefValue(ST) - for i in 1:width - gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(i)]) + for i = 1:width + gep = LLVM.inbounds_gep!( + B, + AT, + sret, + [LLVM.ConstantInt(0), LLVM.ConstantInt(i)], + ) ld = LLVM.load!(B, T_prjlvalue, gep) - shadow = insert_value!(B, shadow, ld, i-1) + shadow = insert_value!(B, shadow, ld, i - 1) end end unsafe_store!(shadowR, shadow.ref) end if unsafe_load(normalR) != C_NULL - normal = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)])) + normal = LLVM.load!( + B, + T_prjlvalue, + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)]), + ) unsafe_store!(normalR, normal.ref) else # Delete the primal code @@ -1865,44 +2635,75 @@ end function common_invoke_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) - - if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + get_mode(gutils), + ) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0) && + is_constant_inst(gutils, orig) return true end - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing - shadow = (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : nothing - + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + shadow = + (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : nothing + T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) conv = LLVM.callconv(orig) width = get_width(gutils) - sret = generic_setup(orig, runtime_generic_augfwd, AnyArray(2+Int(width)), gutils, #=start=#offset+1, B, false) - AT = LLVM.ArrayType(T_prjlvalue, 2+Int(width)) + sret = generic_setup( + orig, + runtime_generic_augfwd, + AnyArray(2 + Int(width)), + gutils, + offset + 1, + B, + false, + ) #=start=# + AT = LLVM.ArrayType(T_prjlvalue, 2 + Int(width)) if unsafe_load(shadowR) != C_NULL if width == 1 - gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) + gep = + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) shadow = LLVM.load!(B, T_prjlvalue, gep) else ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) shadow = LLVM.UndefValue(ST) - for i in 1:width - gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(i)]) + for i = 1:width + gep = LLVM.inbounds_gep!( + B, + AT, + sret, + [LLVM.ConstantInt(0), LLVM.ConstantInt(i)], + ) ld = LLVM.load!(B, T_prjlvalue, gep) - shadow = insert_value!(B, shadow, ld, i-1) + shadow = insert_value!(B, shadow, ld, i - 1) end end unsafe_store!(shadowR, shadow.ref) end - tape = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1+width)])) + tape = LLVM.load!( + B, + T_prjlvalue, + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1 + width)]), + ) unsafe_store!(tapeR, tape.ref) if unsafe_load(normalR) != C_NULL - normal = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)])) + normal = LLVM.load!( + B, + T_prjlvalue, + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)]), + ) unsafe_store!(normalR, normal.ref) else # Delete the primal code @@ -1916,14 +2717,21 @@ end function common_invoke_rev(offset, B, orig, gutils, tape) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal) - - if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + API.DEM_ReverseModePrimal, + ) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0) && + is_constant_inst(gutils, orig) return nothing end - + width = get_width(gutils) - generic_setup(orig, runtime_generic_rev, Nothing, gutils, #=start=#offset+1, B, true; tape) + generic_setup(orig, runtime_generic_rev, Nothing, gutils, offset + 1, B, true; tape) #=start=# return nothing end diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index 965114447cf..a41912cf829 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -1,14 +1,30 @@ macro register_aug(expr) decl = string(expr.args[1]) name = decl[1:prevind(decl, findfirst('(', decl))] - cname = name*"_cfunc" + cname = name * "_cfunc" name = Symbol(name) cname = Symbol(cname) expr2 = :(@inline $expr) res = quote - function $cname(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, normalR::Ptr{LLVM.API.LLVMValueRef}, shadowR::Ptr{LLVM.API.LLVMValueRef}, tapeR::Ptr{LLVM.API.LLVMValueRef})::UInt8 - return UInt8($name(LLVM.IRBuilder(B), LLVM.CallInst(OrigCI), GradientUtils(gutils), normalR, shadowR, tapeR)::Bool) + function $cname( + B::LLVM.API.LLVMBuilderRef, + OrigCI::LLVM.API.LLVMValueRef, + gutils::API.EnzymeGradientUtilsRef, + normalR::Ptr{LLVM.API.LLVMValueRef}, + shadowR::Ptr{LLVM.API.LLVMValueRef}, + tapeR::Ptr{LLVM.API.LLVMValueRef}, + )::UInt8 + return UInt8( + $name( + LLVM.IRBuilder(B), + LLVM.CallInst(OrigCI), + GradientUtils(gutils), + normalR, + shadowR, + tapeR, + )::Bool, + ) end end return Expr(:block, esc(expr2), esc(res)) @@ -17,14 +33,24 @@ end macro register_rev(expr) decl = string(expr.args[1]) name = decl[1:prevind(decl, findfirst('(', decl))] - cname = name*"_cfunc" + cname = name * "_cfunc" name = Symbol(name) cname = Symbol(cname) expr2 = :(@inline $expr) res = quote - function $cname(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, tape::LLVM.API.LLVMValueRef)::Cvoid - $name(LLVM.IRBuilder(B), LLVM.CallInst(OrigCI), GradientUtils(gutils), tape == C_NULL ? nothing : LLVM.Value(tape)) + function $cname( + B::LLVM.API.LLVMBuilderRef, + OrigCI::LLVM.API.LLVMValueRef, + gutils::API.EnzymeGradientUtilsRef, + tape::LLVM.API.LLVMValueRef, + )::Cvoid + $name( + LLVM.IRBuilder(B), + LLVM.CallInst(OrigCI), + GradientUtils(gutils), + tape == C_NULL ? nothing : LLVM.Value(tape), + ) return end end @@ -34,13 +60,27 @@ end macro register_fwd(expr) decl = string(expr.args[1]) name = decl[1:prevind(decl, findfirst('(', decl))] - cname = name*"_cfunc" + cname = name * "_cfunc" name = Symbol(name) cname = Symbol(cname) expr2 = :(@inline $expr) res = quote - function $cname(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, normalR::Ptr{LLVM.API.LLVMValueRef}, shadowR::Ptr{LLVM.API.LLVMValueRef})::UInt8 - return UInt8($name(LLVM.IRBuilder(B), LLVM.CallInst(OrigCI), GradientUtils(gutils), normalR, shadowR)::Bool) + function $cname( + B::LLVM.API.LLVMBuilderRef, + OrigCI::LLVM.API.LLVMValueRef, + gutils::API.EnzymeGradientUtilsRef, + normalR::Ptr{LLVM.API.LLVMValueRef}, + shadowR::Ptr{LLVM.API.LLVMValueRef}, + )::UInt8 + return UInt8( + $name( + LLVM.IRBuilder(B), + LLVM.CallInst(OrigCI), + GradientUtils(gutils), + normalR, + shadowR, + )::Bool, + ) end end return Expr(:block, esc(expr2), esc(res)) @@ -49,13 +89,26 @@ end macro register_diffuse(expr) decl = string(expr.args[1]) name = decl[1:prevind(decl, findfirst('(', decl))] - cname = name*"_cfunc" + cname = name * "_cfunc" name = Symbol(name) cname = Symbol(cname) expr2 = :(@inline $expr) res = quote - function $cname(OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, val::LLVM.API.LLVMValueRef, shadow::UInt8, mode::API.CDerivativeMode, useDefault::Ptr{UInt8})::UInt8 - res = $name(LLVM.CallInst(OrigCI), GradientUtils(gutils), LLVM.Value(val), shadow != 0, mode)::Tuple{Bool, Bool} + function $cname( + OrigCI::LLVM.API.LLVMValueRef, + gutils::API.EnzymeGradientUtilsRef, + val::LLVM.API.LLVMValueRef, + shadow::UInt8, + mode::API.CDerivativeMode, + useDefault::Ptr{UInt8}, + )::UInt8 + res = $name( + LLVM.CallInst(OrigCI), + GradientUtils(gutils), + LLVM.Value(val), + shadow != 0, + mode, + )::Tuple{Bool,Bool} unsafe_store!(useDefault, UInt8(res[2])) return UInt8(res[1]) end @@ -75,7 +128,15 @@ include("parallelrules.jl") if in(name, ("ijl_apply_generic", "jl_apply_generic")) return common_generic_fwd(2, B, orig, gutils, normalR, shadowR) end - if in(name, ("ijl_f__apply_latest", "ijl_f__call_latest", "jl_f__apply_latest", "jl_f__call_latest")) + if in( + name, + ( + "ijl_f__apply_latest", + "ijl_f__call_latest", + "jl_f__apply_latest", + "jl_f__call_latest", + ), + ) return common_apply_latest_fwd(2, B, orig, gutils, normalR, shadowR) end if in(name, ("ijl_new_structv", "jl_new_structv")) @@ -99,17 +160,27 @@ include("parallelrules.jl") if in(name, ("ijl_f_finalizer", "jl_f_finalizer")) return common_finalizer_fwd(2, B, orig, gutils, normalR, shadowR) end - if any(map(k->kind(k)==kind(StringAttribute("enzyme_inactive")), collect(function_attributes(F)))) + if any( + map( + k -> kind(k) == kind(StringAttribute("enzyme_inactive")), + collect(function_attributes(F)), + ), + ) return true end end - err = emit_error(B, orig, "Enzyme: jl_call calling convention not implemented in forward for "*string(orig)) - + err = emit_error( + B, + orig, + "Enzyme: jl_call calling convention not implemented in forward for " * string(orig), + ) + newo = new_from_original(gutils, orig) API.moveBefore(newo, err, B) - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing unsafe_store!(shadowR, normal.ref) end @@ -131,7 +202,15 @@ end if in(name, ("ijl_apply_generic", "jl_apply_generic")) return common_generic_augfwd(2, B, orig, gutils, normalR, shadowR, tapeR) end - if in(name, ("ijl_f__apply_latest", "ijl_f__call_latest", "jl_f__apply_latest", "jl_f__call_latest")) + if in( + name, + ( + "ijl_f__apply_latest", + "ijl_f__call_latest", + "jl_f__apply_latest", + "jl_f__call_latest", + ), + ) return common_apply_latest_augfwd(2, B, orig, gutils, normalR, shadowR, tapeR) end if in(name, ("ijl_new_structv", "jl_new_structv")) @@ -155,16 +234,27 @@ end if in(name, ("ijl_f_finalizer", "jl_f_finalizer")) return common_finalizer_augfwd(2, B, orig, gutils, normalR, shadowR, tapeR) end - if any(map(k->kind(k)==kind(StringAttribute("enzyme_inactive")), collect(function_attributes(F)))) + if any( + map( + k -> kind(k) == kind(StringAttribute("enzyme_inactive")), + collect(function_attributes(F)), + ), + ) return true end end - err = emit_error(B, orig, "Enzyme: jl_call calling convention not implemented in aug_forward for "*string(orig)) + err = emit_error( + B, + orig, + "Enzyme: jl_call calling convention not implemented in aug_forward for " * + string(orig), + ) newo = new_from_original(gutils, orig) API.moveBefore(newo, err, B) - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing unsafe_store!(shadowR, normal.ref) end @@ -187,7 +277,15 @@ end common_generic_rev(2, B, orig, gutils, tape) return nothing end - if in(name, ("ijl_f__apply_latest", "ijl_f__call_latest", "jl_f__apply_latest", "jl_f__call_latest")) + if in( + name, + ( + "ijl_f__apply_latest", + "ijl_f__call_latest", + "jl_f__apply_latest", + "jl_f__call_latest", + ), + ) common_apply_latest_rev(2, B, orig, gutils, tape) return nothing end @@ -219,12 +317,21 @@ end common_finalizer_rev(2, B, orig, gutils, tape) return nothing end - if any(map(k->kind(k)==kind(StringAttribute("enzyme_inactive")), collect(function_attributes(F)))) + if any( + map( + k -> kind(k) == kind(StringAttribute("enzyme_inactive")), + collect(function_attributes(F)), + ), + ) return nothing end end - emit_error(B, orig, "Enzyme: jl_call calling convention not implemented in reverse for "*string(orig)) + emit_error( + B, + orig, + "Enzyme: jl_call calling convention not implemented in reverse for " * string(orig), + ) return nothing end @@ -236,7 +343,12 @@ end if in(name, ("ijl_invoke", "jl_invoke")) return common_invoke_fwd(2, B, orig, gutils, normalR, shadowR) end - if any(map(k->kind(k)==kind(StringAttribute("enzyme_inactive")), collect(function_attributes(F)))) + if any( + map( + k -> kind(k) == kind(StringAttribute("enzyme_inactive")), + collect(function_attributes(F)), + ), + ) return true end end @@ -253,7 +365,12 @@ end if in(name, ("ijl_invoke", "jl_invoke")) return common_invoke_augfwd(2, B, orig, gutils, normalR, shadowR, tapeR) end - if any(map(k->kind(k)==kind(StringAttribute("enzyme_inactive")), collect(function_attributes(F)))) + if any( + map( + k -> kind(k) == kind(StringAttribute("enzyme_inactive")), + collect(function_attributes(F)), + ), + ) return true end end @@ -271,7 +388,12 @@ end common_invoke_rev(2, B, orig, gutils, tape) return nothing end - if any(map(k->kind(k)==kind(StringAttribute("enzyme_inactive")), collect(function_attributes(F)))) + if any( + map( + k -> kind(k) == kind(StringAttribute("enzyme_inactive")), + collect(function_attributes(F)), + ), + ) return nothing end end @@ -295,8 +417,15 @@ end real_ops = collect(operands(orig))[1:end-1] ops = [lookup_value(gutils, new_from_original(gutils, o), B) for o in real_ops] - - c = call_samefunc_with_inverted_bundles!(B, gutils, orig, ops, [API.VT_Primal for _ in ops], #=lookup=#false) + + c = call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + ops, + [API.VT_Primal for _ in ops], + false, + ) #=lookup=# callconv!(c, callconv(orig)) return nothing @@ -319,64 +448,117 @@ end algn = 0 if width == 1 - shadowres = call_samefunc_with_inverted_bundles!(B, gutils, orig, [shadowin], [API.VT_Shadow], #=lookup=#false) + shadowres = call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + [shadowin], + [API.VT_Shadow], + false, + ) #=lookup=# # TODO zero based off runtime types, rather than presume floatlike? if is_constant_value(gutils, origops[1]) elSize = get_array_elsz(B, shadowin) - elSize = LLVM.zext!(B, elSize, LLVM.IntType(8*sizeof(Csize_t))) + elSize = LLVM.zext!(B, elSize, LLVM.IntType(8 * sizeof(Csize_t))) len = get_array_len(B, shadowin) length = LLVM.mul!(B, len, elSize) - bt = GPUCompiler.backtrace(orig) - btstr = sprint() do io - print(io,"\nCaused by:") - Base.show_backtrace(io, bt) - end + bt = GPUCompiler.backtrace(orig) + btstr = sprint() do io + print(io, "\nCaused by:") + Base.show_backtrace(io, bt) + end GPUCompiler.@safe_warn "TODO forward zero-set of arraycopy used memset rather than runtime type $btstr" - LLVM.memset!(B, get_array_data(B, shadowres), LLVM.ConstantInt(i8, 0, false), length, algn) + LLVM.memset!( + B, + get_array_data(B, shadowres), + LLVM.ConstantInt(i8, 0, false), + length, + algn, + ) end if get_runtime_activity(gutils) prev = new_from_original(gutils, orig) - shadowres = LLVM.select!(B, LLVM.icmp!(B, LLVM.API.LLVMIntNE, shadowin, new_from_original(gutils, origops[1])), shadowres, prev) + shadowres = LLVM.select!( + B, + LLVM.icmp!( + B, + LLVM.API.LLVMIntNE, + shadowin, + new_from_original(gutils, origops[1]), + ), + shadowres, + prev, + ) API.moveBefore(prev, shadowres, B) end else - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) - for idx in 1:width - ev = extract_value!(B, shadowin, idx-1) - callv = call_samefunc_with_inverted_bundles!(B, gutils, orig, [ev], [API.VT_Shadow], #=lookup=#false) + shadowres = + UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) + for idx = 1:width + ev = extract_value!(B, shadowin, idx - 1) + callv = call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + [ev], + [API.VT_Shadow], + false, + ) #=lookup=# if is_constant_value(gutils, origops[1]) elSize = get_array_elsz(B, ev) - elSize = LLVM.zext!(B, elSize, LLVM.IntType(8*sizeof(Csize_t))) + elSize = LLVM.zext!(B, elSize, LLVM.IntType(8 * sizeof(Csize_t))) len = get_array_len(B, ev) length = LLVM.mul!(B, len, elSize) - bt = GPUCompiler.backtrace(orig) - btstr = sprint() do io - print(io,"\nCaused by:") - Base.show_backtrace(io, bt) - end + bt = GPUCompiler.backtrace(orig) + btstr = sprint() do io + print(io, "\nCaused by:") + Base.show_backtrace(io, bt) + end GPUCompiler.@safe_warn "TODO forward zero-set of arraycopy used memset rather than runtime type $btstr" - LLVM.memset!(B, get_array_data(B, callv), LLVM.ConstantInt(i8, 0, false), length, algn) + LLVM.memset!( + B, + get_array_data(B, callv), + LLVM.ConstantInt(i8, 0, false), + length, + algn, + ) end if get_runtime_activity(gutils) prev = new_from_original(gutils, orig) - callv = LLVM.select!(B, LLVM.icmp!(B, LLVM.API.LLVMIntNE, ev, new_from_original(gutils, origops[1])), callv, prev) + callv = LLVM.select!( + B, + LLVM.icmp!( + B, + LLVM.API.LLVMIntNE, + ev, + new_from_original(gutils, origops[1]), + ), + callv, + prev, + ) if idx == 1 API.moveBefore(prev, callv, B) end end - shadowres = insert_value!(B, shadowres, callv, idx-1) + shadowres = insert_value!(B, shadowres, callv, idx - 1) end end unsafe_store!(shadowR, shadowres.ref) - return false + return false end function arraycopy_common(fwd, B, orig, origArg, gutils, shadowdst) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + API.DEM_ReverseModePrimal, + ) needsPrimal = needsPrimalP[] != 0 needsShadow = needsShadowP[] != 0 if !needsShadow @@ -390,20 +572,20 @@ function arraycopy_common(fwd, B, orig, origArg, gutils, shadowdst) # size_t len = jl_array_len(ary); # size_t elsz = ary->elsize; # memcpy(new_ary->data, ary->data, len * elsz); - # JL_EXTENSION typedef struct { - # JL_DATA_TYPE - # void *data; - # #ifdef STORE_ARRAY_LEN - # size_t length; - # #endif - # jl_array_flags_t flags; - # uint16_t elsize; // element size including alignment (dim 1 memory stride) - - tt = TypeTree(API.EnzymeGradientUtilsAllocAndGetTypeTree(gutils, orig)) + # JL_EXTENSION typedef struct { + # JL_DATA_TYPE + # void *data; + # #ifdef STORE_ARRAY_LEN + # size_t length; + # #endif + # jl_array_flags_t flags; + # uint16_t elsize; // element size including alignment (dim 1 memory stride) + + tt = TypeTree(API.EnzymeGradientUtilsAllocAndGetTypeTree(gutils, orig)) mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) - dl = string(LLVM.datalayout(mod)) - API.EnzymeTypeTreeLookupEq(tt, 1, dl) - data0!(tt) + dl = string(LLVM.datalayout(mod)) + API.EnzymeTypeTreeLookupEq(tt, 1, dl) + data0!(tt) ct = API.EnzymeTypeTreeInner0(tt) if ct == API.DT_Unknown @@ -411,7 +593,11 @@ function arraycopy_common(fwd, B, orig, origArg, gutils, shadowdst) # ip = API.EnzymeTypeAnalyzerToString(analyzer) # sval = Base.unsafe_string(ip) # API.EnzymeStringFree(ip) - emit_error(B, orig, "Enzyme: Unknown concrete type in arraycopy_common. tt: " * string(tt)) + emit_error( + B, + orig, + "Enzyme: Unknown concrete type in arraycopy_common. tt: " * string(tt), + ) return nothing end @@ -431,7 +617,14 @@ function arraycopy_common(fwd, B, orig, origArg, gutils, shadowdst) B0 = B elseif typeof(actualOp) <: LLVM.Argument B0 = LLVM.IRBuilder() - position!(B0, first(instructions(new_from_original(gutils, LLVM.entry(LLVM.parent(LLVM.parent(orig))))))) + position!( + B0, + first( + instructions( + new_from_original(gutils, LLVM.entry(LLVM.parent(LLVM.parent(orig)))), + ), + ), + ) else B0 = LLVM.IRBuilder() nextInst = LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(actualOp)) @@ -442,7 +635,7 @@ function arraycopy_common(fwd, B, orig, origArg, gutils, shadowdst) end elSize = get_array_elsz(B0, actualOp) - elSize = LLVM.zext!(B0, elSize, LLVM.IntType(8*sizeof(Csize_t))) + elSize = LLVM.zext!(B0, elSize, LLVM.IntType(8 * sizeof(Csize_t))) len = get_array_len(B0, actualOp) @@ -478,30 +671,64 @@ function arraycopy_common(fwd, B, orig, origArg, gutils, shadowdst) if width == 1 - shadowsrc = get_array_data(B, shadowsrc) - shadowdst = get_array_data(B, shadowdst) - - if fwd && secretty != nothing - LLVM.memset!(B, shadowdst, LLVM.ConstantInt(i8, 0, false), length, algn) - end - - API.sub_transfer(gutils, fwd ? API.DEM_ReverseModePrimal : API.DEM_ReverseModeGradient, secretty, intrinsic, #=dstAlign=#1, #=srcAlign=#1, #=offset=#0, false, shadowdst, false, shadowsrc, length, isVolatile, orig, allowForward, #=shadowsLookedUp=#!fwd) + shadowsrc = get_array_data(B, shadowsrc) + shadowdst = get_array_data(B, shadowdst) + + if fwd && secretty != nothing + LLVM.memset!(B, shadowdst, LLVM.ConstantInt(i8, 0, false), length, algn) + end + + API.sub_transfer( + gutils, + fwd ? API.DEM_ReverseModePrimal : API.DEM_ReverseModeGradient, + secretty, + intrinsic, + 1, + 1, + 0, + false, + shadowdst, + false, + shadowsrc, + length, + isVolatile, + orig, + allowForward, + !fwd, + ) #=shadowsLookedUp=# else - for i in 1:width + for i = 1:width - evsrc = extract_value!(B, shadowsrc, i-1) - evdst = extract_value!(B, shadowdst, i-1) + evsrc = extract_value!(B, shadowsrc, i - 1) + evdst = extract_value!(B, shadowdst, i - 1) - shadowsrc0 = get_array_data(B, evsrc) - shadowdst0 = get_array_data(B, evdst) + shadowsrc0 = get_array_data(B, evsrc) + shadowdst0 = get_array_data(B, evdst) - if fwd && secretty != nothing - LLVM.memset!(B, shadowdst0, LLVM.ConstantInt(i8, 0, false), length, algn) - end + if fwd && secretty != nothing + LLVM.memset!(B, shadowdst0, LLVM.ConstantInt(i8, 0, false), length, algn) + end - API.sub_transfer(gutils, fwd ? API.DEM_ReverseModePrimal : API.DEM_ReverseModeGradient, secretty, intrinsic, #=dstAlign=#1, #=srcAlign=#1, #=offset=#0, false, shadowdst0, false, shadowsrc0, length, isVolatile, orig, allowForward, #=shadowsLookedUp=#!fwd) - end + API.sub_transfer( + gutils, + fwd ? API.DEM_ReverseModePrimal : API.DEM_ReverseModeGradient, + secretty, + intrinsic, + 1, + 1, + 0, + false, + shadowdst0, + false, + shadowsrc0, + length, + isVolatile, + orig, + allowForward, + !fwd, + ) #=shadowsLookedUp=# + end end @@ -517,18 +744,18 @@ end origops = LLVM.operands(orig) if !is_constant_value(gutils, origops[1]) && !is_constant_value(gutils, orig) - shadowres = LLVM.Value(unsafe_load(shadowR)) + shadowres = LLVM.Value(unsafe_load(shadowR)) - arraycopy_common(#=fwd=#true, B, orig, origops[1], gutils, shadowres) + arraycopy_common(true, B, orig, origops[1], gutils, shadowres) #=fwd=# end - return false + return false end @register_rev function arraycopy_rev(B, orig, gutils, tape) origops = LLVM.operands(orig) if !is_constant_value(gutils, origops[1]) && !is_constant_value(gutils, orig) - arraycopy_common(#=fwd=#false, B, orig, origops[1], gutils, nothing) + arraycopy_common(false, B, orig, origops[1], gutils, nothing) #=fwd=# end return nothing @@ -548,25 +775,41 @@ end shadowin = invert_pointer(gutils, origops[2], B) if width == 1 args = LLVM.Value[ - new_from_original(gutils, origops[1]) - shadowin - new_from_original(gutils, origops[3]) - ] - shadowres = call_samefunc_with_inverted_bundles!(B, gutils, orig, args, [API.VT_Primal, API.VT_Shadow, API.VT_Primal], #=lookup=#false) + new_from_original(gutils, origops[1]) + shadowin + new_from_original(gutils, origops[3]) + ] + shadowres = call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + args, + [API.VT_Primal, API.VT_Shadow, API.VT_Primal], + false, + ) #=lookup=# else - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) - for idx in 1:width - args = LLVM.Value[new_from_original(gutils, origops[1]) - extract_value!(B, shadowin, idx-1) - new_from_original(gutils, origops[3]) - ] - tmp = call_samefunc_with_inverted_bundles!(B, gutils, orig, args, [API.VT_Primal, API.VT_Shadow, API.VT_Primal], #=lookup=#false) - shadowres = insert_value!(B, shadowres, tmp, idx-1) + shadowres = + UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) + for idx = 1:width + args = LLVM.Value[ + new_from_original(gutils, origops[1]) + extract_value!(B, shadowin, idx - 1) + new_from_original(gutils, origops[3]) + ] + tmp = call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + args, + [API.VT_Primal, API.VT_Shadow, API.VT_Primal], + false, + ) #=lookup=# + shadowres = insert_value!(B, shadowres, tmp, idx - 1) end end unsafe_store!(shadowR, shadowres.ref) - return false + return false end @register_aug function arrayreshape_augfwd(B, orig, gutils, normalR, shadowR, tapeR) @@ -580,12 +823,18 @@ end @register_fwd function gcloaded_fwd(B, orig, gutils, normalR, shadowR) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + get_mode(gutils), + ) - if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) + if (is_constant_value(gutils, orig) || needsShadowP[] == 0) return true end - + origops = LLVM.operands(orig) if is_constant_value(gutils, origops[1]) emit_error(B, orig, "Enzyme: gcloaded has active return, but inactive input(1)") @@ -600,21 +849,36 @@ end shadowin2 = invert_pointer(gutils, origops[2], B) if width == 1 args = LLVM.Value[shadowin1, shadowin2] - shadowres = call_samefunc_with_inverted_bundles!(B, gutils, orig, args, [API.VT_Shadow, API.VT_Shadow], #=lookup=#false) + shadowres = call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + args, + [API.VT_Shadow, API.VT_Shadow], + false, + ) #=lookup=# else - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) - for idx in 1:width + shadowres = + UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) + for idx = 1:width args = LLVM.Value[ - extract_value!(B, shadowin1, idx-1) - extract_value!(B, shadowin2, idx-1) - ] - tmp = call_samefunc_with_inverted_bundles!(B, gutils, orig, args, [API.VT_Shadow, API.VT_Shadow], #=lookup=#false) - shadowres = insert_value!(B, shadowres, tmp, idx-1) + extract_value!(B, shadowin1, idx - 1) + extract_value!(B, shadowin2, idx - 1) + ] + tmp = call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + args, + [API.VT_Shadow, API.VT_Shadow], + false, + ) #=lookup=# + shadowres = insert_value!(B, shadowres, tmp, idx - 1) end end unsafe_store!(shadowR, shadowres.ref) - return false + return false end @register_aug function gcloaded_augfwd(B, orig, gutils, normalR, shadowR, tapeR) @@ -628,10 +892,16 @@ end @register_fwd function boxfloat_fwd(B, orig, gutils, normalR, shadowR) origops = collect(operands(orig)) width = get_width(gutils) - + needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + get_mode(gutils), + ) if is_constant_value(gutils, orig) || needsShadowP[] == 0 return true @@ -643,14 +913,13 @@ end shadowres = LLVM.call!(B, called_type(orig), LLVM.called_operand(orig), shadowsin) callconv!(shadowres, callconv(orig)) else - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) - for idx in 1:width - args = LLVM.Value[ - extract_value!(B, s, idx-1) for s in shadowsin - ] + shadowres = + UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) + for idx = 1:width + args = LLVM.Value[extract_value!(B, s, idx - 1) for s in shadowsin] tmp = LLVM.call!(B, called_type(orig), LLVM.called_operand(orig), args) callconv!(tmp, callconv(orig)) - shadowres = insert_value!(B, shadowres, tmp, idx-1) + shadowres = insert_value!(B, shadowres, tmp, idx - 1) end end unsafe_store!(shadowR, shadowres.ref) @@ -660,10 +929,16 @@ end @register_aug function boxfloat_augfwd(B, orig, gutils, normalR, shadowR, tapeR) origops = collect(operands(orig)) width = get_width(gutils) - + needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + get_mode(gutils), + ) if is_constant_value(gutils, orig) || needsShadowP[] == 0 return true @@ -679,11 +954,11 @@ end shadowres = obj else shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, flt))) - for idx in 1:width + for idx = 1:width obj = emit_allocobj!(B, Base.RefValue{TT}) o2 = bitcast!(B, obj, LLVM.PointerType(flt, addrspace(value_type(obj)))) store!(B, ConstantFP(flt, 0.0), o2) - shadowres = insert_value!(B, shadowres, obj, idx-1) + shadowres = insert_value!(B, shadowres, obj, idx - 1) end end unsafe_store!(shadowR, shadowres.ref) @@ -691,10 +966,16 @@ end end @register_rev function boxfloat_rev(B, orig, gutils, tape) - + needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + API.DEM_ReverseModePrimal, + ) if is_constant_value(gutils, orig) || needsShadowP[] == 0 return nothing @@ -713,12 +994,12 @@ end end else shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, flt))) - for idx in 1:width - ipc = extract_value!(B, ip, idx-1) + for idx = 1:width + ipc = extract_value!(B, ip, idx - 1) ipc = bitcast!(B, ipc, LLVM.PointerType(flt, addrspace(value_type(orig)))) ld = load!(B, flt, ipc) store!(B, ConstantFP(flt, 0.0), ipc) - shadowres = insert_value!(B, shadowres, ld, idx-1) + shadowres = insert_value!(B, shadowres, ld, idx - 1) end if !is_constant_value(gutils, origops[1]) API.EnzymeGradientUtilsAddToDiffe(gutils, origops[1], shadowret, B, flt) @@ -734,7 +1015,8 @@ end emit_error(B, orig, "Enzyme: Not yet implemented forward for jl_eqtable_get") - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing unsafe_store!(shadowR, normal.ref) end @@ -742,11 +1024,13 @@ end return false end -function error_if_active(::Type{T}) where T +function error_if_active(::Type{T}) where {T} seen = () - areg = active_reg_inner(T, seen, nothing, #=justActive=#Val(true)) + areg = active_reg_inner(T, seen, nothing, Val(true)) #=justActive=# if areg == ActiveState - throw(AssertionError("Found unhandled active variable in tuple splat, jl_eqtable $T")) + throw( + AssertionError("Found unhandled active variable in tuple splat, jl_eqtable $T"), + ) end nothing end @@ -755,12 +1039,18 @@ end if is_constant_value(gutils, orig) return true end - + mode = get_mode(gutils) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, mode) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + mode, + ) if needsShadowP[] == 0 return false end @@ -772,14 +1062,49 @@ end origh, origkey, origdflt = operands(orig)[1:end-1] if is_constant_value(gutils, origh) - emit_error(B, orig, "Enzyme: Not yet implemented constant table in jl_eqtable_get "*string(origh)*" "*string(orig)*" result: "*string(absint(orig))*" "*string(abs_typeof(orig, true))*" dict: "*string(absint(origh))*" "*string(abs_typeof(origh, true))*" key "*string(absint(origkey))*" "*string(abs_typeof(origkey, true))*" dflt "*string(absint(origdflt))*" "*string(abs_typeof(origdflt, true))) + emit_error( + B, + orig, + "Enzyme: Not yet implemented constant table in jl_eqtable_get " * + string(origh) * + " " * + string(orig) * + " result: " * + string(absint(orig)) * + " " * + string(abs_typeof(orig, true)) * + " dict: " * + string(absint(origh)) * + " " * + string(abs_typeof(origh, true)) * + " key " * + string(absint(origkey)) * + " " * + string(abs_typeof(origkey, true)) * + " dflt " * + string(absint(origdflt)) * + " " * + string(abs_typeof(origdflt, true)), + ) end - + shadowh = invert_pointer(gutils, origh, B) shadowdflt = if is_constant_value(gutils, origdflt) - shadowdflt2 = julia_error(Base.unsafe_convert(Cstring, "Mixed activity for default of jl_eqtable_get "*string(orig)*" "*string(origdflt)), - orig.ref, API.ET_MixedActivityError, gutils.ref, origdflt.ref, B.ref) + shadowdflt2 = julia_error( + Base.unsafe_convert( + Cstring, + "Mixed activity for default of jl_eqtable_get " * + string(orig) * + " " * + string(origdflt), + ), + orig.ref, + API.ET_MixedActivityError, + gutils.ref, + origdflt.ref, + B.ref, + ) if shadowdflt2 != C_NULL LLVM.Value(shadowdflt2) else @@ -789,8 +1114,8 @@ end else ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(nop))) shadowm = LLVM.UndefValue(ST) - for j in 1:width - shadowm = insert_value!(B, shadowm, nop, j-1) + for j = 1:width + shadowm = insert_value!(B, shadowm, nop, j - 1) end shadowm end @@ -798,24 +1123,41 @@ end else invert_pointer(gutils, origdflt, B) end - + newvals = API.CValueType[API.VT_Shadow, API.VT_Primal, API.VT_Shadow] - + shadowres = if width == 1 newops = LLVM.Value[shadowh, new_from_original(gutils, origkey), shadowdflt] - cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) + cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, false) #=lookup=# callconv!(cal, callconv(orig)) - emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(B, error_if_active), emit_jltypeof!(B, cal)]) + emit_apply_generic!( + B, + LLVM.Value[unsafe_to_llvm(B, error_if_active), emit_jltypeof!(B, cal)], + ) cal else ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) shadow = LLVM.UndefValue(ST) - for j in 1:width - newops = LLVM.Value[extract_value!(B, shadowh, j-1), new_from_original(gutils, origkey), extract_value!(B, shadowdflt, j-1)] - cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) + for j = 1:width + newops = LLVM.Value[ + extract_value!(B, shadowh, j - 1), + new_from_original(gutils, origkey), + extract_value!(B, shadowdflt, j - 1), + ] + cal = call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + newops, + newvals, + false, + ) #=lookup=# callconv!(cal, callconv(orig)) - emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(B, error_if_active), emit_jltypeof!(B, cal)]) - shadow = insert_value!(B, shadow, cal, j-1) + emit_apply_generic!( + B, + LLVM.Value[unsafe_to_llvm(B, error_if_active), emit_jltypeof!(B, cal)], + ) + shadow = insert_value!(B, shadow, cal, j - 1) end shadow end @@ -834,7 +1176,8 @@ end end emit_error(B, orig, "Enzyme: Not yet implemented forward for jl_eqtable_put") - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing unsafe_store!(shadowR, normal.ref) end @@ -857,8 +1200,20 @@ end shadowval = invert_pointer(gutils, origval, B) shadowval = if is_constant_value(gutils, origval) - shadowdflt2 = julia_error(Base.unsafe_convert(Cstring, "Mixed activity for val of jl_eqtable_put "*string(orig)*" "*string(origval)), - orig.ref, API.ET_MixedActivityError, gutils.ref, origval.ref, B.ref) + shadowdflt2 = julia_error( + Base.unsafe_convert( + Cstring, + "Mixed activity for val of jl_eqtable_put " * + string(orig) * + " " * + string(origval), + ), + orig.ref, + API.ET_MixedActivityError, + gutils.ref, + origval.ref, + B.ref, + ) if shadowdflt2 != C_NULL LLVM.Value(shadowdflt2) else @@ -868,8 +1223,8 @@ end else ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(nop))) shadowm = LLVM.UndefValue(ST) - for j in 1:width - shadowm = insert_value!(B, shadowm, nop, j-1) + for j = 1:width + shadowm = insert_value!(B, shadowm, nop, j - 1) end shadowm end @@ -881,23 +1236,46 @@ end mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) newvals = API.CValueType[API.VT_Shadow, API.VT_Primal, API.VT_Shadow, API.VT_None] - + shadowres = if width == 1 - emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(B, error_if_active), emit_jltypeof!(B, shadowval)]) - newops = LLVM.Value[shadowh, new_from_original(gutils, origkey), shadowval, LLVM.null(value_type(originserted))] - cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) + emit_apply_generic!( + B, + LLVM.Value[unsafe_to_llvm(B, error_if_active), emit_jltypeof!(B, shadowval)], + ) + newops = LLVM.Value[ + shadowh, + new_from_original(gutils, origkey), + shadowval, + LLVM.null(value_type(originserted)), + ] + cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, false) #=lookup=# callconv!(cal, callconv(orig)) cal else ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) shadow = LLVM.UndefValue(ST) - for j in 1:width - sval2 = extract_value!(B, shadowval, j-1) - emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(B, error_if_active), emit_jltypeof!(B, sval2)]) - newops = LLVM.Value[extract_value!(B, shadowh, j-1), new_from_original(gutils, origkey), sval2, LLVM.null(value_type(originserted))] - cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) + for j = 1:width + sval2 = extract_value!(B, shadowval, j - 1) + emit_apply_generic!( + B, + LLVM.Value[unsafe_to_llvm(B, error_if_active), emit_jltypeof!(B, sval2)], + ) + newops = LLVM.Value[ + extract_value!(B, shadowh, j - 1), + new_from_original(gutils, origkey), + sval2, + LLVM.null(value_type(originserted)), + ] + cal = call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + newops, + newvals, + false, + ) #=lookup=# callconv!(cal, callconv(orig)) - shadow = insert_value!(B, shadow, cal, j-1) + shadow = insert_value!(B, shadow, cal, j - 1) end shadow end @@ -917,7 +1295,8 @@ end end emit_error(B, orig, "Enzyme: Not yet implemented forward for jl_idtable_rehash") - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing unsafe_store!(shadowR, normal.ref) end @@ -929,9 +1308,14 @@ end if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end - emit_error(B, orig, "Enzyme: Not yet implemented augmented forward for jl_idtable_rehash") + emit_error( + B, + orig, + "Enzyme: Not yet implemented augmented forward for jl_idtable_rehash", + ) - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing unsafe_store!(shadowR, normal.ref) end @@ -955,17 +1339,31 @@ end shadowin = invert_pointer(gutils, origops[1], B) if width == 1 args = LLVM.Value[ - shadowin - new_from_original(gutils, origops[2]) - ] - call_samefunc_with_inverted_bundles!(B, gutils, orig, args, [API.VT_Shadow, API.VT_Primal], #=lookup=#false) + shadowin + new_from_original(gutils, origops[2]) + ] + call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + args, + [API.VT_Shadow, API.VT_Primal], + false, + ) #=lookup=# else - for idx in 1:width + for idx = 1:width args = LLVM.Value[ - extract_value!(B, shadowin, idx-1) - new_from_original(gutils, origops[2]) - ] - call_samefunc_with_inverted_bundles!(B, gutils, orig, args, [API.VT_Shadow, API.VT_Primal], #=lookup=#false) + extract_value!(B, shadowin, idx - 1) + new_from_original(gutils, origops[2]) + ] + call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + args, + [API.VT_Shadow, API.VT_Primal], + false, + ) #=lookup=# end end return false @@ -997,14 +1395,21 @@ end tot = mul!(B, inc, elsz) args = LLVM.Value[anti, inc] - call_samefunc_with_inverted_bundles!(B, gutils, orig, args, [API.VT_Shadow, API.VT_Primal], #=lookup=#false) + call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + args, + [API.VT_Shadow, API.VT_Primal], + false, + ) #=lookup=# toset = get_array_data(B, anti) toset = gep!(B, i8, toset, LLVM.Value[off]) mcall = LLVM.memset!(B, toset, LLVM.ConstantInt(i8, 0, false), tot, al) else - for idx in 1:width - anti = extract_value!(B, shadowin, idx-1) + for idx = 1:width + anti = extract_value!(B, shadowin, idx - 1) idx = get_array_nrows(B, anti) elsz = zext!(B, get_array_elsz(B, anti), value_type(idx)) @@ -1012,7 +1417,14 @@ end tot = mul!(B, inc, elsz) args = LLVM.Value[anti, inc] - call_samefunc_with_inverted_bundles!(B, gutils, orig, args, [API.VT_Shadow, API.VT_Primal], #=lookup=#false) + call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + args, + [API.VT_Shadow, API.VT_Primal], + false, + ) #=lookup=# toset = get_array_data(B, anti) toset = gep!(B, i8, toset, LLVM.Value[off]) @@ -1042,17 +1454,18 @@ end if width == 1 args = LLVM.Value[ - shadowin - offset - ] + shadowin + offset + ] LLVM.call!(B, fty, delF, args) else - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) - for idx in 1:width + shadowres = + UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) + for idx = 1:width args = LLVM.Value[ - extract_value!(B, shadowin, idx-1) - offset - ] + extract_value!(B, shadowin, idx - 1) + offset + ] LLVM.call!(B, fty, delF, args) end end @@ -1086,32 +1499,40 @@ end # TODO get actual alignment algn = 0 - + i8 = LLVM.IntType(8) - for idx in 1:width + for idx = 1:width anti = if width == 1 shadowin else - extract_value!(B, shadowin, idx-1) + extract_value!(B, shadowin, idx - 1) end if get_runtime_activity(gutils) - emit_error(B, orig, "Enzyme: Not yet implemented runtime activity for reverse of jl_array_del_end") + emit_error( + B, + orig, + "Enzyme: Not yet implemented runtime activity for reverse of jl_array_del_end", + ) end args = LLVM.Value[anti, offset] - - found, arty = abs_typeof(origops[1]) + + found, arty, byref = abs_typeof(origops[1]) anti = shadowin elSize = if found LLVM.ConstantInt(Csize_t(sizeof(eltype(arty)))) else - elSize = LLVM.zext!(B, get_array_elsz(B, anti), LLVM.IntType(8*sizeof(Csize_t))) + elSize = LLVM.zext!( + B, + get_array_elsz(B, anti), + LLVM.IntType(8 * sizeof(Csize_t)), + ) end len = get_array_len(B, anti) - + LLVM.call!(B, fty, delF, args) - + length = LLVM.mul!(B, len, elSize) - + if !found && !(eltype(arty) <: Base.IEEEFloat) GPUCompiler.@safe_warn "TODO reverse jl_array_del_end zero-set used memset rather than runtime type of $((found, arty)) in $(string(origops[1]))" end @@ -1138,22 +1559,30 @@ end push!(args, v) end push!(args, new_from_original(gutils, origops[end-1])) - valTys = API.CValueType[API.VT_Shadow, API.VT_Shadow, API.VT_Shadow, API.VT_Shadow, API.VT_Primal] + valTys = API.CValueType[ + API.VT_Shadow, + API.VT_Shadow, + API.VT_Shadow, + API.VT_Shadow, + API.VT_Primal, + ] if width == 1 vargs = args - cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, vargs, valTys, #=lookup=#false) + cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, vargs, valTys, false) #=lookup=# debug_from_orig!(gutils, cal, orig) callconv!(cal, callconv(orig)) else - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) - for idx in 1:width + shadowres = + UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) + for idx = 1:width vargs = LLVM.Value[] for a in args[1:end-1] - push!(vargs, extract_value!(B, a, idx-1)) + push!(vargs, extract_value!(B, a, idx - 1)) end push!(vargs, args[end]) - cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, vargs, valTys, #=lookup=#false) + cal = + call_samefunc_with_inverted_bundles!(B, gutils, orig, vargs, valTys, false) #=lookup=# debug_from_orig!(gutils, cal, orig) callconv!(cal, callconv(orig)) end @@ -1162,7 +1591,7 @@ end return false end @register_aug function jl_array_ptr_copy_augfwd(B, orig, gutils, normalR, shadowR, tapeR) - jl_array_ptr_copy_fwd(B, orig, gutils, normalR, shadowR) + jl_array_ptr_copy_fwd(B, orig, gutils, normalR, shadowR) end @register_rev function jl_array_ptr_copy_rev(B, orig, gutils, tape) return nothing @@ -1178,18 +1607,33 @@ end shadowin = invert_pointer(gutils, origops[1], B) if width == 1 args = LLVM.Value[ - shadowin - new_from_original(gutils, origops[2]) - ] - call_samefunc_with_inverted_bundles!(B, gutils, orig, args, [API.VT_Shadow, API.VT_Primal], #=lookup=#false) + shadowin + new_from_original(gutils, origops[2]) + ] + call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + args, + [API.VT_Shadow, API.VT_Primal], + false, + ) #=lookup=# else - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) - for idx in 1:width + shadowres = + UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) + for idx = 1:width args = LLVM.Value[ - extract_value!(B, shadowin, idx-1) - new_from_original(gutils, origops[2]) - ] - call_samefunc_with_inverted_bundles!(B, gutils, orig, args, [API.VT_Shadow, API.VT_Primal], #=lookup=#false) + extract_value!(B, shadowin, idx - 1) + new_from_original(gutils, origops[2]) + ] + call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + args, + [API.VT_Shadow, API.VT_Primal], + false, + ) #=lookup=# end end return false @@ -1206,9 +1650,10 @@ end @register_fwd function jl_unhandled_fwd(B, orig, gutils, normalR, shadowR) newo = new_from_original(gutils, orig) origops = collect(operands(orig)) - err = emit_error(B, orig, "Enzyme: unhandled forward for "*string(origops[end])) + err = emit_error(B, orig, "Enzyme: unhandled forward for " * string(origops[end])) API.moveBefore(newo, err, C_NULL) - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing width = get_width(gutils) @@ -1216,9 +1661,11 @@ end shadowres = normal else position!(B, LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(normal))) - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(normal)))) - for idx in 1:width - shadowres = insert_value!(B, shadowres, normal, idx-1) + shadowres = UndefValue( + LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(normal))), + ) + for idx = 1:width + shadowres = insert_value!(B, shadowres, normal, idx - 1) end end unsafe_store!(shadowR, shadowres.ref) @@ -1226,7 +1673,7 @@ end return false end @register_aug function jl_unhandled_augfwd(B, orig, gutils, normalR, shadowR, tapeR) - jl_unhandled_fwd(B, orig, gutils, normalR, shadowR) + jl_unhandled_fwd(B, orig, gutils, normalR, shadowR) end @register_rev function jl_unhandled_rev(B, orig, gutils, tape) return nothing @@ -1241,16 +1688,21 @@ end API.moveBefore(newo, err, B) if unsafe_load(shadowR) != C_NULL - valTys = API.CValueType[API.VT_Primal, API.VT_Primal] - args = [new_from_original(gutils, operands(orig)[1]), new_from_original(gutils, operands(orig)[2])] - normal = call_samefunc_with_inverted_bundles!(B, gutils, orig, args, valTys, #=lookup=#false) + valTys = API.CValueType[API.VT_Primal, API.VT_Primal] + args = [ + new_from_original(gutils, operands(orig)[1]), + new_from_original(gutils, operands(orig)[2]), + ] + normal = call_samefunc_with_inverted_bundles!(B, gutils, orig, args, valTys, false) #=lookup=# width = get_width(gutils) if width == 1 shadowres = normal else - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(normal)))) - for idx in 1:width - shadowres = insert_value!(B, shadowres, normal, idx-1) + shadowres = UndefValue( + LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(normal))), + ) + for idx = 1:width + shadowres = insert_value!(B, shadowres, normal, idx - 1) end end unsafe_store!(shadowR, shadowres.ref) @@ -1262,20 +1714,29 @@ end if is_constant_value(gutils, orig) return true end - err = emit_error(B, orig, "Enzyme: unhandled augmented forward for jl_get_binding_or_error") + err = emit_error( + B, + orig, + "Enzyme: unhandled augmented forward for jl_get_binding_or_error", + ) newo = new_from_original(gutils, orig) API.moveBefore(newo, err, B) if unsafe_load(shadowR) != C_NULL - valTys = API.CValueType[API.VT_Primal, API.VT_Primal] - args = [new_from_original(gutils, operands(orig)[1]), new_from_original(gutils, operands(orig)[2])] - normal = call_samefunc_with_inverted_bundles!(B, gutils, orig, args, valTys, #=lookup=#false) + valTys = API.CValueType[API.VT_Primal, API.VT_Primal] + args = [ + new_from_original(gutils, operands(orig)[1]), + new_from_original(gutils, operands(orig)[2]), + ] + normal = call_samefunc_with_inverted_bundles!(B, gutils, orig, args, valTys, false) #=lookup=# width = get_width(gutils) if width == 1 shadowres = normal else - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(normal)))) - for idx in 1:width - shadowres = insert_value!(B, shadowres, normal, idx-1) + shadowres = UndefValue( + LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(normal))), + ) + for idx = 1:width + shadowres = insert_value!(B, shadowres, normal, idx - 1) end end unsafe_store!(shadowR, shadowres.ref) @@ -1292,10 +1753,15 @@ end if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end - err = emit_error(B, orig, "Enzyme: unhandled forward for jl_gc_add_finalizer_th or jl_gc_add_ptr_finalizer") + err = emit_error( + B, + orig, + "Enzyme: unhandled forward for jl_gc_add_finalizer_th or jl_gc_add_ptr_finalizer", + ) newo = new_from_original(gutils, orig) API.moveBefore(newo, err, B) - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing unsafe_store!(shadowR, normal.ref) end @@ -1306,10 +1772,15 @@ end if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end - err = emit_error(B, orig, "Enzyme: unhandled augmented forward for jl_gc_add_finalizer_th") + err = emit_error( + B, + orig, + "Enzyme: unhandled augmented forward for jl_gc_add_finalizer_th", + ) newo = new_from_original(gutils, orig) API.moveBefore(newo, err, B) - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing unsafe_store!(shadowR, normal.ref) end @@ -1333,10 +1804,15 @@ end if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end - err = emit_error(B, orig, "There is a known issue in GPUCompiler.jl which is preventing higher-order AD of this code.\nPlease see https://github.com/JuliaGPU/GPUCompiler.jl/issues/629 for more information and to alert the GPUCompiler authors of your use case and need.") + err = emit_error( + B, + orig, + "There is a known issue in GPUCompiler.jl which is preventing higher-order AD of this code.\nPlease see https://github.com/JuliaGPU/GPUCompiler.jl/issues/629 for more information and to alert the GPUCompiler authors of your use case and need.", + ) newo = new_from_original(gutils, orig) API.moveBefore(newo, err, B) - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing unsafe_store!(shadowR, normal.ref) end @@ -1347,10 +1823,15 @@ end if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end - err = emit_error(B, orig, "There is a known issue in GPUCompiler.jl which is preventing higher-order AD of this code.\nPlease see https://github.com/JuliaGPU/GPUCompiler.jl/issues/629 for more information and to alert the GPUCompiler authors of your use case and need.") + err = emit_error( + B, + orig, + "There is a known issue in GPUCompiler.jl which is preventing higher-order AD of this code.\nPlease see https://github.com/JuliaGPU/GPUCompiler.jl/issues/629 for more information and to alert the GPUCompiler authors of your use case and need.", + ) newo = new_from_original(gutils, orig) API.moveBefore(newo, err, B) - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing unsafe_store!(shadowR, normal.ref) end @@ -1369,7 +1850,7 @@ end end -function register_handler!(variants, augfwd_handler, rev_handler, fwd_handler=nothing) +function register_handler!(variants, augfwd_handler, rev_handler, fwd_handler = nothing) for variant in variants if augfwd_handler !== nothing && rev_handler !== nothing API.EnzymeRegisterCallHandler(variant, augfwd_handler, rev_handler) @@ -1381,31 +1862,71 @@ function register_handler!(variants, augfwd_handler, rev_handler, fwd_handler=no end macro augfunc(f) - cname = Symbol(string(f)*"_cfunc") - :(@cfunction($cname, UInt8, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef}) + cname = Symbol(string(f) * "_cfunc") + :(@cfunction( + $cname, + UInt8, + ( + LLVM.API.LLVMBuilderRef, + LLVM.API.LLVMValueRef, + API.EnzymeGradientUtilsRef, + Ptr{LLVM.API.LLVMValueRef}, + Ptr{LLVM.API.LLVMValueRef}, + Ptr{LLVM.API.LLVMValueRef}, + ) )) end macro revfunc(f) - cname = Symbol(string(f)*"_cfunc") - :(@cfunction($cname, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef) + cname = Symbol(string(f) * "_cfunc") + :(@cfunction( + $cname, + Cvoid, + ( + LLVM.API.LLVMBuilderRef, + LLVM.API.LLVMValueRef, + API.EnzymeGradientUtilsRef, + LLVM.API.LLVMValueRef, + ) )) end macro fwdfunc(f) - cname = Symbol(string(f)*"_cfunc") - :(@cfunction($cname, UInt8, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef}) + cname = Symbol(string(f) * "_cfunc") + :(@cfunction( + $cname, + UInt8, + ( + LLVM.API.LLVMBuilderRef, + LLVM.API.LLVMValueRef, + API.EnzymeGradientUtilsRef, + Ptr{LLVM.API.LLVMValueRef}, + Ptr{LLVM.API.LLVMValueRef}, + ) )) end macro diffusefunc(f) - cname = Symbol(string(f)*"_cfunc") - :(@cfunction(Compiler.$cname, UInt8, (LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef, UInt8, API.CDerivativeMode, Ptr{UInt8}) + cname = Symbol(string(f) * "_cfunc") + :(@cfunction( + Compiler.$cname, + UInt8, + ( + LLVM.API.LLVMValueRef, + API.EnzymeGradientUtilsRef, + LLVM.API.LLVMValueRef, + UInt8, + API.CDerivativeMode, + Ptr{UInt8}, + ) )) end @noinline function register_llvm_rules() - API.EnzymeRegisterDiffUseCallHandler("enzyme_custom", @diffusefunc(enzyme_custom_diffuse)) + API.EnzymeRegisterDiffUseCallHandler( + "enzyme_custom", + @diffusefunc(enzyme_custom_diffuse) + ) register_handler!( ("julia.call",), @augfunc(jlcall_augfwd), @@ -1473,79 +1994,79 @@ end @fwdfunc(wait_fwd), ) register_handler!( - ("jl_","jl_breakpoint"), + ("jl_", "jl_breakpoint"), @augfunc(noop_augfwd), @revfunc(duplicate_rev), @fwdfunc(noop_fwd), ) register_handler!( - ("jl_array_copy","ijl_array_copy"), + ("jl_array_copy", "ijl_array_copy"), @augfunc(arraycopy_augfwd), @revfunc(arraycopy_rev), @fwdfunc(arraycopy_fwd), ) register_handler!( - ("jl_reshape_array","ijl_reshape_array"), + ("jl_reshape_array", "ijl_reshape_array"), @augfunc(arrayreshape_augfwd), @revfunc(arrayreshape_rev), @fwdfunc(arrayreshape_fwd), ) register_handler!( - ("jl_f_setfield","ijl_f_setfield"), + ("jl_f_setfield", "ijl_f_setfield"), @augfunc(setfield_augfwd), @revfunc(setfield_rev), @fwdfunc(setfield_fwd), ) register_handler!( - ("jl_box_float32","ijl_box_float32", "jl_box_float64", "ijl_box_float64"), + ("jl_box_float32", "ijl_box_float32", "jl_box_float64", "ijl_box_float64"), @augfunc(boxfloat_augfwd), @revfunc(boxfloat_rev), @fwdfunc(boxfloat_fwd), ) register_handler!( - ("jl_f_tuple","ijl_f_tuple"), + ("jl_f_tuple", "ijl_f_tuple"), @augfunc(f_tuple_augfwd), @revfunc(f_tuple_rev), @fwdfunc(f_tuple_fwd), ) register_handler!( - ("jl_eqtable_get","ijl_eqtable_get"), + ("jl_eqtable_get", "ijl_eqtable_get"), @augfunc(eqtableget_augfwd), @revfunc(eqtableget_rev), @fwdfunc(eqtableget_fwd), ) register_handler!( - ("jl_eqtable_put","ijl_eqtable_put"), + ("jl_eqtable_put", "ijl_eqtable_put"), @augfunc(eqtableput_augfwd), @revfunc(eqtableput_rev), @fwdfunc(eqtableput_fwd), ) register_handler!( - ("jl_idtable_rehash","ijl_idtable_rehash"), + ("jl_idtable_rehash", "ijl_idtable_rehash"), @augfunc(idtablerehash_augfwd), @revfunc(idtablerehash_rev), @fwdfunc(idtablerehash_fwd), ) register_handler!( - ("jl_f__apply_iterate","ijl_f__apply_iterate"), + ("jl_f__apply_iterate", "ijl_f__apply_iterate"), @augfunc(apply_iterate_augfwd), @revfunc(apply_iterate_rev), @fwdfunc(apply_iterate_fwd), ) register_handler!( - ("jl_f__svec_ref","ijl_f__svec_ref"), + ("jl_f__svec_ref", "ijl_f__svec_ref"), @augfunc(f_svec_ref_augfwd), @revfunc(f_svec_ref_rev), @fwdfunc(f_svec_ref_fwd), ) register_handler!( - ("jl_new_structv","ijl_new_structv"), + ("jl_new_structv", "ijl_new_structv"), @augfunc(new_structv_augfwd), @revfunc(new_structv_rev), @fwdfunc(new_structv_fwd), ) register_handler!( - ("jl_new_structt","ijl_new_structt"), + ("jl_new_structt", "ijl_new_structt"), @augfunc(new_structt_augfwd), @revfunc(new_structt_rev), @fwdfunc(new_structt_fwd), @@ -1557,7 +2078,12 @@ end @fwdfunc(get_binding_or_error_fwd), ) register_handler!( - ("jl_gc_add_finalizer_th","ijl_gc_add_finalizer_th", "jl_gc_add_ptr_finalizer","ijl_gc_add_ptr_finalizer"), + ( + "jl_gc_add_finalizer_th", + "ijl_gc_add_finalizer_th", + "jl_gc_add_ptr_finalizer", + "ijl_gc_add_ptr_finalizer", + ), @augfunc(finalizer_augfwd), @revfunc(finalizer_rev), @fwdfunc(finalizer_fwd), @@ -1569,37 +2095,37 @@ end @fwdfunc(deferred_fwd), ) register_handler!( - ("jl_array_grow_end","ijl_array_grow_end"), + ("jl_array_grow_end", "ijl_array_grow_end"), @augfunc(jl_array_grow_end_augfwd), @revfunc(jl_array_grow_end_rev), @fwdfunc(jl_array_grow_end_fwd), ) register_handler!( - ("jl_array_del_end","ijl_array_del_end"), + ("jl_array_del_end", "ijl_array_del_end"), @augfunc(jl_array_del_end_augfwd), @revfunc(jl_array_del_end_rev), @fwdfunc(jl_array_del_end_fwd), ) register_handler!( - ("jl_f_getfield","ijl_f_getfield"), + ("jl_f_getfield", "ijl_f_getfield"), @augfunc(jl_getfield_augfwd), @revfunc(jl_getfield_rev), @fwdfunc(jl_getfield_fwd), ) register_handler!( - ("ijl_get_nth_field_checked","jl_get_nth_field_checked"), + ("ijl_get_nth_field_checked", "jl_get_nth_field_checked"), @augfunc(jl_nthfield_augfwd), @revfunc(jl_nthfield_rev), @fwdfunc(jl_nthfield_fwd), ) register_handler!( - ("jl_array_sizehint","ijl_array_sizehint"), + ("jl_array_sizehint", "ijl_array_sizehint"), @augfunc(jl_array_sizehint_augfwd), @revfunc(jl_array_sizehint_rev), @fwdfunc(jl_array_sizehint_fwd), ) register_handler!( - ("jl_array_ptr_copy","ijl_array_ptr_copy"), + ("jl_array_ptr_copy", "ijl_array_ptr_copy"), @augfunc(jl_array_ptr_copy_augfwd), @revfunc(jl_array_ptr_copy_rev), @fwdfunc(jl_array_ptr_copy_fwd), diff --git a/src/rules/parallelrules.jl b/src/rules/parallelrules.jl index 29648389474..d882fc26721 100644 --- a/src/rules/parallelrules.jl +++ b/src/rules/parallelrules.jl @@ -1,9 +1,30 @@ -function runtime_newtask_fwd(world::Val{World}, fn::FT1, dfn::FT2, post::Any, ssize::Int, runtimeActivity::Val{RuntimeActivity}, ::Val{width}) where {FT1, FT2, World, width, RuntimeActivity} +function runtime_newtask_fwd( + world::Val{World}, + fn::FT1, + dfn::FT2, + post::Any, + ssize::Int, + runtimeActivity::Val{RuntimeActivity}, + ::Val{width}, +) where {FT1,FT2,World,width,RuntimeActivity} FT = Core.Typeof(fn) ghos = guaranteed_const(FT) opt_mi = world - forward = thunk(opt_mi, (ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ForwardMode), Val(width), Val((false,)), #=returnPrimal=#Val(true), #=shadowinit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity) + forward = thunk( + opt_mi, + (ghos ? Const : Duplicated){FT}, + Const, + Tuple{}, + Val(API.DEM_ForwardMode), + Val(width), + Val((false,)), + Val(true), + Val(false), + FFIABI, + Val(false), + runtimeActivity, + ) #=erriffuncwritten=# ft = ghos ? Const(fn) : Duplicated(fn, dfn) function fclosure() res = forward(ft) @@ -13,12 +34,34 @@ function runtime_newtask_fwd(world::Val{World}, fn::FT1, dfn::FT2, post::Any, ss return ccall(:jl_new_task, Ref{Task}, (Any, Any, Int), fclosure, post, ssize) end -function runtime_newtask_augfwd(world::Val{World}, fn::FT1, dfn::FT2, post::Any, ssize::Int, runtimeActivity::Val{RuntimeActivity}, ::Val{width}, ::Val{ModifiedBetween}) where {FT1, FT2, World, width, ModifiedBetween, RuntimeActivity} +function runtime_newtask_augfwd( + world::Val{World}, + fn::FT1, + dfn::FT2, + post::Any, + ssize::Int, + runtimeActivity::Val{RuntimeActivity}, + ::Val{width}, + ::Val{ModifiedBetween}, +) where {FT1,FT2,World,width,ModifiedBetween,RuntimeActivity} # TODO make this AD subcall type stable FT = Core.Typeof(fn) ghos = guaranteed_const(FT) opt_mi = world - forward, adjoint = thunk(opt_mi, (ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ReverseModePrimal), Val(width), Val(ModifiedBetween), #=returnPrimal=#Val(true), #=shadowinit=#Val(false), FFIABI, #=erriffuncwritten=#Val(false), runtimeActivity) + forward, adjoint = thunk( + opt_mi, + (ghos ? Const : Duplicated){FT}, + Const, + Tuple{}, + Val(API.DEM_ReverseModePrimal), + Val(width), + Val(ModifiedBetween), + Val(true), + Val(false), + FFIABI, + Val(false), + runtimeActivity, + ) #=erriffuncwritten=# ft = ghos ? Const(fn) : Duplicated(fn, dfn) taperef = Ref{Any}() @@ -41,13 +84,17 @@ function runtime_newtask_augfwd(world::Val{World}, fn::FT1, dfn::FT2, post::Any, end -function referenceCaller(fn::Ref{Clos}, args...) where Clos +function referenceCaller(fn::Ref{Clos}, args...) where {Clos} fval = fn[] fval = fval::Clos fval(args...) end -function runtime_pfor_fwd(thunk::ThunkTy, ft::FT, threading_args...)::Cvoid where {ThunkTy, FT} +function runtime_pfor_fwd( + thunk::ThunkTy, + ft::FT, + threading_args..., +)::Cvoid where {ThunkTy,FT} function fwd(tid_args...) if length(tid_args) == 0 thunk(ft) @@ -59,12 +106,21 @@ function runtime_pfor_fwd(thunk::ThunkTy, ft::FT, threading_args...)::Cvoid wher return end -function runtime_pfor_augfwd(thunk::ThunkTy, ft::FT, ::Val{AnyJL}, ::Val{byRef}, threading_args...) where {ThunkTy, FT, AnyJL, byRef} +function runtime_pfor_augfwd( + thunk::ThunkTy, + ft::FT, + ::Val{AnyJL}, + ::Val{byRef}, + threading_args..., +) where {ThunkTy,FT,AnyJL,byRef} TapeType = EnzymeRules.tape_type(ThunkTy) tapes = if AnyJL Vector{TapeType}(undef, Base.Threads.nthreads()) else - Base.unsafe_convert(Ptr{TapeType}, Libc.malloc(sizeof(TapeType)*Base.Threads.nthreads())) + Base.unsafe_convert( + Ptr{TapeType}, + Libc.malloc(sizeof(TapeType) * Base.Threads.nthreads()), + ) end function fwd(tid_args...) @@ -94,7 +150,14 @@ function runtime_pfor_augfwd(thunk::ThunkTy, ft::FT, ::Val{AnyJL}, ::Val{byRef}, return tapes end -function runtime_pfor_rev(thunk::ThunkTy, ft::FT, ::Val{AnyJL}, ::Val{byRef}, tapes, threading_args...) where {ThunkTy, FT, AnyJL, byRef} +function runtime_pfor_rev( + thunk::ThunkTy, + ft::FT, + ::Val{AnyJL}, + ::Val{byRef}, + tapes, + threading_args..., +) where {ThunkTy,FT,AnyJL,byRef} function rev(tid_args...) tid = if length(tid_args) == 0 tid = Base.Threads.threadid() @@ -130,7 +193,7 @@ function runtime_pfor_rev(thunk::ThunkTy, ft::FT, ::Val{AnyJL}, ::Val{byRef}, ta return nothing end -@inline function threadsfor_common(orig, gutils, B, mode, tape=nothing) +@inline function threadsfor_common(orig, gutils, B, mode, tape = nothing) mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) @@ -182,25 +245,54 @@ end width = get_width(gutils) ops = collect(operands(orig))[1:end-1] - dupClosure = !guaranteed_const_nongen(funcT, world) && !is_constant_value(gutils, ops[1]) + dupClosure = + !guaranteed_const_nongen(funcT, world) && !is_constant_value(gutils, ops[1]) pdupClosure = dupClosure subfunc = nothing if mode == API.DEM_ForwardMode if fwdmodenm === nothing etarget = Compiler.EnzymeTarget() - eparams = Compiler.EnzymeCompilerParams(Tuple{(dupClosure ? Duplicated : Const){funcT}, e_tt.parameters...}, API.DEM_ForwardMode, width, Const{Nothing}, #=runEnzyme=#true, #=abiwrap=#true, modifiedBetween, #=returnPrimal=#false, #=shadowInit=#false, UnknownTapeType, FFIABI, #=ErrIfFuncWritten=#false, get_runtime_activity(gutils)) - ejob = Compiler.CompilerJob(mi2, CompilerConfig(etarget, eparams; kernel=false), world) + eparams = Compiler.EnzymeCompilerParams( + Tuple{(dupClosure ? Duplicated : Const){funcT},e_tt.parameters...}, + API.DEM_ForwardMode, + width, + Const{Nothing}, + true, + true, + modifiedBetween, + false, + false, + UnknownTapeType, + FFIABI, + false, + get_runtime_activity(gutils), + ) #=ErrIfFuncWritten=# + ejob = Compiler.CompilerJob( + mi2, + CompilerConfig(etarget, eparams; kernel = false), + world, + ) + + cmod, fwdmodenm, _, _ = _thunk(ejob, false) #=postopt=# - cmod, fwdmodenm, _, _ = _thunk(ejob, #=postopt=#false) - LLVM.link!(mod, cmod) push!(attributes, StringAttribute("enzymejl_forward", fwdmodenm)) - push!(function_attributes(functions(mod)[fwdmodenm]), EnumAttribute("alwaysinline")) + push!( + function_attributes(functions(mod)[fwdmodenm]), + EnumAttribute("alwaysinline"), + ) permit_inlining!(functions(mod)[fwdmodenm]) end - thunkTy = ForwardModeThunk{Ptr{Cvoid}, dupClosure ? Duplicated{funcT} : Const{funcT}, Const{Nothing}, e_tt, width, #=returnPrimal=#false} + thunkTy = ForwardModeThunk{ + Ptr{Cvoid}, + dupClosure ? Duplicated{funcT} : Const{funcT}, + Const{Nothing}, + e_tt, + width, + false, + } #=returnPrimal=# subfunc = functions(mod)[fwdmodenm] elseif mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient @@ -209,7 +301,7 @@ end has_active = ty == MixedState || ty == ActiveState if has_active refed = true - e_tt = Tuple{Duplicated{Base.RefValue{funcT}}, e_tt.parameters...} + e_tt = Tuple{Duplicated{Base.RefValue{funcT}},e_tt.parameters...} funcT = Core.Typeof(referenceCaller) dupClosure = false modifiedBetween = (false, modifiedBetween...) @@ -220,30 +312,75 @@ end if augfwdnm === nothing || adjointnm === nothing etarget = Compiler.EnzymeTarget() # TODO modifiedBetween - eparams = Compiler.EnzymeCompilerParams(Tuple{(dupClosure ? Duplicated : Const){funcT}, e_tt.parameters...}, API.DEM_ReverseModePrimal, width, Const{Nothing}, #=runEnzyme=#true, #=abiwrap=#true, modifiedBetween, #=returnPrimal=#false, #=shadowInit=#false, UnknownTapeType, FFIABI, #=ErrIfFuncWritten=#false, get_runtime_activity(gutils)) - ejob = Compiler.CompilerJob(mi2, CompilerConfig(etarget, eparams; kernel=false), world) - - cmod, adjointnm, augfwdnm, TapeType = _thunk(ejob, #=postopt=#false) + eparams = Compiler.EnzymeCompilerParams( + Tuple{(dupClosure ? Duplicated : Const){funcT},e_tt.parameters...}, + API.DEM_ReverseModePrimal, + width, + Const{Nothing}, + true, + true, + modifiedBetween, + false, + false, + UnknownTapeType, + FFIABI, + false, + get_runtime_activity(gutils), + ) #=ErrIfFuncWritten=# + ejob = Compiler.CompilerJob( + mi2, + CompilerConfig(etarget, eparams; kernel = false), + world, + ) + + cmod, adjointnm, augfwdnm, TapeType = _thunk(ejob, false) #=postopt=# LLVM.link!(mod, cmod) push!(attributes, StringAttribute("enzymejl_augforward", augfwdnm)) - push!(function_attributes(functions(mod)[augfwdnm]), EnumAttribute("alwaysinline")) + push!( + function_attributes(functions(mod)[augfwdnm]), + EnumAttribute("alwaysinline"), + ) permit_inlining!(functions(mod)[augfwdnm]) push!(attributes, StringAttribute("enzymejl_adjoint", adjointnm)) - push!(function_attributes(functions(mod)[adjointnm]), EnumAttribute("alwaysinline")) + push!( + function_attributes(functions(mod)[adjointnm]), + EnumAttribute("alwaysinline"), + ) permit_inlining!(functions(mod)[adjointnm]) - push!(attributes, StringAttribute("enzymejl_tapetype", string(convert(UInt, unsafe_to_pointer(TapeType))))) - + push!( + attributes, + StringAttribute( + "enzymejl_tapetype", + string(convert(UInt, unsafe_to_pointer(TapeType))), + ), + ) + end if mode == API.DEM_ReverseModePrimal - thunkTy = AugmentedForwardThunk{Ptr{Cvoid}, dupClosure ? Duplicated{funcT} : Const{funcT}, Const{Nothing}, e_tt, width, #=returnPrimal=#true, TapeType} + thunkTy = AugmentedForwardThunk{ + Ptr{Cvoid}, + dupClosure ? Duplicated{funcT} : Const{funcT}, + Const{Nothing}, + e_tt, + width, + true, + TapeType, + } #=returnPrimal=# subfunc = functions(mod)[augfwdnm] - else - thunkTy = AdjointThunk{Ptr{Cvoid}, dupClosure ? Duplicated{funcT} : Const{funcT}, Const{Nothing}, e_tt, width, TapeType} + else + thunkTy = AdjointThunk{ + Ptr{Cvoid}, + dupClosure ? Duplicated{funcT} : Const{funcT}, + Const{Nothing}, + e_tt, + width, + TapeType, + } subfunc = functions(mod)[adjointnm] end else @@ -251,7 +388,7 @@ end end ppfuncT = pfuncT - dpfuncT = width == 1 ? pfuncT : NTuple{(Int)width, pfuncT} + dpfuncT = width == 1 ? pfuncT : NTuple{(Int)width,pfuncT} if refed dpfuncT = Base.RefValue{dpfuncT} @@ -263,7 +400,7 @@ end if width == 1 dfuncT = Duplicated{dfuncT} else - dfuncT = BatchDuplicated{dfuncT, Int(width)} + dfuncT = BatchDuplicated{dfuncT,Int(width)} end else dfuncT = Const{dfuncT} @@ -273,7 +410,7 @@ end alloctx = LLVM.IRBuilder() position!(alloctx, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils))) - ll_th = convert(LLVMType, thunkTy) + ll_th = convert(LLVMType, thunkTy) al = alloca!(alloctx, ll_th) al = addrspacecast!(B, al, LLVM.PointerType(ll_th, Tracked)) al = addrspacecast!(B, al, LLVM.PointerType(ll_th, Derived)) @@ -320,7 +457,12 @@ end val0 = v end - ptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 0)]) + ptr = inbounds_gep!( + B, + llty, + al, + [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 0)], + ) store!(B, val0, ptr) if pdupClosure @@ -343,7 +485,8 @@ end if refed dval0 = dval = emit_allocobj!(B, dpfuncT) - dval = bitcast!(B, dval, LLVM.PointerType(spllty, addrspace(value_type(dval)))) + dval = + bitcast!(B, dval, LLVM.PointerType(spllty, addrspace(value_type(dval)))) dval = addrspacecast!(B, dval, LLVM.PointerType(spllty, Derived)) store!(B, dv, dval) if pv !== nothing @@ -356,7 +499,15 @@ end dval0 = dv end - dptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 1)]) + dptr = inbounds_gep!( + B, + llty, + al, + [ + LLVM.ConstantInt(LLVM.IntType(64), 0), + LLVM.ConstantInt(LLVM.IntType(32), 1), + ], + ) store!(B, dval0, dptr) end @@ -379,12 +530,15 @@ end end mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing - shadow = (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + shadow = + (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : nothing - _, sname, dfuncT, vals, thunkTy, _, _ = threadsfor_common(orig, gutils, B, API.DEM_ForwardMode) + _, sname, dfuncT, vals, thunkTy, _, _ = + threadsfor_common(orig, gutils, B, API.DEM_ForwardMode) - tt = Tuple{thunkTy, dfuncT, Bool} + tt = Tuple{thunkTy,dfuncT,Bool} mode = get_mode(gutils) world = enzyme_extract_world(LLVM.parent(position(B))) entry = nested_codegen!(mode, mod, runtime_pfor_fwd, tt, world) @@ -414,12 +568,21 @@ end return true end - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing - shadow = (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + shadow = + (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : nothing - byRef, sname, dfuncT, vals, thunkTy, _, copies = threadsfor_common(orig, gutils, B, API.DEM_ReverseModePrimal) + byRef, sname, dfuncT, vals, thunkTy, _, copies = + threadsfor_common(orig, gutils, B, API.DEM_ReverseModePrimal) - tt = Tuple{thunkTy, dfuncT, Val{any_jltypes(EnzymeRules.tape_type(thunkTy))}, Val{byRef}, Bool} + tt = Tuple{ + thunkTy, + dfuncT, + Val{any_jltypes(EnzymeRules.tape_type(thunkTy))}, + Val{byRef}, + Bool, + } mode = get_mode(gutils) world = enzyme_extract_world(LLVM.parent(position(B))) entry = nested_codegen!(mode, mod, runtime_pfor_augfwd, tt, world) @@ -459,7 +622,8 @@ end return end - byRef, sname, dfuncT, vals, thunkTy, TapeType, copies = threadsfor_common(orig, gutils, B, API.DEM_ReverseModeGradient, tape) + byRef, sname, dfuncT, vals, thunkTy, TapeType, copies = + threadsfor_common(orig, gutils, B, API.DEM_ReverseModeGradient, tape) STT = if !any_jltypes(TapeType) Ptr{TapeType} @@ -467,7 +631,14 @@ end Vector{TapeType} end - tt = Tuple{thunkTy, dfuncT, Val{any_jltypes(EnzymeRules.tape_type(thunkTy))}, Val{byRef}, STT, Bool} + tt = Tuple{ + thunkTy, + dfuncT, + Val{any_jltypes(EnzymeRules.tape_type(thunkTy))}, + Val{byRef}, + STT, + Bool, + } mode = get_mode(gutils) entry = nested_codegen!(mode, mod, runtime_pfor_rev, tt, world) push!(function_attributes(entry), EnumAttribute("alwaysinline")) @@ -499,15 +670,18 @@ end ops = collect(operands(orig)) vals = LLVM.Value[ - unsafe_to_llvm(B, runtime_newtask_fwd), - unsafe_to_llvm(B, Val(world)), - new_from_original(gutils, ops[1]), - invert_pointer(gutils, ops[1], B), - new_from_original(gutils, ops[2]), - (sizeof(Int) == sizeof(Int64) ? emit_box_int64! : emit_box_int32!)(B, new_from_original(gutils, ops[3])), - unsafe_to_llvm(B, Val(get_runtime_activity(gutils))), - unsafe_to_llvm(B, Val(width)), - ] + unsafe_to_llvm(B, runtime_newtask_fwd), + unsafe_to_llvm(B, Val(world)), + new_from_original(gutils, ops[1]), + invert_pointer(gutils, ops[1], B), + new_from_original(gutils, ops[2]), + (sizeof(Int) == sizeof(Int64) ? emit_box_int64! : emit_box_int32!)( + B, + new_from_original(gutils, ops[3]), + ), + unsafe_to_llvm(B, Val(get_runtime_activity(gutils))), + unsafe_to_llvm(B, Val(width)), + ] ntask = emit_apply_generic!(B, vals) debug_from_orig!(gutils, ntask, orig) @@ -532,9 +706,11 @@ end if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing - shadow = (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : nothing - + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + shadow = + (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : nothing + T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) @@ -550,15 +726,19 @@ end ops = collect(operands(orig)) vals = LLVM.Value[ - unsafe_to_llvm(B, runtime_newtask_augfwd), - unsafe_to_llvm(B, Val(world)), - new_from_original(gutils, ops[1]), - invert_pointer(gutils, ops[1], B), - new_from_original(gutils, ops[2]), - (sizeof(Int) == sizeof(Int64) ? emit_box_int64! : emit_box_int32!)(B, new_from_original(gutils, ops[3])), - unsafe_to_llvm(B, Val(get_runtime_activity(gutils))), unsafe_to_llvm(B, Val(width)), - unsafe_to_llvm(B, Val(ModifiedBetween)), - ] + unsafe_to_llvm(B, runtime_newtask_augfwd), + unsafe_to_llvm(B, Val(world)), + new_from_original(gutils, ops[1]), + invert_pointer(gutils, ops[1], B), + new_from_original(gutils, ops[2]), + (sizeof(Int) == sizeof(Int64) ? emit_box_int64! : emit_box_int32!)( + B, + new_from_original(gutils, ops[3]), + ), + unsafe_to_llvm(B, Val(get_runtime_activity(gutils))), + unsafe_to_llvm(B, Val(width)), + unsafe_to_llvm(B, Val(ModifiedBetween)), + ] ntask = emit_apply_generic!(B, vals) debug_from_orig!(gutils, ntask, orig) @@ -569,12 +749,20 @@ end sret = LLVM.pointercast!(B, sret, LLVM.PointerType(AT, Derived)) if shadowR != C_NULL - shadow = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)])) + shadow = LLVM.load!( + B, + T_prjlvalue, + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]), + ) unsafe_store!(shadowR, shadow.ref) end if normalR != C_NULL - normal = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)])) + normal = LLVM.load!( + B, + T_prjlvalue, + LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)]), + ) unsafe_store!(normalR, normal.ref) end @@ -596,16 +784,18 @@ end if width == 1 nops = LLVM.Value[inv, new_from_original(gutils, ops[2])] valTys = API.CValueType[API.VT_Shadow, API.VT_Primal] - cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, nops, valTys, #=lookup=#false) + cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, nops, valTys, false) #=lookup=# debug_from_orig!(gutils, cal, orig) callconv!(cal, callconv(orig)) else - for idx in 1:width - nops = LLVM.Value[extract_value(B, inv, idx-1), - new_from_original(gutils, ops[2])] + for idx = 1:width + nops = LLVM.Value[ + extract_value(B, inv, idx - 1), + new_from_original(gutils, ops[2]), + ] valTys = API.CValueType[API.VT_Shadow, API.VT_Primal] - cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, nops, valTys, #=lookup=#false) - + cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, nops, valTys, false) #=lookup=# + debug_from_orig!(gutils, cal, orig) callconv!(cal, callconv(orig)) end @@ -626,7 +816,8 @@ end if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing unsafe_store!(shadowR, normal.ref) end @@ -663,7 +854,11 @@ end mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) waitfn = find_match(mod, "jl_wait") if waitfn === nothing - emit_error(B, orig, "Enzyme: could not find jl_wait fn to create shadow of jl_enq_work") + emit_error( + B, + orig, + "Enzyme: could not find jl_wait fn to create shadow of jl_enq_work", + ) return nothing end @assert waitfn !== nothing @@ -678,7 +873,8 @@ end if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing unsafe_store!(shadowR, normal.ref) end @@ -689,7 +885,8 @@ end if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true end - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing unsafe_store!(shadowR, normal.ref) end @@ -702,7 +899,11 @@ end mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) enq_work_fn = find_match(mod, "jl_enq_work") if enq_work_fn === nothing - emit_error(B, orig, "Enzyme: could not find jl_enq_work fn to create shadow of wait") + emit_error( + B, + orig, + "Enzyme: could not find jl_enq_work fn to create shadow of wait", + ) return nothing end @assert enq_work_fn !== nothing diff --git a/src/rules/typerules.jl b/src/rules/typerules.jl index 569ef87323f..2a2d1032c1d 100644 --- a/src/rules/typerules.jl +++ b/src/rules/typerules.jl @@ -1,12 +1,26 @@ -function int_return_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeTreeRef}, known_values::Ptr{API.IntList}, numArgs::Csize_t, val::LLVM.API.LLVMValueRef)::UInt8 +function int_return_rule( + direction::Cint, + ret::API.CTypeTreeRef, + args::Ptr{API.CTypeTreeRef}, + known_values::Ptr{API.IntList}, + numArgs::Csize_t, + val::LLVM.API.LLVMValueRef, +)::UInt8 TT = TypeTree(API.DT_Integer, LLVM.context(LLVM.Value(val))) only!(TT, -1) API.EnzymeMergeTypeTree(ret, TT) return UInt8(false) end -function inout_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeTreeRef}, known_values::Ptr{API.IntList}, numArgs::Csize_t, val::LLVM.API.LLVMValueRef)::UInt8 +function inout_rule( + direction::Cint, + ret::API.CTypeTreeRef, + args::Ptr{API.CTypeTreeRef}, + known_values::Ptr{API.IntList}, + numArgs::Csize_t, + val::LLVM.API.LLVMValueRef, +)::UInt8 if numArgs != 1 return UInt8(false) end diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index 6117e464d81..500ff53d9a0 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -1,4 +1,12 @@ -function body_construct_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs, tuple) +function body_construct_augfwd( + N, + Width, + primtypes, + active_refs, + primargs, + batchshadowargs, + tuple, +) shadow_rets = Vector{Expr}[] results = quote $(active_refs...) @@ -6,29 +14,35 @@ function body_construct_augfwd(N, Width, primtypes, active_refs, primargs, batch @assert length(primtypes) == N @assert length(primargs) == N @assert length(batchshadowargs) == N - for i in 1:N + for i = 1:N @assert length(batchshadowargs[i]) == Width shadow_rets_i = Expr[] aref = Symbol("active_ref_$i") - for w in 1:Width - sref = Symbol("sub_shadow_"*string(i)*"_"*string(w)) - push!(shadow_rets_i, quote - $sref = if $aref == AnyState - $(primargs[i]); - else - if !ActivityTup[$i] - if ($aref == DupState || $aref == MixedState) && $(batchshadowargs[i][w]) === nothing - prim = $(primargs[i]) - throw("Error cannot store inactive but differentiable variable $prim into active tuple") - end - end - if $aref == DupState - $(batchshadowargs[i][w]) + for w = 1:Width + sref = Symbol("sub_shadow_" * string(i) * "_" * string(w)) + push!( + shadow_rets_i, + quote + $sref = if $aref == AnyState + $(primargs[i]) else - $(batchshadowargs[i][w])[] + if !ActivityTup[$i] + if ($aref == DupState || $aref == MixedState) && + $(batchshadowargs[i][w]) === nothing + prim = $(primargs[i]) + throw( + "Error cannot store inactive but differentiable variable $prim into active tuple", + ) + end + end + if $aref == DupState + $(batchshadowargs[i][w]) + else + $(batchshadowargs[i][w])[] + end end - end - end) + end, + ) end push!(shadow_rets, shadow_rets_i) end @@ -36,11 +50,11 @@ function body_construct_augfwd(N, Width, primtypes, active_refs, primargs, batch refs = Expr[] ref_syms = Symbol[] res_syms = Symbol[] - for w in 1:Width + for w = 1:Width sres = Symbol("result_$w") ref_res = Symbol("ref_result_$w") combined = Expr[] - for i in 1:N + for i = 1:N push!(combined, shadow_rets[i][w]) end if tuple @@ -85,10 +99,18 @@ function body_construct_augfwd(N, Width, primtypes, active_refs, primargs, batch end -function body_construct_rev(N, Width, primtypes, active_refs, primargs, batchshadowargs, tuple) +function body_construct_rev( + N, + Width, + primtypes, + active_refs, + primargs, + batchshadowargs, + tuple, +) outs = [] - for i in 1:N - for w in 1:Width + for i = 1:N + for w = 1:Width tsym = Symbol("tval_$w") expr = if tuple :($tsym[$i]) @@ -96,20 +118,25 @@ function body_construct_rev(N, Width, primtypes, active_refs, primargs, batchsha :(getfield($tsym, $i)) end shad = batchshadowargs[i][w] - out = :(if $(Symbol("active_ref_$i")) == MixedState || $(Symbol("active_ref_$i")) == ActiveState - if $shad isa Base.RefValue - $shad[] = recursive_add($shad[], $expr, identity, guaranteed_nonactive) - else - error("Enzyme Mutability Error: Cannot add one in place to immutable value "*string($shad)) + out = :( + if $(Symbol("active_ref_$i")) == MixedState || + $(Symbol("active_ref_$i")) == ActiveState + if $shad isa Base.RefValue + $shad[] = recursive_add($shad[], $expr, identity, guaranteed_nonactive) + else + error( + "Enzyme Mutability Error: Cannot add one in place to immutable value " * + string($shad), + ) + end end - end ) push!(outs, out) end end - tapes = Expr[:(tval_1 = tape[])] - for w in 2:Width + tapes = Expr[:(tval_1 = tape[])] + for w = 2:Width sym = Symbol("tval_$w") df = Symbol("df_$w") push!(tapes, :($sym = $df[])) @@ -131,87 +158,226 @@ function body_runtime_tuple_rev(N, Width, primtypes, active_refs, primargs, batc body_construct_rev(N, Width, primtypes, active_refs, primargs, batchshadowargs, true) end -function body_runtime_newstruct_rev(N, Width, primtypes, active_refs, primargs, batchshadowargs) +function body_runtime_newstruct_rev( + N, + Width, + primtypes, + active_refs, + primargs, + batchshadowargs, +) body_construct_rev(N, Width, primtypes, active_refs, primargs, batchshadowargs, false) end -function body_runtime_tuple_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs) +function body_runtime_tuple_augfwd( + N, + Width, + primtypes, + active_refs, + primargs, + batchshadowargs, +) body_construct_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs, true) end function func_runtime_tuple_augfwd(N, Width) - primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width; func=false, mixed_or_active=true) - body = body_runtime_tuple_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs) + primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = + setup_macro_wraps(false, N, Width; func = false, mixed_or_active = true) + body = body_runtime_tuple_augfwd( + N, + Width, + primtypes, + active_refs, + primargs, + batchshadowargs, + ) quote - function runtime_tuple_augfwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, $(allargs...))::ReturnType where {ActivityTup, MB, ReturnType, $(typeargs...)} + function runtime_tuple_augfwd( + activity::Type{Val{ActivityTup}}, + width::Val{$Width}, + ModifiedBetween::Val{MB}, + RT::Val{ReturnType}, + $(allargs...), + )::ReturnType where {ActivityTup,MB,ReturnType,$(typeargs...)} $body end end end -@generated function runtime_tuple_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, allargs...)::ReturnType where {ActivityTup, MB, Width, ReturnType} +@generated function runtime_tuple_augfwd( + activity::Type{Val{ActivityTup}}, + width::Val{Width}, + ModifiedBetween::Val{MB}, + RT::Val{ReturnType}, + allargs..., +)::ReturnType where {ActivityTup,MB,Width,ReturnType} N = div(length(allargs), Width) - primargs, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width, :allargs; func=false, mixed_or_active=true) - return body_runtime_tuple_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs) + primargs, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = + setup_macro_wraps(false, N, Width, :allargs; func = false, mixed_or_active = true) + return body_runtime_tuple_augfwd( + N, + Width, + primtypes, + active_refs, + primargs, + batchshadowargs, + ) end function func_runtime_tuple_rev(N, Width) - primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width; mixed_or_active=true) - body = body_runtime_tuple_rev(N, Width, primtypes, active_refs, primargs, batchshadowargs) + primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = + setup_macro_wraps(false, N, Width; mixed_or_active = true) + body = + body_runtime_tuple_rev(N, Width, primtypes, active_refs, primargs, batchshadowargs) quote - function runtime_tuple_rev(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, tape::TapeType, $(allargs...)) where {ActivityTup, MB, TapeType, $(typeargs...)} + function runtime_tuple_rev( + activity::Type{Val{ActivityTup}}, + width::Val{$Width}, + ModifiedBetween::Val{MB}, + tape::TapeType, + $(allargs...), + ) where {ActivityTup,MB,TapeType,$(typeargs...)} $body end end end -@generated function runtime_tuple_rev(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, tape::TapeType, allargs...) where {ActivityTup, MB, Width, TapeType} - N = div(length(allargs)-(Width-1), Width) - primargs, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width, :allargs; mixed_or_active=true) - return body_runtime_tuple_rev(N, Width, primtypes, active_refs, primargs, batchshadowargs) -end - - -function body_runtime_newstruct_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs) - body_construct_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs, false) +@generated function runtime_tuple_rev( + activity::Type{Val{ActivityTup}}, + width::Val{Width}, + ModifiedBetween::Val{MB}, + tape::TapeType, + allargs..., +) where {ActivityTup,MB,Width,TapeType} + N = div(length(allargs) - (Width - 1), Width) + primargs, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = + setup_macro_wraps(false, N, Width, :allargs; mixed_or_active = true) + return body_runtime_tuple_rev( + N, + Width, + primtypes, + active_refs, + primargs, + batchshadowargs, + ) +end + + +function body_runtime_newstruct_augfwd( + N, + Width, + primtypes, + active_refs, + primargs, + batchshadowargs, +) + body_construct_augfwd( + N, + Width, + primtypes, + active_refs, + primargs, + batchshadowargs, + false, + ) end function func_runtime_newstruct_augfwd(N, Width) - primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width; mixed_or_active=true) - body = body_runtime_newstruct_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs) + primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = + setup_macro_wraps(false, N, Width; mixed_or_active = true) + body = body_runtime_newstruct_augfwd( + N, + Width, + primtypes, + active_refs, + primargs, + batchshadowargs, + ) quote - function runtime_newstruct_augfwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, ::Type{NewType}, RT::Val{ReturnType}, $(allargs...))::ReturnType where {ActivityTup, MB, ReturnType, NewType, $(typeargs...)} + function runtime_newstruct_augfwd( + activity::Type{Val{ActivityTup}}, + width::Val{$Width}, + ModifiedBetween::Val{MB}, + ::Type{NewType}, + RT::Val{ReturnType}, + $(allargs...), + )::ReturnType where {ActivityTup,MB,ReturnType,NewType,$(typeargs...)} $body end end end -@generated function runtime_newstruct_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, ::Type{NewType}, RT::Val{ReturnType}, allargs...)::ReturnType where {ActivityTup, MB, Width, ReturnType, NewType} - N = div(length(allargs)+2, Width+1)-1 - primargs, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width, :allargs; mixed_or_active=true) - return body_runtime_newstruct_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs) +@generated function runtime_newstruct_augfwd( + activity::Type{Val{ActivityTup}}, + width::Val{Width}, + ModifiedBetween::Val{MB}, + ::Type{NewType}, + RT::Val{ReturnType}, + allargs..., +)::ReturnType where {ActivityTup,MB,Width,ReturnType,NewType} + N = div(length(allargs) + 2, Width + 1) - 1 + primargs, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = + setup_macro_wraps(false, N, Width, :allargs; mixed_or_active = true) + return body_runtime_newstruct_augfwd( + N, + Width, + primtypes, + active_refs, + primargs, + batchshadowargs, + ) end function func_runtime_newstruct_rev(N, Width) - primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width; mixed_or_active=true) - body = body_runtime_newstruct_rev(N, Width, primtypes, active_refs, primargs, batchshadowargs) + primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = + setup_macro_wraps(false, N, Width; mixed_or_active = true) + body = body_runtime_newstruct_rev( + N, + Width, + primtypes, + active_refs, + primargs, + batchshadowargs, + ) quote - function runtime_newstruct_rev(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, ::Type{NewStruct}, tape::TapeType, $(allargs...)) where {ActivityTup, MB, NewStruct, TapeType, $(typeargs...)} + function runtime_newstruct_rev( + activity::Type{Val{ActivityTup}}, + width::Val{$Width}, + ModifiedBetween::Val{MB}, + ::Type{NewStruct}, + tape::TapeType, + $(allargs...), + ) where {ActivityTup,MB,NewStruct,TapeType,$(typeargs...)} $body end end end -@generated function runtime_newstruct_rev(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, ::Type{NewStruct}, tape::TapeType, allargs...) where {ActivityTup, MB, Width, NewStruct, TapeType} - N = div(length(allargs)-(Width-1), Width) - primargs, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width, :allargs; mixed_or_active=true) - return body_runtime_newstruct_rev(N, Width, primtypes, active_refs, primargs, batchshadowargs) +@generated function runtime_newstruct_rev( + activity::Type{Val{ActivityTup}}, + width::Val{Width}, + ModifiedBetween::Val{MB}, + ::Type{NewStruct}, + tape::TapeType, + allargs..., +) where {ActivityTup,MB,Width,NewStruct,TapeType} + N = div(length(allargs) - (Width - 1), Width) + primargs, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = + setup_macro_wraps(false, N, Width, :allargs; mixed_or_active = true) + return body_runtime_newstruct_rev( + N, + Width, + primtypes, + active_refs, + primargs, + batchshadowargs, + ) end for (N, Width) in Iterators.product(0:30, 1:10) @@ -235,7 +401,8 @@ function newstruct_common(fwd, run, offset, B, orig, gutils, normalR, shadowR) abs = [abs_typeof(v) for v in origops[offset+1:end-1]] @assert length(icvs) == length(abs) - for (icv, (found_partial, typ_partial), (found, typ)) in zip(icvs, abs_partial, abs) + for (icv, (found_partial, typ_partial, byref_partial), (found, typ, byref)) in + zip(icvs, abs_partial, abs) # Constants not handled unless known inactive from type if icv if !found_partial @@ -251,7 +418,8 @@ function newstruct_common(fwd, run, offset, B, orig, gutils, normalR, shadowR) if !found_partial return false end - act = active_reg_inner(typ_partial, (), world, #=justactive=#Val(false), #=unionsret=#Val(false), #=abstractismixed=#Val(true)) + act = + active_reg_inner(typ_partial, (), world, Val(false), Val(false), Val(true)) #=abstractismixed=# if act == MixedState || act == ActiveState return false end @@ -262,7 +430,7 @@ function newstruct_common(fwd, run, offset, B, orig, gutils, normalR, shadowR) return true end - shadowsin = LLVM.Value[invert_pointer(gutils, o, B) for o in origops[offset:end-1] ] + shadowsin = LLVM.Value[invert_pointer(gutils, o, B) for o in origops[offset:end-1]] if width == 1 if offset != 1 pushfirst!(shadowsin, origops[1]) @@ -270,17 +438,16 @@ function newstruct_common(fwd, run, offset, B, orig, gutils, normalR, shadowR) shadowres = LLVM.call!(B, called_type(orig), LLVM.called_operand(orig), shadowsin) callconv!(shadowres, callconv(orig)) else - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) - for idx in 1:width - args = LLVM.Value[ - extract_value!(B, s, idx-1) for s in shadowsin - ] + shadowres = + UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) + for idx = 1:width + args = LLVM.Value[extract_value!(B, s, idx - 1) for s in shadowsin] if offset != 1 pushfirst!(args, origops[1]) end tmp = LLVM.call!(B, called_type(orig), LLVM.called_operand(orig), args) callconv!(tmp, callconv(orig)) - shadowres = insert_value!(B, shadowres, tmp, idx-1) + shadowres = insert_value!(B, shadowres, tmp, idx - 1) end end unsafe_store!(shadowR, shadowres.ref) @@ -291,17 +458,35 @@ end function common_newstructv_fwd(offset, B, orig, gutils, normalR, shadowR) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) - - if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + get_mode(gutils), + ) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0) && + is_constant_inst(gutils, orig) return true end - if !newstruct_common(#=fwd=#true, #=run=#true, offset, B, orig, gutils, normalR, shadowR) + if !newstruct_common(true, true, offset, B, orig, gutils, normalR, shadowR) #=run=# origops = collect(operands(orig)) abs_partial = [abs_typeof(v, true) for v in origops[offset+1:end-1]] icvs = [is_constant_value(gutils, v) for v in origops[offset+1:end-1]] - emit_error(B, orig, "Enzyme: Not yet implemented, mixed activity for jl_new_struct constants="*string(icvs)*" "*string(orig)*" "*string(abs_partial)*" "*string([v for v in origops[offset+1:end-1]])) + emit_error( + B, + orig, + "Enzyme: Not yet implemented, mixed activity for jl_new_struct constants=" * + string(icvs) * + " " * + string(orig) * + " " * + string(abs_partial) * + " " * + string([v for v in origops[offset+1:end-1]]), + ) end return false @@ -310,15 +495,26 @@ end function common_newstructv_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR)::Bool needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) - - if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + get_mode(gutils), + ) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0) && + is_constant_inst(gutils, orig) return true end - if !newstruct_common(#=fwd=#false, #=run=#true, offset, B, orig, gutils, normalR, shadowR) - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing - shadow = (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : nothing + if !newstruct_common(false, true, offset, B, orig, gutils, normalR, shadowR) #=run=# + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : + nothing + shadow = + (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : + nothing T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) @@ -326,8 +522,20 @@ function common_newstructv_augfwd(offset, B, orig, gutils, normalR, shadowR, tap width = get_width(gutils) - sret = generic_setup(orig, runtime_newstruct_augfwd, width == 1 ? Any : AnyArray(Int(width)), gutils, #=start=#offset, B, false; firstconst=true, endcast = false, firstconst_after_tape=true, runtime_activity=false) - + sret = generic_setup( + orig, + runtime_newstruct_augfwd, + width == 1 ? Any : AnyArray(Int(width)), + gutils, + offset, + B, + false; + firstconst = true, + endcast = false, + firstconst_after_tape = true, + runtime_activity = false, + ) #=start=# + if width == 1 shadow = sret else @@ -338,10 +546,15 @@ function common_newstructv_augfwd(offset, B, orig, gutils, normalR, shadowR, tap cal = LLVM.pointercast!(B, cal, LLVM.PointerType(llty, Derived)) ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) shadow = LLVM.UndefValue(ST) - for i in 1:width - gep = LLVM.inbounds_gep!(B, AT, cal, [LLVM.ConstantInt(0), LLVM.ConstantInt(i-1)]) + for i = 1:width + gep = LLVM.inbounds_gep!( + B, + AT, + cal, + [LLVM.ConstantInt(0), LLVM.ConstantInt(i - 1)], + ) ld = LLVM.load!(B, T_prjlvalue, gep) - shadow = insert_value!(B, shadow, ld, i-1) + shadow = insert_value!(B, shadow, ld, i - 1) end end unsafe_store!(shadowR, shadow.ref) @@ -359,18 +572,36 @@ function common_newstructv_rev(offset, B, orig, gutils, tape) end needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + API.DEM_ReverseModePrimal, + ) needsPrimal = needsPrimalP[] != 0 needsShadow = needsShadowP[] != 0 - if !needsShadow - return - end + if !needsShadow + return + end - if !newstruct_common(#=fwd=#false, #=run=#false, offset, B, orig, gutils, #=normalR=#nothing, #=shadowR=#nothing) + if !newstruct_common(false, false, offset, B, orig, gutils, nothing, nothing) #=shadowR=# @assert tape !== C_NULL width = get_width(gutils) - generic_setup(orig, runtime_newstruct_rev, Nothing, gutils, #=start=#offset, B, true; firstconst=true, tape, firstconst_after_tape=true, runtime_activity=false) + generic_setup( + orig, + runtime_newstruct_rev, + Nothing, + gutils, + offset, + B, + true; + firstconst = true, + tape, + firstconst_after_tape = true, + runtime_activity = false, + ) #=start=# end return nothing @@ -383,15 +614,25 @@ end function common_f_tuple_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR)::Bool needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) - - if is_constant_value(gutils, orig) || needsShadowP[] == 0 + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + get_mode(gutils), + ) + + if is_constant_value(gutils, orig) || needsShadowP[] == 0 return true end - if !newstruct_common(#=fwd=#false, #=run=#true, offset, B, orig, gutils, normalR, shadowR) - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing - shadow = (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : nothing + if !newstruct_common(false, true, offset, B, orig, gutils, normalR, shadowR) #=run=# + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : + nothing + shadow = + (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : + nothing T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) @@ -399,8 +640,18 @@ function common_f_tuple_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) width = get_width(gutils) - sret = generic_setup(orig, runtime_tuple_augfwd, width == 1 ? Any : AnyArray(Int(width)), gutils, #=start=#offset+1, B, false; endcast = false, runtime_activity=false) - + sret = generic_setup( + orig, + runtime_tuple_augfwd, + width == 1 ? Any : AnyArray(Int(width)), + gutils, + offset + 1, + B, + false; + endcast = false, + runtime_activity = false, + ) #=start=# + if width == 1 shadow = sret else @@ -411,10 +662,15 @@ function common_f_tuple_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) cal = LLVM.pointercast!(B, cal, LLVM.PointerType(llty, Derived)) ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) shadow = LLVM.UndefValue(ST) - for i in 1:width - gep = LLVM.inbounds_gep!(B, AT, cal, [LLVM.ConstantInt(0), LLVM.ConstantInt(i-1)]) + for i = 1:width + gep = LLVM.inbounds_gep!( + B, + AT, + cal, + [LLVM.ConstantInt(0), LLVM.ConstantInt(i - 1)], + ) ld = LLVM.load!(B, T_prjlvalue, gep) - shadow = insert_value!(B, shadow, ld, i-1) + shadow = insert_value!(B, shadow, ld, i - 1) end end unsafe_store!(shadowR, shadow.ref) @@ -428,7 +684,13 @@ end function common_f_tuple_rev(offset, B, orig, gutils, tape) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + API.DEM_ReverseModePrimal, + ) needsPrimal = needsPrimalP[] != 0 needsShadow = needsShadowP[] != 0 @@ -440,7 +702,7 @@ function common_f_tuple_rev(offset, B, orig, gutils, tape) return true end - if !newstruct_common(#=fwd=#false, #=run=#false, offset, B, orig, gutils, #=normalR=#nothing, #=shadowR=#nothing) + if !newstruct_common(false, false, offset, B, orig, gutils, nothing, nothing) #=shadowR=# @assert tape !== C_NULL width = get_width(gutils) tape2 = if width != 1 @@ -456,8 +718,13 @@ function common_f_tuple_rev(offset, B, orig, gutils, tape) cal = LLVM.pointercast!(B, cal, LLVM.PointerType(llty, Derived)) ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) - for i in 1:width - gep = LLVM.inbounds_gep!(B, AT, cal, [LLVM.ConstantInt(0), LLVM.ConstantInt(i-1)]) + for i = 1:width + gep = LLVM.inbounds_gep!( + B, + AT, + cal, + [LLVM.ConstantInt(0), LLVM.ConstantInt(i - 1)], + ) ld = LLVM.load!(B, T_prjlvalue, gep) push!(res, ld) end @@ -465,7 +732,17 @@ function common_f_tuple_rev(offset, B, orig, gutils, tape) else tape end - generic_setup(orig, runtime_tuple_rev, Nothing, gutils, #=start=#offset+1, B, true; tape=tape2, runtime_activity=false) + generic_setup( + orig, + runtime_tuple_rev, + Nothing, + gutils, + offset + 1, + B, + true; + tape = tape2, + runtime_activity = false, + ) #=start=# end return nothing end @@ -506,7 +783,12 @@ end @assert is_constant_value(gutils, origops[1]) if is_constant_value(gutils, origops[2]) - emit_error(B, orig, "Enzyme: Not yet implemented, mixed activity for jl_new_struct_t"*string(orig)) + emit_error( + B, + orig, + "Enzyme: Not yet implemented, mixed activity for jl_new_struct_t" * + string(orig), + ) end shadowsin = invert_pointer(gutils, origops[2], B) @@ -515,12 +797,16 @@ end shadowres = LLVM.call!(B, called_type(orig), LLVM.called_operand(orig), vals) callconv!(shadowres, callconv(orig)) else - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) - for idx in 1:width - vals = [new_from_original(gutils, origops[1]), extract_value!(B, shadowsin, idx-1)] + shadowres = + UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) + for idx = 1:width + vals = [ + new_from_original(gutils, origops[1]), + extract_value!(B, shadowsin, idx - 1), + ] tmp = LLVM.call!(B, called_type(orig), LLVM.called_operand(orig), args) callconv!(tmp, callconv(orig)) - shadowres = insert_value!(B, shadowres, tmp, idx-1) + shadowres = insert_value!(B, shadowres, tmp, idx - 1) end end unsafe_store!(shadowR, shadowres.ref) @@ -537,14 +823,24 @@ end end needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + API.DEM_ReverseModePrimal, + ) needsPrimal = needsPrimalP[] != 0 needsShadow = needsShadowP[] != 0 - if !needsShadow - return - end - emit_error(B, orig, "Enzyme: Not yet implemented reverse for jl_new_structt "*string(orig)) + if !needsShadow + return + end + emit_error( + B, + orig, + "Enzyme: Not yet implemented reverse for jl_new_structt " * string(orig), + ) return nothing end @@ -568,9 +864,13 @@ function common_jl_getfield_fwd(offset, B, orig, gutils, normalR, shadowR) shadowres = LLVM.call!(B, called_type(orig), LLVM.called_operand(orig), args) callconv!(shadowres, callconv(orig)) else - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) - for idx in 1:width - args = LLVM.Value[new_from_original(gutils, origops[1]), extract_value!(B, shadowin, idx-1)] + shadowres = + UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) + for idx = 1:width + args = LLVM.Value[ + new_from_original(gutils, origops[1]), + extract_value!(B, shadowin, idx - 1), + ] for a in origops[3:end-1] push!(args, new_from_original(gutils, a)) end @@ -579,7 +879,7 @@ function common_jl_getfield_fwd(offset, B, orig, gutils, normalR, shadowR) end tmp = LLVM.call!(B, called_type(orig), LLVM.called_operand(orig), args) callconv!(tmp, callconv(orig)) - shadowres = insert_value!(B, shadowres, tmp, idx-1) + shadowres = insert_value!(B, shadowres, tmp, idx - 1) end end unsafe_store!(shadowR, shadowres.ref) @@ -588,9 +888,11 @@ function common_jl_getfield_fwd(offset, B, orig, gutils, normalR, shadowR) if width == 1 shadowres = normal else - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(normal)))) - for idx in 1:width - shadowres = insert_value!(B, shadowres, normal, idx-1) + shadowres = UndefValue( + LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(normal))), + ) + for idx = 1:width + shadowres = insert_value!(B, shadowres, normal, idx - 1) end end unsafe_store!(shadowR, shadowres.ref) @@ -598,7 +900,7 @@ function common_jl_getfield_fwd(offset, B, orig, gutils, normalR, shadowR) return false end -@generated function ntuple_ref_zero(::Val{N}, ::Type{RT}, res) where {N, RT} +@generated function ntuple_ref_zero(::Val{N}, ::Type{RT}, res) where {N,RT} expr = Vector{Expr}(undef, N) fill!(expr, :(Ref{$RT}(make_zero(res)))) return quote @@ -607,9 +909,9 @@ end end end -@generated function ntuple_ref_lookup(::Val{N}, ::Type{RT}, dptrs, symname) where {N, RT} +@generated function ntuple_ref_lookup(::Val{N}, ::Type{RT}, dptrs, symname) where {N,RT} expr = Vector{Expr}(undef, N) - for i in 1:N + for i = 1:N @inbounds expr[i] = quote begin dv = dptrs[$i] @@ -625,7 +927,7 @@ end @generated function ntuple_lookup(::Val{N}, ptrs, symname) where {N} expr = Vector{Expr}(undef, N) - for i in 1:N + for i = 1:N @inbounds expr[i] = quote begin dv = ptrs[$i] @@ -639,11 +941,17 @@ end end end -function rt_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isconst}, dptrs::Vararg{T2, Nargs}) where {NT, T, T2, Nargs, symname, isconst} +function rt_jl_getfield_aug( + ::Val{NT}, + dptr::T, + ::Type{Val{symname}}, + ::Val{isconst}, + dptrs::Vararg{T2,Nargs}, +) where {NT,T,T2,Nargs,symname,isconst} res = if dptr isa Base.RefValue - Base.getfield(dptr[], symname) + Base.getfield(dptr[], symname) else - Base.getfield(dptr, symname) + Base.getfield(dptr, symname) end RT = Core.Typeof(res) @@ -652,13 +960,16 @@ function rt_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isco if length(dptrs) == 0 return Ref{RT}(make_zero(res)) else - return NT(ntuple_ref_zero(Val(1+length(dptrs)), RT, res)) + return NT(ntuple_ref_zero(Val(1 + length(dptrs)), RT, res)) end elseif actreg == MixedState if length(dptrs) == 0 return Ref{RT}(res) else - fval = NT((Ref{RT}(res), ntuple_ref_lookup(Val(length(dptrs)), RT, dptrs, symname)...)) + fval = NT(( + Ref{RT}(res), + ntuple_ref_lookup(Val(length(dptrs)), RT, dptrs, symname)..., + )) return fval end elseif isconst @@ -678,11 +989,17 @@ function rt_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isco end end -function idx_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isconst}, dptrs::Vararg{T2, Nargs}) where {NT, T, T2, Nargs, symname, isconst} +function idx_jl_getfield_aug( + ::Val{NT}, + dptr::T, + ::Type{Val{symname}}, + ::Val{isconst}, + dptrs::Vararg{T2,Nargs}, +) where {NT,T,T2,Nargs,symname,isconst} res = if dptr isa Base.RefValue - Base.getfield(dptr[], symname+1) + Base.getfield(dptr[], symname + 1) else - Base.getfield(dptr, symname+1) + Base.getfield(dptr, symname + 1) end RT = Core.Typeof(res) actreg = active_reg_nothrow(RT, Val(nothing)) @@ -690,13 +1007,16 @@ function idx_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isc if length(dptrs) == 0 return Ref{RT}(make_zero(res))::Any else - return NT(ntuple_ref_zero(Val(1+length(dptrs)), RT, res)) + return NT(ntuple_ref_zero(Val(1 + length(dptrs)), RT, res)) end elseif actreg == MixedState if length(dptrs) == 0 return Ref{RT}(res) else - fval = NT((Ref{RT}(res), ntuple_ref_lookup(Val(length(dptrs)), RT, dptrs, symname+1)...)) + fval = NT(( + Ref{RT}(res), + ntuple_ref_lookup(Val(length(dptrs)), RT, dptrs, symname + 1)..., + )) return fval end elseif isconst @@ -710,16 +1030,21 @@ function idx_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isc if length(dptrs) == 0 return res::Any else - fval = NT((res, ntuple_lookup(Val(length(dptrs)), dptrs, symname+1)...)) + fval = NT((res, ntuple_lookup(Val(length(dptrs)), dptrs, symname + 1)...)) return fval end end end -@generated function recursive_field_add(::Type{dRT}, vload, ::Val{symname}, dret) where {dRT, symname} +@generated function recursive_field_add( + ::Type{dRT}, + vload, + ::Val{symname}, + dret, +) where {dRT,symname} N = fieldcount(dRT) exprs = Vector{Expr}(undef, N) - for i in 1:N + for i = 1:N @inbounds exprs[i] = if fieldname(dRT, i) == symname :(recursive_add(getfield(vload, $i), dret, identity, guaranteed_nonactive)) else @@ -733,11 +1058,17 @@ end end end -function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, dptrs::Vararg{T2, Nargs}) where {T, T2, Nargs, symname, isconst} +function rt_jl_getfield_rev( + dptr::T, + dret, + ::Type{Val{symname}}, + ::Val{isconst}, + dptrs::Vararg{T2,Nargs}, +) where {T,T2,Nargs,symname,isconst} cur = if dptr isa Base.RefValue - getfield(dptr[], symname) + getfield(dptr[], symname) else - getfield(dptr, symname) + getfield(dptr, symname) end RT = Core.Typeof(cur) @@ -750,7 +1081,11 @@ function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, dRT = Core.Typeof(vload) dptr[] = recursive_field_add(dRT, vload, Val(symname), dret[]) else - setfield!(dptr, symname, recursive_add(cur, dret[], identity, guaranteed_nonactive)) + setfield!( + dptr, + symname, + recursive_add(cur, dret[], identity, guaranteed_nonactive), + ) end else if dptr isa Base.RefValue @@ -760,18 +1095,22 @@ function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, else setfield!(dptr, symname, recursive_add(cur, dret[1][])) end - for i in 1:length(dptrs) + for i = 1:length(dptrs) if dptrs[i] isa Base.RefValue vload = dptrs[i][] dRT = Core.Typeof(vload) dptrs[i][] = recursive_field_add(dRT, vload, Val(symname), dret[1+i][]) else curi = if dptr isa Base.RefValue - Base.getfield(dptrs[i][], symname) + Base.getfield(dptrs[i][], symname) else - Base.getfield(dptrs[i], symname) + Base.getfield(dptrs[i], symname) end - setfield!(dptrs[i], symname, recursive_add(curi, dret[1+i][], identity, guaranteed_nonactive)) + setfield!( + dptrs[i], + symname, + recursive_add(curi, dret[1+i][], identity, guaranteed_nonactive), + ) end end end @@ -779,10 +1118,15 @@ function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, return nothing end -@generated function recursive_index_add(::Type{dRT}, vload, ::Val{symname}, dret) where {dRT, symname} +@generated function recursive_index_add( + ::Type{dRT}, + vload, + ::Val{symname}, + dret, +) where {dRT,symname} N = fieldcount(dRT) exprs = Vector{Expr}(undef, N) - for i in 1:N + for i = 1:N @inbounds exprs[i] = if i == symname :(recursive_add(getfield(vload, $i), dret, identity, guaranteed_nonactive)) else @@ -796,11 +1140,17 @@ end end end -function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, dptrs::Vararg{T2, Nargs}) where {T, T2, Nargs, symname, isconst} +function idx_jl_getfield_rev( + dptr::T, + dret, + ::Type{Val{symname}}, + ::Val{isconst}, + dptrs::Vararg{T2,Nargs}, +) where {T,T2,Nargs,symname,isconst} cur = if dptr isa Base.RefValue - Base.getfield(dptr[], symname+1) + Base.getfield(dptr[], symname + 1) else - Base.getfield(dptr, symname+1) + Base.getfield(dptr, symname + 1) end RT = Core.Typeof(cur) @@ -811,30 +1161,43 @@ function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst} if dptr isa Base.RefValue vload = dptr[] dRT = Core.Typeof(vload) - dptr[] = recursive_index_add(dRT, vload, Val(symname+1), dret[]) + dptr[] = recursive_index_add(dRT, vload, Val(symname + 1), dret[]) else - setfield!(dptr, symname+1, recursive_add(cur, dret[], identity, guaranteed_nonactive)) + setfield!( + dptr, + symname + 1, + recursive_add(cur, dret[], identity, guaranteed_nonactive), + ) end else if dptr isa Base.RefValue vload = dptr[] dRT = Core.Typeof(vload) - dptr[] = recursive_index_add(dRT, vload, Val(symname+1), dret[1][]) + dptr[] = recursive_index_add(dRT, vload, Val(symname + 1), dret[1][]) else - setfield!(dptr, symname+1, recursive_add(cur, dret[1][], identity, guaranteed_nonactive)) + setfield!( + dptr, + symname + 1, + recursive_add(cur, dret[1][], identity, guaranteed_nonactive), + ) end - for i in 1:length(dptrs) + for i = 1:length(dptrs) if dptrs[i] isa Base.RefValue vload = dptrs[i][] dRT = Core.Typeof(vload) - dptrs[i][] = recursive_index_add(dRT, vload, Val(symname+1), dret[1+i][]) + dptrs[i][] = + recursive_index_add(dRT, vload, Val(symname + 1), dret[1+i][]) else curi = if dptr isa Base.RefValue - Base.getfield(dptrs[i][], symname+1) + Base.getfield(dptrs[i][], symname + 1) else - Base.getfield(dptrs[i], symname+1) + Base.getfield(dptrs[i], symname + 1) end - setfield!(dptrs[i], symname+1, recursive_add(curi, dret[1+i][], identity, guaranteed_nonactive)) + setfield!( + dptrs[i], + symname + 1, + recursive_add(curi, dret[1+i][], identity, guaranteed_nonactive), + ) end end end @@ -862,8 +1225,8 @@ function common_jl_getfield_augfwd(offset, B, orig, gutils, normalR, shadowR, ta inps = [inp] else inps = LLVM.Value[] - for w in 1:width - push!(inps, extract_value!(B, inp, w-1)) + for w = 1:width + push!(inps, extract_value!(B, inp, w - 1)) end end else @@ -899,18 +1262,23 @@ function common_jl_getfield_augfwd(offset, B, orig, gutils, normalR, shadowR, ta if !is_constant_value(gutils, ops[2]) forgep = LLVM.addrspacecast!(B, forgep, LLVM.PointerType(T_jlvalue, Derived)) forgep = LLVM.pointercast!(B, forgep, LLVM.PointerType(AT, Derived)) - end + end ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) shadow = LLVM.UndefValue(ST) - for i in 1:width + for i = 1:width if !is_constant_value(gutils, ops[2]) - gep = LLVM.inbounds_gep!(B, AT, forgep, [LLVM.ConstantInt(0), LLVM.ConstantInt(i-1)]) + gep = LLVM.inbounds_gep!( + B, + AT, + forgep, + [LLVM.ConstantInt(0), LLVM.ConstantInt(i - 1)], + ) ld = LLVM.load!(B, T_prjlvalue, gep) else ld = forgep end - shadow = insert_value!(B, shadow, ld, i-1) + shadow = insert_value!(B, shadow, ld, i - 1) end shadowres = shadow end @@ -927,7 +1295,13 @@ function common_jl_getfield_rev(offset, B, orig, gutils, tape) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + API.DEM_ReverseModePrimal, + ) if needsShadowP[] == 0 return end @@ -936,7 +1310,7 @@ function common_jl_getfield_rev(offset, B, orig, gutils, tape) width = get_width(gutils) mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) - + if !is_constant_value(gutils, ops[2]) inp = invert_pointer(gutils, ops[2], B) inp = lookup_value(gutils, inp, B) @@ -944,8 +1318,8 @@ function common_jl_getfield_rev(offset, B, orig, gutils, tape) inps = [inp] else inps = LLVM.Value[] - for w in 1:width - push!(inps, extract_value!(B, inp, w-1)) + for w = 1:width + push!(inps, extract_value!(B, inp, w - 1)) end end else @@ -988,21 +1362,22 @@ end shadowin = invert_pointer(gutils, origops[1], B) if width == 1 args = LLVM.Value[ - shadowin - new_from_original(gutils, origops[2]) - ] + shadowin + new_from_original(gutils, origops[2]) + ] shadowres = LLVM.call!(B, called_type(orig), LLVM.called_operand(orig), args) callconv!(shadowres, callconv(orig)) else - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) - for idx in 1:width + shadowres = + UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) + for idx = 1:width args = LLVM.Value[ - extract_value!(B, shadowin, idx-1) - new_from_original(gutils, origops[2]) - ] + extract_value!(B, shadowin, idx - 1) + new_from_original(gutils, origops[2]) + ] tmp = LLVM.call!(B, called_type(orig), LLVM.called_operand(orig), args) callconv!(tmp, callconv(orig)) - shadowres = insert_value!(B, shadowres, tmp, idx-1) + shadowres = insert_value!(B, shadowres, tmp, idx - 1) end end unsafe_store!(shadowR, shadowres.ref) @@ -1011,9 +1386,11 @@ end if width == 1 shadowres = normal else - shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(normal)))) - for idx in 1:width - shadowres = insert_value!(B, shadowres, normal, idx-1) + shadowres = UndefValue( + LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(normal))), + ) + for idx = 1:width + shadowres = insert_value!(B, shadowres, normal, idx - 1) end end unsafe_store!(shadowR, shadowres.ref) @@ -1029,7 +1406,7 @@ end width = get_width(gutils) mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) - + T_int8 = LLVM.Int8Type() T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) @@ -1040,8 +1417,8 @@ end inps = [inp] else inps = LLVM.Value[] - for w in 1:width - push!(inps, extract_value!(B, inp, w-1)) + for w = 1:width + push!(inps, extract_value!(B, inp, w - 1)) end end else @@ -1077,18 +1454,23 @@ end if !is_constant_value(gutils, ops[1]) forgep = LLVM.addrspacecast!(B, forgep, LLVM.PointerType(T_jlvalue, Derived)) forgep = LLVM.pointercast!(B, forgep, LLVM.PointerType(AT, Derived)) - end + end ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) shadow = LLVM.UndefValue(ST) - for i in 1:width + for i = 1:width if !is_constant_value(gutils, ops[1]) - gep = LLVM.inbounds_gep!(B, AT, forgep, [LLVM.ConstantInt(0), LLVM.ConstantInt(i-1)]) + gep = LLVM.inbounds_gep!( + B, + AT, + forgep, + [LLVM.ConstantInt(0), LLVM.ConstantInt(i - 1)], + ) ld = LLVM.load!(B, T_prjlvalue, gep) else ld = forgep end - shadow = insert_value!(B, shadow, ld, i-1) + shadow = insert_value!(B, shadow, ld, i - 1) end shadowres = shadow end @@ -1104,19 +1486,25 @@ end needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal) + activep = API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + orig, + needsPrimalP, + needsShadowP, + API.DEM_ReverseModePrimal, + ) needsPrimal = needsPrimalP[] != 0 needsShadow = needsShadowP[] != 0 - if !needsShadow - return - end + if !needsShadow + return + end ops = collect(operands(orig)) width = get_width(gutils) mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) - + if !is_constant_value(gutils, ops[1]) inp = invert_pointer(gutils, ops[1], B) inp = lookup_value(gutils, inp, B) @@ -1124,8 +1512,8 @@ end inps = [inp] else inps = LLVM.Value[] - for w in 1:width - push!(inps, extract_value!(B, inp, w-1)) + for w = 1:width + push!(inps, extract_value!(B, inp, w - 1)) end end else @@ -1170,7 +1558,8 @@ end end function common_setfield_fwd(offset, B, orig, gutils, normalR, shadowR) - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing unsafe_store!(shadowR, normal.ref) end @@ -1188,34 +1577,48 @@ function common_setfield_fwd(offset, B, orig, gutils, normalR, shadowR) shadowout = invert_pointer(gutils, origops[4], B) if width == 1 args = LLVM.Value[ - new_from_original(gutils, origops[1]) - shadowin - new_from_original(gutils, origops[3]) - shadowout - ] - valTys = API.CValueType[API.VT_Primal, API.VT_Shadow, API.VT_Primal, API.VT_Shadow] + new_from_original(gutils, origops[1]) + shadowin + new_from_original(gutils, origops[3]) + shadowout + ] + valTys = + API.CValueType[API.VT_Primal, API.VT_Shadow, API.VT_Primal, API.VT_Shadow] if offset != 1 pushfirst!(args, first(operands(orig))) pushfirst!(valTys, API.VT_Primal) end - shadowres = call_samefunc_with_inverted_bundles!(B, gutils, orig, args, valTys, #=lookup=#false) + shadowres = + call_samefunc_with_inverted_bundles!(B, gutils, orig, args, valTys, false) #=lookup=# callconv!(shadowres, callconv(orig)) else - for idx in 1:width + for idx = 1:width args = LLVM.Value[ - new_from_original(gutils, origops[1]) - extract_value!(B, shadowin, idx-1) - new_from_original(gutils, origops[3]) - extract_value!(B, shadowout, idx-1) - ] - valTys = API.CValueType[API.VT_Primal, API.VT_Shadow, API.VT_Primal, API.VT_Shadow] + new_from_original(gutils, origops[1]) + extract_value!(B, shadowin, idx - 1) + new_from_original(gutils, origops[3]) + extract_value!(B, shadowout, idx - 1) + ] + valTys = API.CValueType[ + API.VT_Primal, + API.VT_Shadow, + API.VT_Primal, + API.VT_Shadow, + ] if offset != 1 pushfirst!(args, first(operands(orig))) pushfirst!(valTys, API.VT_Primal) end - tmp = call_samefunc_with_inverted_bundles!(B, gutils, orig, args, valTys, #=lookup=#false) + tmp = call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + args, + valTys, + false, + ) #=lookup=# callconv!(tmp, callconv(orig)) end @@ -1225,7 +1628,7 @@ function common_setfield_fwd(offset, B, orig, gutils, normalR, shadowR) end -function rt_jl_setfield_aug(dptr::T, idx, ::Val{isconst}, val, dval) where {T, isconst} +function rt_jl_setfield_aug(dptr::T, idx, ::Val{isconst}, val, dval) where {T,isconst} RT = Core.Typeof(val) if active_reg(RT) setfield!(dptr, idx, make_zero(val)) @@ -1234,7 +1637,7 @@ function rt_jl_setfield_aug(dptr::T, idx, ::Val{isconst}, val, dval) where {T, i end end -function rt_jl_setfield_rev(dptr::T, idx, ::Val{isconst}, val, dval) where {T, isconst} +function rt_jl_setfield_rev(dptr::T, idx, ::Val{isconst}, val, dval) where {T,isconst} RT = Core.Typeof(val) if active_reg(RT) && !isconst dval[] += getfield(dptr, idx) @@ -1244,7 +1647,8 @@ end function common_setfield_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing unsafe_store!(shadowR, normal.ref) end @@ -1263,15 +1667,16 @@ function common_setfield_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) - for idx in 1:width + for idx = 1:width vals = LLVM.Value[ - (width == 1) ? shadowstruct : extract_value!(B, shadowstruct, idx-1), - new_from_original(gutils, origops[3]), - unsafe_to_llvm(B, Val(is_constant_value(gutils, origops[4]))), - new_from_original(gutils, origops[4]), - is_constant_value(gutils, origops[4]) ? unsafe_to_llvm(B, nothing) : ((width == 1) ? shadowval : extract_value!(B, shadowval, idx-1)), + (width == 1) ? shadowstruct : extract_value!(B, shadowstruct, idx - 1), + new_from_original(gutils, origops[3]), + unsafe_to_llvm(B, Val(is_constant_value(gutils, origops[4]))), + new_from_original(gutils, origops[4]), + is_constant_value(gutils, origops[4]) ? unsafe_to_llvm(B, nothing) : + ((width == 1) ? shadowval : extract_value!(B, shadowval, idx - 1)), ] - + pushfirst!(vals, unsafe_to_llvm(B, rt_jl_setfield_aug)) cal = emit_apply_generic!(B, vals) @@ -1295,18 +1700,27 @@ function common_setfield_rev(offset, B, orig, gutils, tape) else nothing end - + mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) - for idx in 1:width + for idx = 1:width vals = LLVM.Value[ - lookup_value(gutils, (width == 1) ? shadowstruct : extract_value!(B, shadowstruct, idx-1), B), - lookup_value(gutils, new_from_original(gutils, origops[3]), B), - unsafe_to_llvm(B, Val(is_constant_value(gutils, origops[4]))), - lookup_value(gutils, new_from_original(gutils, origops[4]), B), - is_constant_value(gutils, origops[4]) ? unsafe_to_llvm(B, nothing) : lookup_value(gutils, ((width == 1) ? shadowval : extract_value!(B, shadowval, idx-1)), B), + lookup_value( + gutils, + (width == 1) ? shadowstruct : extract_value!(B, shadowstruct, idx - 1), + B, + ), + lookup_value(gutils, new_from_original(gutils, origops[3]), B), + unsafe_to_llvm(B, Val(is_constant_value(gutils, origops[4]))), + lookup_value(gutils, new_from_original(gutils, origops[4]), B), + is_constant_value(gutils, origops[4]) ? unsafe_to_llvm(B, nothing) : + lookup_value( + gutils, + ((width == 1) ? shadowval : extract_value!(B, shadowval, idx - 1)), + B, + ), ] - + pushfirst!(vals, unsafe_to_llvm(B, rt_jl_setfield_rev)) cal = emit_apply_generic!(B, vals) @@ -1314,7 +1728,7 @@ function common_setfield_rev(offset, B, orig, gutils, tape) debug_from_orig!(gutils, cal, orig) end end - return nothing + return nothing end @@ -1330,7 +1744,7 @@ end common_setfield_rev(1, B, orig, gutils, tape) end -function error_if_differentiable(::Type{T}) where T +function error_if_differentiable(::Type{T}) where {T} seen = () areg = active_reg_inner(T, seen, nothing) if areg != AnyState @@ -1349,13 +1763,13 @@ function common_f_svec_ref_fwd(offset, B, orig, gutils, normalR, shadowR) origmi, origh, origkey = operands(orig)[offset:end-1] shadowh = invert_pointer(gutils, origh, B) - + newvals = API.CValueType[API.VT_Primal, API.VT_Shadow, API.VT_Primal] if offset != 1 pushfirst!(newvals, API.VT_Primal) end - + mi = new_from_original(gutils, origmi) mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) @@ -1365,27 +1779,50 @@ function common_f_svec_ref_fwd(offset, B, orig, gutils, normalR, shadowR) if offset != 1 pushfirst!(newops, operands(orig)[1]) end - cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) + cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, false) #=lookup=# callconv!(cal, callconv(orig)) - + if is_constant_value(gutils, origh) - emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(B, error_if_differentiable), emit_jltypeof!(B, cal)]) + emit_apply_generic!( + B, + LLVM.Value[ + unsafe_to_llvm(B, error_if_differentiable), + emit_jltypeof!(B, cal), + ], + ) end cal else ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) shadow = LLVM.UndefValue(ST) - for j in 1:width - newops = LLVM.Value[mi, extract_value!(B, shadowh, j-1), new_from_original(gutils, origkey)] + for j = 1:width + newops = LLVM.Value[ + mi, + extract_value!(B, shadowh, j - 1), + new_from_original(gutils, origkey), + ] if offset != 1 pushfirst!(newops, operands(orig)[1]) end - cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) + cal = call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + newops, + newvals, + false, + ) #=lookup=# callconv!(cal, callconv(orig)) if is_constant_value(gutils, origh) - emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(B, error_if_differentiable), emit_jltypeof!(B, cal)]) + emit_apply_generic!( + B, + LLVM.Value[ + unsafe_to_llvm(B, error_if_differentiable), + emit_jltypeof!(B, cal), + ], + ) end - shadow = insert_value!(B, shadow, cal, j-1) + shadow = insert_value!(B, shadow, cal, j - 1) end shadow end @@ -1405,19 +1842,19 @@ function common_f_svec_ref_augfwd(offset, B, orig, gutils, normalR, shadowR, tap origmi, origh, origkey = operands(orig)[offset:end-1] shadowh = invert_pointer(gutils, origh, B) - + newvals = API.CValueType[API.VT_Primal, API.VT_Shadow, API.VT_Primal] if offset != 1 pushfirst!(newvals, API.VT_Primal) end - + errfn = if is_constant_value(gutils, origh) error_if_differentiable else error_if_active end - + mi = new_from_original(gutils, origmi) mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) @@ -1427,24 +1864,38 @@ function common_f_svec_ref_augfwd(offset, B, orig, gutils, normalR, shadowR, tap if offset != 1 pushfirst!(newops, operands(orig)[1]) end - cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) + cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, false) #=lookup=# callconv!(cal, callconv(orig)) - - + + emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(B, errfn), emit_jltypeof!(B, cal)]) cal else ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) shadow = LLVM.UndefValue(ST) - for j in 1:width - newops = LLVM.Value[mi, extract_value!(B, shadowh, j-1), new_from_original(gutils, origkey)] + for j = 1:width + newops = LLVM.Value[ + mi, + extract_value!(B, shadowh, j - 1), + new_from_original(gutils, origkey), + ] if offset != 1 pushfirst!(newops, operands(orig)[1]) end - cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) + cal = call_samefunc_with_inverted_bundles!( + B, + gutils, + orig, + newops, + newvals, + false, + ) #=lookup=# callconv!(cal, callconv(orig)) - emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(B, errfn), emit_jltypeof!(B, cal)]) - shadow = insert_value!(B, shadow, cal, j-1) + emit_apply_generic!( + B, + LLVM.Value[unsafe_to_llvm(B, errfn), emit_jltypeof!(B, cal)], + ) + shadow = insert_value!(B, shadow, cal, j - 1) end shadow end @@ -1463,7 +1914,8 @@ function common_finalizer_fwd(offset, B, orig, gutils, normalR, shadowR) return true end emit_error(B, orig, "Enzyme: unhandled forward for jl_f_finalizer") - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing unsafe_store!(shadowR, normal.ref) end @@ -1475,7 +1927,8 @@ function common_finalizer_augfwd(offset, B, orig, gutils, normalR, shadowR, tape return true end emit_error(B, orig, "Enzyme: unhandled augmented forward for jl_f_finalizer") - normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + normal = + (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing if shadowR != C_NULL && normal !== nothing unsafe_store!(shadowR, normal.ref) end diff --git a/src/typeanalysis.jl b/src/typeanalysis.jl index a1d90ba81f4..a84f96f856e 100644 --- a/src/typeanalysis.jl +++ b/src/typeanalysis.jl @@ -7,7 +7,10 @@ end Base.unsafe_convert(::Type{API.EnzymeTypeAnalysisRef}, ta::TypeAnalysis) = ta.ref LLVM.dispose(ta::TypeAnalysis) = API.FreeTypeAnalysis(ta) -function TypeAnalysis(logic, typerules::Dict{String, CustomRuleType}=Dict{String,CustomRuleType}()) +function TypeAnalysis( + logic, + typerules::Dict{String,CustomRuleType} = Dict{String,CustomRuleType}(), +) rulenames = String[] rules = CustomRuleType[] for (rulename, rule) in typerules @@ -20,4 +23,4 @@ end # typedef bool (*CustomRuleType)(int /*direction*/, CTypeTree * /*return*/, # CTypeTree * /*args*/, size_t /*numArgs*/, -# LLVMValueRef)=T \ No newline at end of file +# LLVMValueRef)=T diff --git a/src/typetree.jl b/src/typetree.jl index 89e5a040f36..8ddce070b26 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -51,7 +51,7 @@ function shift!(tt::TypeTree, dl, offset, maxSize, addOffset) API.EnzymeTypeTreeShiftIndiciesEq(tt, dl, offset, maxSize, addOffset) end -function merge!(dst::TypeTree, src::TypeTree; consume=true) +function merge!(dst::TypeTree, src::TypeTree; consume = true) API.EnzymeMergeTypeTree(dst, src) if consume LLVM.dispose(src) @@ -80,28 +80,12 @@ end @static if VERSION >= v"1.11-" -const TypeTreePrimitives = ( - Char, - Float16, - Float32, - Float64, - Core.BFloat16 -) + const TypeTreePrimitives = (Char, Float16, Float32, Float64, Core.BFloat16) else -const TypeTreePrimitives = ( - Char, - Float16, - Float32, - Float64 -) + const TypeTreePrimitives = (Char, Float16, Float32, Float64) end -const TypeTreeEmptyPointers = ( - BigFloat, - Any, - Symbol, - Union{}, -) +const TypeTreeEmptyPointers = (BigFloat, Any, Symbol, Union{}) function get_offsets(@nospecialize(T::Type)) for sT in (Integer, TypeTreePrimitives...) @@ -109,18 +93,22 @@ function get_offsets(@nospecialize(T::Type)) return ((typetree_primitive(T), 0),) end end - for sT in (DataType, AbstractString, TypeTreeEmptyPointers...) + for sT in (DataType, AbstractString) if T <: sT return ((API.DT_Pointer, 0),) end end - -@static if VERSION < v"1.11-" - TypeTreePtrs = (Core.SimpleVector, Ptr, Core.LLVMPtr, Array) -else - TypeTreePtrs = (Core.SimpleVector, Ptr, Core.LLVMPtr, Array, GenericMemory) -end for sT in TypeTreeEmptyPointers + if T == sT + return ((API.DT_Pointer, 0),) + end + end + @static if VERSION < v"1.11-" + TypeTreePtrs = (Core.SimpleVector, Ptr, Core.LLVMPtr, Array) + else + TypeTreePtrs = (Core.SimpleVector, Ptr, Core.LLVMPtr, Array, GenericMemory) + end + for sT in TypeTreePtrs if T <: sT return ((API.DT_Pointer, 0),) end @@ -132,8 +120,8 @@ end return () end - results = Tuple{API.CConcreteType, Int}[] - for f in 1:fieldcount(T) + results = Tuple{API.CConcreteType,Int}[] + for f = 1:fieldcount(T) offset = fieldoffset(T, f) subT = fieldtype(T, f) @@ -141,9 +129,9 @@ end push!(results, (API.DT_Pointer, offset)) continue end - + for (sT, sO) in get_offsets(subT) - push!(results, (sT, sO+offset)) + push!(results, (sT, sO + offset)) end end return results @@ -173,10 +161,17 @@ function to_fullmd(@nospecialize(T::Type)) end function to_md(tt::TypeTree, ctx) - return LLVM.Metadata(LLVM.MetadataAsValue(ccall((:EnzymeTypeTreeToMD, API.libEnzyme), - LLVM.API.LLVMValueRef, - (API.CTypeTreeRef, - LLVM.API.LLVMContextRef), tt, ctx))) + return LLVM.Metadata( + LLVM.MetadataAsValue( + ccall( + (:EnzymeTypeTreeToMD, API.libEnzyme), + LLVM.API.LLVMValueRef, + (API.CTypeTreeRef, LLVM.API.LLVMContextRef), + tt, + ctx, + ), + ), + ) end const TypeTreeTable = IdDict{Any,Union{Nothing,TypeTree}} @@ -190,7 +185,7 @@ Construct a Enzyme typetree from a Julia type. When using a memoized lookup by providing `seen` across multiple calls to typtree the user must call `copy` on the returned value before mutating it. """ -function typetree(@nospecialize(T::Type), ctx, dl, seen=TypeTreeTable()) +function typetree(@nospecialize(T::Type), ctx, dl, seen = TypeTreeTable()) if haskey(seen, T) tree = seen[T] if tree === nothing @@ -209,7 +204,7 @@ function typetree_inner(::Type{<:Integer}, ctx, dl, seen::TypeTreeTable) end for sT in TypeTreePrimitives @eval function typetree_inner(::Type{$sT}, ctx, dl, seen::TypeTreeTable) - return TypeTree($(typetree_primitive(sT)), -1, ctx) + return TypeTree($(typetree_primitive(sT)), -1, ctx) end end @@ -221,21 +216,25 @@ function typetree_inner(::Type{<:AbstractString}, ctx, dl, seen::TypeTreeTable) end for sT in TypeTreeEmptyPointers @eval function typetree_inner(::Type{$sT}, ctx, dl, seen::TypeTreeTable) - return TypeTree() + return TypeTree() end end function typetree_inner(::Type{Core.SimpleVector}, ctx, dl, seen::TypeTreeTable) tt = TypeTree() - for i in 0:(sizeof(Csize_t) - 1) + for i = 0:(sizeof(Csize_t)-1) merge!(tt, TypeTree(API.DT_Integer, i, ctx)) end return tt end -function typetree_inner(::Type{<:Union{Ptr{T},Core.LLVMPtr{T}}}, ctx, dl, - seen::TypeTreeTable) where {T} +function typetree_inner( + ::Type{<:Union{Ptr{T},Core.LLVMPtr{T}}}, + ctx, + dl, + seen::TypeTreeTable, +) where {T} tt = copy(typetree(T, ctx, dl, seen)) merge!(tt, TypeTree(API.DT_Pointer, ctx)) only!(tt, -1) @@ -261,13 +260,18 @@ end sizeofstruct += sizeof(Csize_t) end - for i in offset:(sizeofstruct-1) + for i = offset:(sizeofstruct-1) merge!(tt, TypeTree(API.DT_Integer, i, ctx)) end return tt end else - function typetree_inner(::Type{<:GenericMemory{kind, T}}, ctx, dl, seen::TypeTreeTable) where {kind, T} + function typetree_inner( + ::Type{<:GenericMemory{kind,T}}, + ctx, + dl, + seen::TypeTreeTable, + ) where {kind,T} offset = 0 tt = copy(typetree(T, ctx, dl, seen)) if !allocatedinline(T) && Base.isconcretetype(T) @@ -277,7 +281,7 @@ else merge!(tt, TypeTree(API.DT_Pointer, ctx)) only!(tt, sizeof(Csize_t)) - for i in 0:(sizeof(Csize_t)-1) + for i = 0:(sizeof(Csize_t)-1) merge!(tt, TypeTree(API.DT_Integer, i, ctx)) end return tt @@ -327,7 +331,7 @@ function typetree_inner(@nospecialize(T::Type), ctx, dl, seen::TypeTreeTable) end tt = TypeTree() - for f in 1:fieldcount(T) + for f = 1:fieldcount(T) offset = fieldoffset(T, f) subT = fieldtype(T, f) subtree = copy(typetree(subT, ctx, dl, seen)) @@ -358,7 +362,10 @@ struct FnTypeInfo end Base.cconvert(::Type{API.CFnTypeInfo}, fnti::FnTypeInfo) = fnti function Base.unsafe_convert(::Type{API.CFnTypeInfo}, fnti::FnTypeInfo) - args_kv = Base.unsafe_convert(Ptr{API.IntList}, Base.cconvert(Ptr{API.IntList}, fnti.known_values)) + args_kv = Base.unsafe_convert( + Ptr{API.IntList}, + Base.cconvert(Ptr{API.IntList}, fnti.known_values), + ) rTT = Base.unsafe_convert(API.CTypeTreeRef, Base.cconvert(API.CTypeTreeRef, fnti.rTT)) tts = API.CTypeTreeRef[] @@ -366,6 +373,9 @@ function Base.unsafe_convert(::Type{API.CFnTypeInfo}, fnti::FnTypeInfo) raw_tt = Base.unsafe_convert(API.CTypeTreeRef, Base.cconvert(API.CTypeTreeRef, tt)) push!(tts, raw_tt) end - argTTs = Base.unsafe_convert(Ptr{API.CTypeTreeRef}, Base.cconvert(Ptr{API.CTypeTreeRef}, tts)) + argTTs = Base.unsafe_convert( + Ptr{API.CTypeTreeRef}, + Base.cconvert(Ptr{API.CTypeTreeRef}, tts), + ) return API.CFnTypeInfo(argTTs, rTT, args_kv) end diff --git a/src/utils.jl b/src/utils.jl index ac312e82951..d042859b899 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -5,10 +5,16 @@ Assumes that `val` is globally rooted and pointer to it can be leaked. Prefer `pointer_from_objref`. Only use inside Enzyme.jl should be for Types. """ -@inline unsafe_to_pointer(val::Type{T}) where T = ccall(Base.@cfunction(Base.identity, Ptr{Cvoid}, (Ptr{Cvoid},)), Ptr{Cvoid}, (Any,), val) +@inline unsafe_to_pointer(val::Type{T}) where {T} = ccall( + Base.@cfunction(Base.identity, Ptr{Cvoid}, (Ptr{Cvoid},)), + Ptr{Cvoid}, + (Any,), + val, +) export unsafe_to_pointer -@inline is_concrete_tuple(x::Type{T2}) where T2 = (T2 <: Tuple) && !(T2 === Tuple) && !(T2 isa UnionAll) +@inline is_concrete_tuple(x::Type{T2}) where {T2} = + (T2 <: Tuple) && !(T2 === Tuple) && !(T2 isa UnionAll) export is_concrete_tuple const Tracked = 10 @@ -20,11 +26,11 @@ const captured_constants = Base.IdSet{Any}() function unsafe_nothing_to_llvm(mod::LLVM.Module) globs = LLVM.globals(mod) k = "jl_nothing" - if Base.haskey(globs, "ejl_"*k) + if Base.haskey(globs, "ejl_" * k) return globs["ejl_"*k] end T_jlvalue = LLVM.StructType(LLVM.LLVMType[]) - gv = LLVM.GlobalVariable(mod, T_jlvalue, "ejl_"*k, Tracked) + gv = LLVM.GlobalVariable(mod, T_jlvalue, "ejl_" * k, Tracked) API.SetMD(gv, "enzyme_ta_norecur", LLVM.MDNode(LLVM.Metadata[])) API.SetMD(gv, "enzyme_inactive", LLVM.MDNode(LLVM.Metadata[])) @@ -56,13 +62,13 @@ function unsafe_to_llvm(B::LLVM.IRBuilder, @nospecialize(val)) if v === val mod = LLVM.parent(LLVM.parent(LLVM.position(B))) globs = LLVM.globals(mod) - if Base.haskey(globs, "ejl_"*k) + if Base.haskey(globs, "ejl_" * k) return globs["ejl_"*k] end - gv = LLVM.GlobalVariable(mod, T_jlvalue, "ejl_"*k, Tracked) + gv = LLVM.GlobalVariable(mod, T_jlvalue, "ejl_" * k, Tracked) API.SetMD(gv, "enzyme_ta_norecur", LLVM.MDNode(LLVM.Metadata[])) - legal, jTy = Compiler.abs_typeof(gv, true) + legal, jTy, byref = Compiler.abs_typeof(gv, true) if legal curent_bb = position(B) fn = LLVM.parent(curent_bb) @@ -78,12 +84,12 @@ function unsafe_to_llvm(B::LLVM.IRBuilder, @nospecialize(val)) if v === val mod = LLVM.parent(LLVM.parent(LLVM.position(B))) globs = LLVM.globals(mod) - if Base.haskey(globs, "ejl_"*k) + if Base.haskey(globs, "ejl_" * k) return globs["ejl_"*k] end - gv = LLVM.GlobalVariable(mod, T_jlvalue, "ejl_"*k, Tracked) + gv = LLVM.GlobalVariable(mod, T_jlvalue, "ejl_" * k, Tracked) API.SetMD(gv, "enzyme_ta_norecur", LLVM.MDNode(LLVM.Metadata[])) - legal, jTy = Compiler.abs_typeof(gv, true) + legal, jTy, byref = Compiler.abs_typeof(gv, true) if legal curent_bb = position(B) fn = LLVM.parent(curent_bb) @@ -153,7 +159,11 @@ using Base: _methods_by_ftype # on 1.10 (JuliaLang/julia#48611) the generated function knows which world it was invoked in function _generated_ex(world, source, ex) - stub = Core.GeneratedFunctionStub(identity, Core.svec(:methodinstance, :ft, :tt), Core.svec()) + stub = Core.GeneratedFunctionStub( + identity, + Core.svec(:methodinstance, :ft, :tt), + Core.svec(), + ) stub(world, source, ex) end @@ -164,23 +174,38 @@ function codegen_world_age_generator(world::UInt, source, self, ft::Type, tt::Ty tt = tt.parameters[1] # validation - ft <: Core.Builtin && error("$(GPUCompiler.unsafe_function_from_type(ft)) is not a generic function") + ft <: Core.Builtin && + error("$(GPUCompiler.unsafe_function_from_type(ft)) is not a generic function") # look up the method method_error = :(throw(MethodError(ft, tt, $world))) - sig = Tuple{ft, tt.parameters...} + sig = Tuple{ft,tt.parameters...} min_world = Ref{UInt}(typemin(UInt)) max_world = Ref{UInt}(typemax(UInt)) has_ambig = Ptr{Int32}(C_NULL) # don't care about ambiguous results - mthds = Base._methods_by_ftype(sig, #=mt=# nothing, #=lim=# -1, - world, #=ambig=# false, - min_world, max_world, has_ambig) + mthds = Base._methods_by_ftype( + sig, + nothing, + -1, #=lim=# + world, + false, #=ambig=# + min_world, + max_world, + has_ambig, + ) mthds === nothing && return _generated_ex(world, source, method_error) length(mthds) == 1 || return _generated_ex(world, source, method_error) # look up the method and code instance mtypes, msp, m = mthds[1] - mi = ccall(:jl_specializations_get_linfo, Ref{MethodInstance}, (Any, Any, Any), m, mtypes, msp) + mi = ccall( + :jl_specializations_get_linfo, + Ref{MethodInstance}, + (Any, Any, Any), + m, + mtypes, + msp, + ) ci = retrieve_code_info(mi, world)::CodeInfo # prepare a new code info @@ -222,8 +247,3 @@ end end export codegen_world_age - - - - - diff --git a/test/runtests.jl b/test/runtests.jl index 573140f2c27..92bfa475136 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2856,6 +2856,17 @@ end @test dx[3] ≈ 0 end +function unstable_fun(A0) + A = 'N' in ('H', 'h', 'S', 's') ? wrap(A0) : A0 + (@inbounds A[1])::eltype(A0) +end +@testset "Type unstable static array index" begin + inp = ones(SVector{2, Float64}) + res = Enzyme.gradient(Enzyme.Reverse, unstable_fun, inp)[1] + @test res ≈ [1.0, 0.0] + res = Enzyme.gradient(Enzyme.Forward, unstable_fun, inp)[1] + @test res ≈ [1.0, 0.0] +end function sparse_eval(x::Vector{Float64}) A = sparsevec([1, 1, 2, 3], [2.0*x[2]^3.0, 1.0-x[1], 2.0+x[3], -1.0]) diff --git a/test/typetree.jl b/test/typetree.jl index 51c284d6e94..1a869d66878 100644 --- a/test/typetree.jl +++ b/test/typetree.jl @@ -79,3 +79,16 @@ end "{[0]:Pointer, [0,4]:Float@float, [0,8]:Float@double, [4]:Integer, [8]:Pointer, [8,4]:Float@float, [8,8]:Float@double, [12]:Integer, [16]:Pointer, [16,4]:Float@float, [16,8]:Float@double, [20]:Integer, [24]:Pointer, [24,4]:Float@float, [24,8]:Float@double}" end end + +@testset "GetOffsets" begin + @test Enzyme.get_offsets(Float16) == ((Enzyme.API.DT_Half,0),) + @test Enzyme.get_offsets(Float32) == ((Enzyme.API.DT_Float,0),) + @test Enzyme.get_offsets(Float64) == ((Enzyme.API.DT_Double,0),) + @test Enzyme.get_offsets(Int) == ((Enzyme.API.DT_Integer,0),) + @test Enzyme.get_offsets(Char) == ((Enzyme.API.DT_Integer,0),) + @test Enzyme.get_offsets(Ptr) == ((Enzyme.API.DT_Pointer,0),) + @test Enzyme.get_offsets(Ptr{Char}) == ((Enzyme.API.DT_Pointer,0),) + @test Enzyme.get_offsets(Ptr{Float32}) == ((Enzyme.API.DT_Pointer,0),) + @test Enzyme.get_offsets(Vector{Float32}) == ((Enzyme.API.DT_Pointer,0),) + @test Enzyme.get_offsets(Tuple{Float64, Int}) == [(Enzyme.API.DT_Double,0),(Enzyme.API.DT_Integer, 8)] +end From 63921794388c1ed36343e0d0ef5c4e7ddb03f0dd Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 24 Sep 2024 11:35:55 -0500 Subject: [PATCH 48/87] Update Project.toml (#1885) --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 5a0e192de5f..3c93057f906 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.3" +version = "0.13.2" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" From fe41647c22a1e939e7d5ef1ba1d9470833c420a3 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 24 Sep 2024 13:32:46 -0500 Subject: [PATCH 49/87] fix array (#1884) * fix array * fix * Update absint.jl * fix * fix * Fix flake * sym --- src/absint.jl | 98 ++++++++++++++++++++++++++---------------- test/internal_rules.jl | 2 +- test/runtests.jl | 29 +++++++++++++ 3 files changed, 92 insertions(+), 37 deletions(-) diff --git a/src/absint.jl b/src/absint.jl index 585b1625a39..03cf53cf4e9 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -158,7 +158,7 @@ function absint(arg::LLVM.Value, partial::Bool = false) end function actual_size(@nospecialize(typ2)) - if typ2 <: Array || typ2 <: AbstractString + if typ2 <: Array || typ2 <: AbstractString || typ2 <: Symbol return sizeof(Int) elseif Base.isconcretetype(typ2) return sizeof(typ2) @@ -359,52 +359,78 @@ function abs_typeof( end if byref == GPUCompiler.BITS_REF || byref == GPUCompiler.MUT_REF dl = LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(arg)))) - if offset === nothing - byref = GPUCompiler.BITS_VALUE - legal = true - typ2 = typ - while actual_size(typ2) != sizeof(dl, value_type(arg)) - if fieldcount(typ2) > 0 - typ2 = fieldtype(typ, 1) - if !Base.allocatedinline(typ2) - if byref != GPUCompiler.BITS_VALUE - legal = false - break + function should_recurse(typ2, arg_t) + if actual_size(typ2) != sizeof(dl, arg_t) + return true + else + if Base.isconcretetype(typ2) + if fieldcount(typ2) > 0 + if actual_size(fieldtype(typ2,1)) == actual_size(fieldtype(typ2, 1)) + return true end - byref = GPUCompiler.MUT_REF - continue end end - legal = false - break - end - if legal - return (true, typ2, byref) + return false end - else + end + + byref = GPUCompiler.BITS_VALUE + legal = true + + while offset !== nothing && legal @assert Base.isconcretetype(typ) + seen = false + lasti = 1 for i = 1:fieldcount(typ) + fo = fieldoffset(typ, i) if fieldoffset(typ, i) == offset - subT = fieldtype(typ, i) - fsize = if i == fieldcount(typ) - sizeof(typ) - else - fieldoffset(typ, i + 1) - end - offset - if fsize == sizeof(dl, value_type(arg)) - if Base.isconcretetype(subT) && - is_concrete_tuple(subT) && - length(subT.parameters) == 1 - subT = subT.parameters[1] - end - if Base.allocatedinline(subT) - return (true, subT, GPUCompiler.BITS_VALUE) - else - return (true, subT, GPUCompiler.MUT_REF) + offset = nothing + typ = fieldtype(typ, i) + if !Base.allocatedinline(typ) + if byref != GPUCompiler.BITS_VALUE + legal = false end + byref = GPUCompiler.MUT_REF + end + seen = true + break + elseif fieldoffset(typ, i) > offset + offset = offset - fieldoffset(typ, lasti) + typ = fieldtype(typ, lasti) + if !Base.allocatedinline(typ) + legal = false + end + seen = true + break + end + + if fo != 0 && fo != fieldoffset(typ, i-1) + lasti = i + end + end + if !seen + legal = false + end + end + + typ2 = typ + while should_recurse(typ2, value_type(arg)) + if fieldcount(typ2) > 0 + typ2 = fieldtype(typ2, 1) + if !Base.allocatedinline(typ2) + if byref != GPUCompiler.BITS_VALUE + legal = false + break end + byref = GPUCompiler.MUT_REF + continue end end + legal = false + break + end + if legal + return (true, typ2, byref) end end elseif legal && if typ <: Ptr && Base.isconcretetype(typ) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 32a206c62e6..0d5bbdae017 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -591,7 +591,7 @@ end TM in (Const, Duplicated, BatchDuplicated), TB in (Const, Duplicated, BatchDuplicated) are_activities_compatible(Const, TY, TM, TB) || continue - test_reverse(f!, TY, (Y, TY), (M, TM), (B, TB), (_A, Const)) + test_reverse(f!, TY, (Y, TY), (M, TM), (B, TB), (_A, Const); atol = 1.0e-5, rtol = 1.0e-5) end end @testset "test through `Adjoint` wrapper (regression test for #1306)" begin diff --git a/test/runtests.jl b/test/runtests.jl index 92bfa475136..69e6d51cd5e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3959,6 +3959,35 @@ function harmonic_f!(inter_list, coords, inters) return si end +function invwsumsq(w::AbstractVector, a::AbstractVector) + s = zero(zero(eltype(a)) / zero(eltype(w))) + for i in eachindex(w) + s += abs2(a[i]) / w[i] + end + return s +end + +_logpdf(d, x) = invwsumsq(d.Σ.diag, x .- d.μ) + +function demo_func(x::Any=transpose([1.5 2.0;]);) + m = [-0.30725218207431315, 0.5492115788562757] + d = (; Σ = LinearAlgebra.Diagonal([1.0, 1.0]), μ = m) + logp = _logpdf(d, reshape(x, (2,))) + return logp +end + +demof(x) = demo_func() + +@testset "Type checks" begin + x = [0.0, 0.0] + Enzyme.autodiff( + Enzyme.Reverse, + Enzyme.Const(demof), + Enzyme.Active, + Enzyme.Duplicated(x, zero(x)), + ) +end + @testset "Decay preservation" begin inters = [HarmonicAngle(1.0, 0.1), HarmonicAngle(2.0, 0.3)] inter_list = [1, 3] From 9f6663311d18fa72b3de7e68eca3287e0aa31cc3 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 25 Sep 2024 12:15:24 -0500 Subject: [PATCH 50/87] Fix abs cstring (#1888) --- Project.toml | 2 +- src/absint.jl | 54 +++++++++++++++++++++++++-------------------------- 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/Project.toml b/Project.toml index 3c93057f906..5a0e192de5f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.2" +version = "0.13.3" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" diff --git a/src/absint.jl b/src/absint.jl index 03cf53cf4e9..0739b7ded5a 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -489,30 +489,30 @@ function abs_typeof( end return (false, nothing, nothing) end -# -# function abs_cstring(arg::LLVM.Value)::Tuple{Bool,String} -# if isa(arg, ConstantExpr) -# ce = arg -# while isa(ce, ConstantExpr) -# if opcode(ce) == LLVM.API.LLVMAddrSpaceCast || opcode(ce) == LLVM.API.LLVMBitCast || opcode(ce) == LLVM.API.LLVMIntToPtr -# ce = operands(ce)[1] -# elseif opcode(ce) == LLVM.API.LLVMGetElementPtr -# if all(x -> isa(x, LLVM.ConstantInt) && convert(UInt, x) == 0, operands(ce)[2:end]) -# ce = operands(ce)[1] -# else -# break -# end -# else -# break -# end -# end -# if isa(ce, LLVM.GlobalVariable) -# ce = LLVM.initializer(ce) -# if (isa(ce, LLVM.ConstantArray) || isa(ce, LLVM.ConstantDataArray)) && eltype(value_type(ce)) == LLVM.IntType(8) -# return (true, String(map((x)->convert(UInt8, x), collect(ce)[1:(end-1)]))) -# end -# -# end -# end -# return (false, "") -# end + +function abs_cstring(arg::LLVM.Value)::Tuple{Bool,String} + if isa(arg, ConstantExpr) + ce = arg + while isa(ce, ConstantExpr) + if opcode(ce) == LLVM.API.LLVMAddrSpaceCast || opcode(ce) == LLVM.API.LLVMBitCast || opcode(ce) == LLVM.API.LLVMIntToPtr + ce = operands(ce)[1] + elseif opcode(ce) == LLVM.API.LLVMGetElementPtr + if all(x -> isa(x, LLVM.ConstantInt) && convert(UInt, x) == 0, operands(ce)[2:end]) + ce = operands(ce)[1] + else + break + end + else + break + end + end + if isa(ce, LLVM.GlobalVariable) + ce = LLVM.initializer(ce) + if (isa(ce, LLVM.ConstantArray) || isa(ce, LLVM.ConstantDataArray)) && eltype(value_type(ce)) == LLVM.IntType(8) + return (true, String(map((x)->convert(UInt8, x), collect(ce)[1:(end-1)]))) + end + + end + end + return (false, "") +end From 519b693c7c8530414352c6a5a3e7f6c5c7fa4be1 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 25 Sep 2024 12:15:42 -0500 Subject: [PATCH 51/87] recfix (#1886) --- src/absint.jl | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/absint.jl b/src/absint.jl index 0739b7ded5a..27f66d2c069 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -359,18 +359,24 @@ function abs_typeof( end if byref == GPUCompiler.BITS_REF || byref == GPUCompiler.MUT_REF dl = LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(arg)))) - function should_recurse(typ2, arg_t) - if actual_size(typ2) != sizeof(dl, arg_t) - return true + function should_recurse(typ2, arg_t, byref) + sz = sizeof(dl, arg_t) + if byref != GPUCompiler.BITS_VALUE + @assert sz == sizeof(Int) + return false else - if Base.isconcretetype(typ2) - if fieldcount(typ2) > 0 - if actual_size(fieldtype(typ2,1)) == actual_size(fieldtype(typ2, 1)) - return true + if actual_size(typ2) != sz + return true + else + if Base.isconcretetype(typ2) + if fieldcount(typ2) > 0 + if actual_size(fieldtype(typ2,1)) == sz + return true + end end end + return false end - return false end end @@ -414,7 +420,7 @@ function abs_typeof( end typ2 = typ - while should_recurse(typ2, value_type(arg)) + while should_recurse(typ2, value_type(arg), byref) if fieldcount(typ2) > 0 typ2 = fieldtype(typ2, 1) if !Base.allocatedinline(typ2) From a72efb1ba163158c199a4811d4839b2640e4dd91 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 25 Sep 2024 22:26:13 -0500 Subject: [PATCH 52/87] Concrete type assertion (#1890) * Concrete type assertion * fix * fix * fix --- src/absint.jl | 72 ++++++++++++++++++++++++++++++------------------- src/compiler.jl | 4 ++- 2 files changed, 48 insertions(+), 28 deletions(-) diff --git a/src/absint.jl b/src/absint.jl index 27f66d2c069..9eec24bcf34 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -167,6 +167,43 @@ function actual_size(@nospecialize(typ2)) end end +@inline function first_non_ghost(@nospecialize(typ2)) + fc = fieldcount(typ2) + for i in 1:fc + if i == fc + return (i, sizeof(typ2)) + else + fo = fieldoffset(typ2, i+1) + if fo != 0 + return (i, fo) + end + end + end + return (-1, 0) +end + +function should_recurse(@nospecialize(typ2), arg_t, byref, dl) + sz = sizeof(dl, arg_t) + if byref != GPUCompiler.BITS_VALUE + @assert sz == sizeof(Int) + return false + else + if actual_size(typ2) != sz + return true + else + if Base.isconcretetype(typ2) + idx, sz2 = first_non_ghost(typ2) + if idx != -1 + if sz2 == sz + return true + end + end + end + return false + end + end +end + function abs_typeof( arg::LLVM.Value, partial::Bool = false, @@ -346,7 +383,7 @@ function abs_typeof( if !error legal, typ, byref = abs_typeof(larg) - if legal && (byref == GPUCompiler.MUT_REF || byref == GPUCompiler.BITS_REF) + if legal && (byref == GPUCompiler.MUT_REF || byref == GPUCompiler.BITS_REF) && Base.isconcretetype(typ) @static if VERSION < v"1.11-" if typ <: Array && Base.isconcretetype(typ) T = eltype(typ) @@ -359,31 +396,11 @@ function abs_typeof( end if byref == GPUCompiler.BITS_REF || byref == GPUCompiler.MUT_REF dl = LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(arg)))) - function should_recurse(typ2, arg_t, byref) - sz = sizeof(dl, arg_t) - if byref != GPUCompiler.BITS_VALUE - @assert sz == sizeof(Int) - return false - else - if actual_size(typ2) != sz - return true - else - if Base.isconcretetype(typ2) - if fieldcount(typ2) > 0 - if actual_size(fieldtype(typ2,1)) == sz - return true - end - end - end - return false - end - end - end byref = GPUCompiler.BITS_VALUE legal = true - while offset !== nothing && legal + while (offset !== nothing && offset != 0) && legal @assert Base.isconcretetype(typ) seen = false lasti = 1 @@ -403,6 +420,7 @@ function abs_typeof( elseif fieldoffset(typ, i) > offset offset = offset - fieldoffset(typ, lasti) typ = fieldtype(typ, lasti) + @assert Base.isconcretetype(typ) if !Base.allocatedinline(typ) legal = false end @@ -420,9 +438,10 @@ function abs_typeof( end typ2 = typ - while should_recurse(typ2, value_type(arg), byref) - if fieldcount(typ2) > 0 - typ2 = fieldtype(typ2, 1) + while should_recurse(typ2, value_type(arg), byref, dl) + idx, _ = first_non_ghost(typ2) + if idx != -1 + typ2 = fieldtype(typ2, idx) if !Base.allocatedinline(typ2) if byref != GPUCompiler.BITS_VALUE legal = false @@ -439,10 +458,9 @@ function abs_typeof( return (true, typ2, byref) end end - elseif legal && if typ <: Ptr && Base.isconcretetype(typ) + elseif legal && typ <: Ptr && Base.isconcretetype(typ) return (true, eltype(typ), GPUCompiler.BITS_VALUE) end - end end end diff --git a/src/compiler.jl b/src/compiler.jl index f3680cc4c05..32dd293ecea 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -7936,8 +7936,10 @@ function GPUCompiler.codegen( elseif byref == GPUCompiler.MUT_REF || byref == GPUCompiler.BITS_REF Ptr{source_typ} else - println(string(mod)) + # println(string(mod)) + println(string(f)) @show legal, source_typ, byref, llvm_source_typ, codegen_typ, string(inst) + @show enzyme_custom_extract_mi(f) @assert false end else From 31c60beca9b422adbd2f7d86e32802e65eaad31b Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Thu, 26 Sep 2024 10:56:35 +0200 Subject: [PATCH 53/87] remove spurious checkin --- test/ext/.chainrulescore.jl.swp | Bin 12288 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 test/ext/.chainrulescore.jl.swp diff --git a/test/ext/.chainrulescore.jl.swp b/test/ext/.chainrulescore.jl.swp deleted file mode 100644 index 94b31875e128106c0c92d4e28639ebc06d4e3451..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 12288 zcmeI2&u<$=6vrpMP@ts<4i&dYe85{A{|b?4R99(3f{;oMsNqtoXl&2A3-+$NJ67x- zhy;H@4^W`M0d55$aYLe9;D)$!;m83YAt51gtoY8%I=gn_G;J?LdMkZ)*R$`vdGEV# zY(<%7_088_r?cg%!11&YU*B5xfBJjp*r^k9V!Ib5DlrcZKAhb`Pqvy-^iDsH>g70+ zy>bw06^F7r_I}qHRyu*Mtc5p5Jym|YThS`f6*bhA)-@S~@t~`cRur2@V?VyK0<6Fz zDlpLg!pkQ&7wa{<)C=>^)3eWPK4MUIX9ZXRR)7^?1y})AfE8c`Sb?LYfbP$U9VC1# zP5NAVotwI*ANj%xumY?AE5Hh{0;~WlzzVPetN<&(3a|o4PyyK%;^t#Qe0&_q6AN##eJ+nEJH$IoX27 zKI7{leT^9ov$Jr^kK%0^w~I*>>k?g#70$&W?uMQxtQk%PdjnGpQxH<|l|jE7w4_$; z5?z$KbuAmndsz>J+~bp(Z$ukPwKJ#OoD=z7l!+?4B&Q~H`L0G`q9;;5q`OPtL4#xk z?C&{i^33m`d1rY~Rh|zq^(d#rtYG`6l8qiB#MGotimAavgvsV~5wg#4&Nakjo1328 zJ7Pv>Rs1mB?fOcXu;wVxL-cgInRg}V^|8vfe6xv{^r){Qzqs4i9?^wxBgH4}SFesAS(_D8WP+tN-!$uw+Pf=qo|>wIMe3BZcI3AM-!G16;XL_^MH_bH znb00J8Vxf=7Bn>;rpEX#fPKz*DK$Q*Du2;4)mV_W4ZdpjFTY{-FRz;Y%PUDsc4fR} z#XwS+X9Nw3OJ|}`6^EE`%NCnc{5y`kS=&4X+8#d;le6@bvWj?b) zYv5O%1WUe?;J~Wt9<6sHsTb?_uc(e1^5vw(*kN0V4N;(^uUdg0v?vKWq11y|nL_p{ zrS)aj!ib- Op22j$L`rX6Xz?G6bcs*^ From 0f7b3557e0c1791105718ae5417e30aae72cb27d Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 26 Sep 2024 11:11:56 -0500 Subject: [PATCH 54/87] CustomRules: fix body check (#1896) --- src/rules/customrules.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 1985283da3e..cb6c60d98db 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -977,7 +977,7 @@ end !(aug_RT === Union{}) TapeT = EnzymeRules.tape_type(aug_RT) elseif (aug_RT isa UnionAll) && - (aug_RT <: EnzymeRules.AugmentedReturn) && + (aug_RT <: EnzymeRules.AugmentedReturn) && hasfield(typeof(aug_RT.body), :name) && aug_RT.body.name == EnzymeCore.EnzymeRules.AugmentedReturn.body.body.body.name if aug_RT.body.parameters[3] isa TypeVar TapeT = aug_RT.body.parameters[3].ub @@ -985,7 +985,7 @@ end TapeT = Any end elseif (aug_RT isa UnionAll) && - (aug_RT <: EnzymeRules.AugmentedReturnFlexShadow) && + (aug_RT <: EnzymeRules.AugmentedReturnFlexShadow) && hasfield(typeof(aug_RT.body), :name) && aug_RT.body.name == EnzymeCore.EnzymeRules.AugmentedReturnFlexShadow.body.body.body.name if aug_RT.body.parameters[3] isa TypeVar From 9495698d2d1fa163413ed6a3e113a29a02292ea4 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 26 Sep 2024 14:32:17 -0500 Subject: [PATCH 55/87] Sparsearrays ext (#1891) * Sparsearrays ext * fix --- Project.toml | 1 + src/compiler.jl | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 5a0e192de5f..f2b99062a01 100644 --- a/Project.toml +++ b/Project.toml @@ -15,6 +15,7 @@ ObjectFile = "d8793406-e978-5875-9003-1fc021f44a92" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [weakdeps] BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" diff --git a/src/compiler.jl b/src/compiler.jl index 32dd293ecea..1417379a83e 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -38,7 +38,7 @@ import Enzyme_jll import GPUCompiler: CompilerJob, codegen, safe_name using LLVM.Interop import LLVM: Target, TargetMachine - +import SparseArrays using Printf using Preferences @@ -522,6 +522,7 @@ end @inline ptreltype(::Type{Tuple{Vararg{T}}}) where {T} = T @inline ptreltype(::Type{IdDict{K,V}}) where {K,V} = V @inline ptreltype(::Type{IdDict{K,V} where K}) where {V} = V +@inline ptreltype(::Type{SparseArrays.CHOLMOD.Dense{T}}) where T = T @inline is_arrayorvararg_ty(::Type) = false @inline is_arrayorvararg_ty(::Type{Array{T,N}}) where {T,N} = true @@ -533,6 +534,7 @@ end @inline is_arrayorvararg_ty(::Type{Base.RefValue{T}}) where {T} = true @inline is_arrayorvararg_ty(::Type{IdDict{K,V}}) where {K,V} = true @inline is_arrayorvararg_ty(::Type{IdDict{K,V} where K}) where {V} = true +@inline is_arrayorvararg_ty(::Type{SparseArrays.CHOLMOD.Dense{T}}) where T = true @inline function datatype_fieldcount(t::Type{T}) where {T} return Base.datatype_fieldcount(t) From f91eabb764d6d0e0d24b6e929a2aa0ffc86aec9b Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Thu, 26 Sep 2024 17:16:26 -0400 Subject: [PATCH 56/87] Add WithPrimal and NoPrimal function (#1898) * Add WithPrimal and NoPrimal function * version bumps --- Project.toml | 2 +- lib/EnzymeCore/Project.toml | 2 +- lib/EnzymeCore/src/EnzymeCore.jl | 23 +++++++++++++++++++++++ src/Enzyme.jl | 6 +++++- test/runtests.jl | 19 +++++++++++++++++++ 5 files changed, 49 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index f2b99062a01..4cffef367c8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.3" +version = "0.13.4" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" diff --git a/lib/EnzymeCore/Project.toml b/lib/EnzymeCore/Project.toml index 37ddaf64576..3a871b930cc 100644 --- a/lib/EnzymeCore/Project.toml +++ b/lib/EnzymeCore/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeCore" uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" authors = ["William Moses ", "Valentin Churavy "] -version = "0.8.2" +version = "0.8.3" [compat] Adapt = "3, 4" diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index f51c742f5d4..3231674de54 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -244,6 +244,21 @@ const ReverseHolomorphicWithPrimal = ReverseMode{true,false,DefaultABI, true, fa @inline set_runtime_activity(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, rt::Bool) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,rt,ABI,Holomorphic,ErrIfFuncWritten}() @inline clear_runtime_activity(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,false,ABI,Holomorphic,ErrIfFuncWritten}() +""" + WithPrimal(::Enzyme.Mode) + +Modifies the mode to include the primal value. +""" +@inline WithPrimal(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{true,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}() + +""" + NoPrimal(::Enzyme.Mode) + +Modifies the mode to exclude the primal value. +""" +@inline NoPrimal(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{false,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}() + + """ struct ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI} <: Mode{ABI,ErrIfFuncWritten,RuntimeActivity} @@ -267,6 +282,10 @@ const ReverseSplitWithPrimal = ReverseModeSplit{true, true, false, 0, true,Defau @inline set_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}, rt::Bool) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,rt,Width,ModifiedBetween,ABI, ErrIfFuncWritten}() @inline clear_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,false,Width,ModifiedBetween,ABI, ErrIfFuncWritten}() +@inline WithPrimal(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{true,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}() +@inline NoPrimal(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{false,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}() + + """ struct Forward{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity} <: Mode{ABI, ErrIfFuncWritten, RuntimeActivity} @@ -286,6 +305,10 @@ const ForwardWithPrimal = ForwardMode{true, DefaultABI, false, false}() @inline set_runtime_activity(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}, rt::Bool) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,rt}() @inline clear_runtime_activity(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,false}() +@inline WithPrimal(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{true,ABI,ErrIfFuncWritten,RuntimeActivity}() +@inline NoPrimal(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{false,ABI,ErrIfFuncWritten,RuntimeActivity}() + + function autodiff end function autodiff_deferred end function autodiff_thunk end diff --git a/src/Enzyme.jl b/src/Enzyme.jl index c99114e038f..b49c3738f6b 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -46,7 +46,9 @@ import EnzymeCore: set_abi, set_runtime_activity, clear_runtime_activity, - within_autodiff + within_autodiff, + WithPrimal, + NoPrimal export Annotation, Const, Active, @@ -63,6 +65,8 @@ export Annotation, set_abi, set_runtime_activity, clear_runtime_activity, + WithPrimal, + NoPrimal, within_autodiff import EnzymeCore: BatchDuplicatedFunc diff --git a/test/runtests.jl b/test/runtests.jl index 69e6d51cd5e..d499febd77e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4066,6 +4066,25 @@ end @test res[2][6] ≈ 6.0 end +@testset "WithPrimal" begin + @test WithPrimal(Reverse) === ReverseWithPrimal + @test NoPrimal(Reverse) === Reverse + @test WithPrimal(ReverseWithPrimal) === ReverseWithPrimal + @test NoPrimal(ReverseWithPrimal) === Reverse + + @test WithPrimal(set_runtime_activity(Reverse)) === set_runtime_activity(ReverseWithPrimal) + + @test WithPrimal(Forward) === ForwardWithPrimal + @test NoPrimal(Forward) === Forward + @test WithPrimal(ForwardWithPrimal) === ForwardWithPrimal + @test NoPrimal(ForwardWithPrimal) === Forward + + @test WithPrimal(ReverseSplitNoPrimal) === ReverseSplitWithPrimal + @test NoPrimal(ReverseSplitNoPrimal) === ReverseSplitNoPrimal + @test WithPrimal(ReverseSplitWithPrimal) === ReverseSplitWithPrimal + @test NoPrimal(ReverseSplitWithPrimal) === ReverseSplitNoPrimal +end + # TEST EXTENSIONS using SpecialFunctions @testset "SpecialFunctions ext" begin From 3565c573d5b92330e33b3440623b9604ed7ebfbc Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 26 Sep 2024 16:45:35 -0500 Subject: [PATCH 57/87] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 4cffef367c8..fd0882e97ec 100644 --- a/Project.toml +++ b/Project.toml @@ -35,7 +35,7 @@ EnzymeStaticArraysExt = "StaticArrays" BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" -EnzymeCore = "0.8" +EnzymeCore = "0.8.3" Enzyme_jll = "0.0.150" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27" LLVM = "6.1, 7, 8, 9" From 17a0c7f9b6dbea81d09e5266a032bb6bfed3b4f3 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 26 Sep 2024 17:12:32 -0500 Subject: [PATCH 58/87] Skip japi1 activity rule (#1899) --- src/rules/activityrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules/activityrules.jl b/src/rules/activityrules.jl index 7a940259fab..f84b6befb8b 100644 --- a/src/rules/activityrules.jl +++ b/src/rules/activityrules.jl @@ -1,6 +1,6 @@ function julia_activity_rule(f::LLVM.Function) - if startswith(LLVM.name(f)) == "japi3" + if startswith(LLVM.name(f)) == "japi3" || startswith(LLVM.name(f)) == "japi1" return end mi, RT = enzyme_custom_extract_mi(f) From bbafecf3f3f5f05b1bc0f794652309ee28e2b108 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 27 Sep 2024 15:29:29 -0500 Subject: [PATCH 59/87] Support fillarray return (#1901) --- src/compiler.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/compiler.jl b/src/compiler.jl index 1417379a83e..36d899ac291 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -8778,6 +8778,8 @@ end function add_one_in_place(x) if x isa Base.RefValue x[] = recursive_add(x[], default_adjoint(eltype(Core.Typeof(x)))) + elseif x isa (Array{T,0} where T) + x[] = recursive_add(x[], default_adjoint(eltype(Core.Typeof(x)))) else error( "Enzyme Mutability Error: Cannot add one in place to immutable value " * From 23c2fde4bb3ed4d3465b72aac78d7052797943a6 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 27 Sep 2024 16:22:44 -0500 Subject: [PATCH 60/87] More info for dupnoneed (#1904) * More info for dupnoneed * Update lib/EnzymeCore/src/EnzymeCore.jl Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> * Update EnzymeCore.jl --------- Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> --- lib/EnzymeCore/src/EnzymeCore.jl | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index 3231674de54..f76b3023021 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -79,7 +79,27 @@ end DuplicatedNoNeed(x, ∂f_∂x) Like [`Duplicated`](@ref), except also specifies that Enzyme may avoid computing -the original result and only compute the derivative values. +the original result and only compute the derivative values. This creates opportunities +for improved performance. + +```julia + +function square_byref(out, v) + out[] = v * v + nothing +end + +out = Ref(0.0) +dout = Ref(1.0) +Enzyme.autodiff(Reverse, square_byref, DuplicatedNoNeed(out, dout), Active(1.0)) +dout[] + +# output +0.0 +``` + +For example, marking the out variable as `DuplicatedNoNeed` instead of `Duplicated` allows +Enzyme to avoid computing `v * v` (while still computing its derivative). This should only be used if `x` is a write-only variable. Otherwise, if the differentiated function stores values in `x` and reads them back in subsequent computations, using From d092d4ab20cf8a01f489bc2822f0e6e6538df549 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 27 Sep 2024 16:50:20 -0500 Subject: [PATCH 61/87] Abs int end of load (#1905) --- src/absint.jl | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/absint.jl b/src/absint.jl index 9eec24bcf34..041b3bd1cc3 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -432,13 +432,22 @@ function abs_typeof( lasti = i end end + if !seen && fieldcount(typ) > 0 + offset = offset - fieldoffset(typ, lasti) + typ = fieldtype(typ, lasti) + @assert Base.isconcretetype(typ) + if !Base.allocatedinline(typ) + legal = false + end + seen = true + end if !seen legal = false end end typ2 = typ - while should_recurse(typ2, value_type(arg), byref, dl) + while legal && should_recurse(typ2, value_type(arg), byref, dl) idx, _ = first_non_ghost(typ2) if idx != -1 typ2 = fieldtype(typ2, idx) From 126127910baa58642f106f78b47dc1d9e05108f2 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Fri, 27 Sep 2024 18:02:48 -0400 Subject: [PATCH 62/87] Add reverse rule for Sparse dense matmul/vec (#1792) * Add sparse array internal rule * Add sparsearray extension for mul! * Add more testing * Add BatchDuplicated (still broken) * Remove BatchMode since it isn't applicable? * Add sparse array testing * Don't support batchmode for now * Revert project to old style * Add sparse array compat bound * reenable batch mode for bug hunting * Turn on BatchDuplicated stuff again * Remove Q comment * Encorporate BatchDuplicated into testing properly * Consider constant fp in runtime activity (#1797) * Consider constant fp in runtime activity * fix * Suggest workaround in error for overwritten active by ref (#1791) * Fix custom active reverse mode check (#1798) * Remove Q comment * Encorporate BatchDuplicated into testing properly * Look for more writebarrier opportunities (#1800) * Look for more writebarrier opportunities * Update compiler.jl * Restrict version to 1.10+ (#1809) * Restrict version to 1.10+ * fix * fixup * Update CI.yml * Update Project.toml * Update Project.toml * Update Project.toml * Fix MixedDuplicated ABI error on primalerror (#1815) * Update test * Move new SparseArrays Cholmod into extension * Make LinearAlgebra.mul! explicit * Make sparse arrays not a extension * Fix rules for 0.13 * Remove sparse arrays ext file * Update compiler --------- Co-authored-by: William Moses Co-authored-by: Daniel Wennberg --- Project.toml | 1 + src/Enzyme.jl | 1 + src/internal_rules.jl | 129 ++++++++++++++++++++++++++++++++++++++++- test/internal_rules.jl | 42 ++++++++++++++ test/runtests.jl | 2 +- 5 files changed, 173 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index fd0882e97ec..9623428c29a 100644 --- a/Project.toml +++ b/Project.toml @@ -42,6 +42,7 @@ LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" ObjectFile = "0.4" Preferences = "1.4" +SparseArrays = "1" SpecialFunctions = "1, 2" StaticArrays = "1" julia = "1.10" diff --git a/src/Enzyme.jl b/src/Enzyme.jl index b49c3738f6b..091021cb8bb 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -98,6 +98,7 @@ export jacobian, gradient, gradient!, hvp, hvp!, hvp_and_gradient! export markType, batch_size, onehot, chunkedonehot using LinearAlgebra +import SparseArrays import EnzymeCore: ReverseMode, ReverseModeSplit, ForwardMode, Mode import EnzymeCore: EnzymeRules diff --git a/src/internal_rules.jl b/src/internal_rules.jl index f8c6e730bb0..b6d081d57d6 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -724,6 +724,133 @@ function EnzymeRules.reverse( return (nothing, nothing) end + +function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, + func::Const{typeof(LinearAlgebra.mul!)}, + ::Type{RT}, + C::Annotation{<:StridedVecOrMat}, + A::Const{<:SparseArrays.SparseMatrixCSCUnion}, + B::Annotation{<:StridedVecOrMat}, + α::Annotation{<:Number}, + β::Annotation{<:Number} + ) where {RT} + + cache_C = !(isa(β, Const)) ? copy(C.val) : nothing + # Always need to do forward pass otherwise primal may not be correct + func.val(C.val, A.val, B.val, α.val, β.val) + + primal = if EnzymeRules.needs_primal(config) + C.val + else + nothing + end + + shadow = if EnzymeRules.needs_shadow(config) + C.dval + else + nothing + end + + # Check if A is overwritten and B is active (and thus required) + cache_A = ( EnzymeRules.overwritten(config)[5] + && !(typeof(B) <: Const) + && !(typeof(C) <: Const) + ) ? copy(A.val) : nothing + + # cache_B = ( EnzymeRules.overwritten(config)[6]) ? copy(B.val) : nothing + + if !isa(α, Const) + cache_α = A.val*B.val + else + cache_α = nothing + end + + cache = (cache_C, cache_A, cache_α) + + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.reverse(config::EnzymeRules.RevConfig, + func::Const{typeof(LinearAlgebra.mul!)}, + ::Type{RT}, cache, + C::Annotation{<:StridedVecOrMat}, + A::Const{<:SparseArrays.SparseMatrixCSCUnion}, + B::Annotation{<:StridedVecOrMat}, + α::Annotation{<:Number}, + β::Annotation{<:Number} + ) where {RT} + + cache_C, cache_A, cache_α = cache + Cval = !isnothing(cache_C) ? cache_C : C.val + Aval = !isnothing(cache_A) ? cache_A : A.val + # Bval = !isnothing(cache_B) ? cache_B : B.val + + N = EnzymeRules.width(config) + if !isa(C, Const) + dCs = C.dval + dBs = isa(B, Const) ? dCs : B.dval + + dα = if !isa(α, Const) + if N == 1 + LinearAlgebra.dot(C.dval, cache_α) + else + ntuple(Val(N)) do i + Base.@_inline_meta + LinearAlgebra.dot(C.dval[i], cache_α) + end + end + else + nothing + end + + dβ = if !isa(β, Const) + if N == 1 + LinearAlgebra.dot(C.dval, Cval) + else + ntuple(Val(N)) do i + Base.@_inline_meta + LinearAlgebra.dot(C.dval[i], Cval) + end + end + else + nothing + end + + for i in 1:N + # This rule is incorrect since you need to project dA to have the same + # sparsity pattern as A. + # if !isa(A, Const) + # dA = EnzymeRules.width(config) == 1 ? A.dval : A.dval[b] + # #dA .+= α*dC*B' + # mul!(dA, dC, Bval', α.val, true) + # end + + if !isa(B, Const) + #dB .+= α*A'*dC + if N ==1 + func.val(dBs, Aval', dCs, α.val, true) + else + func.val(dBs[i], Aval', dCs[i], α.val, true) + end + end + + if N==1 + dCs .*= β.val + else + dCs[i] .*= β.val + end + end + end + + return (nothing, nothing, nothing, dα, dβ) +end + + + + + + + function EnzymeRules.forward( config::EnzymeRules.FwdConfig, ::Const{typeof(sort!)}, @@ -1269,4 +1396,4 @@ function EnzymeRules.reverse( smpl::Annotation{<:Random.SamplerTrivial{Random.CloseOpen01{FT}}}, ) where {rngty<:Union{TaskLocalRNG,Xoshiro},FT<:Union{Float32,Float64}} return (nothing, nothing, nothing) -end +end \ No newline at end of file diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 0d5bbdae017..246929272b9 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -677,4 +677,46 @@ end # @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f4(x); nothing end, Active(0.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((0.0,0.0)),) end +@testset "SparseArrays spmatvec reverse rule" begin + C = zeros(18) + M = sprand(18, 9, 0.1) + v = randn(9) + α = 2.0 + β = 1.0 + + for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), + Tα in (Const, Active), Tβ in (Const, Active) + + are_activities_compatible(Tret, Tret, Tv, Tα, Tβ) || continue + test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (α, Tα), (β, Tβ)) + + end + + + for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), bα in (true, false), bβ in (true, false) + are_activities_compatible(Tret, Tret, Tv) || continue + test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (bα, Const), (bβ, Const)) + end +end + +@testset "SparseArrays spmatmat reverse rule" begin + C = zeros(18, 11) + M = sprand(18, 9, 0.1) + v = randn(9, 11) + α = 2.0 + β = 1.0 + + for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), + Tα in (Const, Active), Tβ in (Const, Active) + + are_activities_compatible(Tret, Tv, Tα, Tβ) || continue + test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (α, Tα), (β, Tβ)) + end + + for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), bα in (true, false), bβ in (true, false) + are_activities_compatible(Tret, Tv) || continue + test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (bα, Const), (bβ, Const)) + end +end + end # InternalRules diff --git a/test/runtests.jl b/test/runtests.jl index d499febd77e..bd1c7dd90dc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4101,4 +4101,4 @@ include("ext/logexpfunctions.jl") @testset "BFloat16s ext" begin include("ext/bfloat16s.jl") -end +end \ No newline at end of file From 8e4c50a174fccc9bba81f1a78a16acdf3042cdb3 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 27 Sep 2024 18:18:23 -0500 Subject: [PATCH 63/87] Fix japi1 (#1907) --- src/rules/activityrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules/activityrules.jl b/src/rules/activityrules.jl index f84b6befb8b..13bacb06a52 100644 --- a/src/rules/activityrules.jl +++ b/src/rules/activityrules.jl @@ -1,6 +1,6 @@ function julia_activity_rule(f::LLVM.Function) - if startswith(LLVM.name(f)) == "japi3" || startswith(LLVM.name(f)) == "japi1" + if startswith(LLVM.name(f), "japi3") || startswith(LLVM.name(f), "japi1") return end mi, RT = enzyme_custom_extract_mi(f) From 5fe7d91d82a5e1c3465836ab09504a8fb1a6464b Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 27 Sep 2024 20:18:20 -0500 Subject: [PATCH 64/87] CompatHelper: bump compat for Enzyme_jll to 0.0.151, (keep existing compat) (#1908) * CompatHelper: bump compat for Enzyme_jll to 0.0.151, (keep existing compat) * Add symv and bump jll --------- Co-authored-by: CompatHelper Julia Co-authored-by: William Moses --- Project.toml | 2 +- src/compiler.jl | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 9623428c29a..0a03500ac95 100644 --- a/Project.toml +++ b/Project.toml @@ -36,7 +36,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.8.3" -Enzyme_jll = "0.0.150" +Enzyme_jll = "0.0.151" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27" LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" diff --git a/src/compiler.jl b/src/compiler.jl index 36d899ac291..08dc5f05c95 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -7106,7 +7106,7 @@ function GPUCompiler.codegen( disableFallback = String[] ForwardModeDerivatives = - ("nrm2", "dot", "gemm", "gemv", "axpy", "copy", "scal", "symm", "syrk", "potrf") + ("nrm2", "dot", "gemm", "gemv", "axpy", "copy", "scal", "symv", "symm", "syrk", "potrf") ReverseModeDerivatives = ( "nrm2", "dot", @@ -7115,6 +7115,7 @@ function GPUCompiler.codegen( "axpy", "copy", "scal", + "symv", "symm", "trmv", "syrk", From 7c4a31aa28752138b4ded2c46833c2ad2c9ed63e Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 27 Sep 2024 20:18:47 -0500 Subject: [PATCH 65/87] Fix randn (#1906) * Fix randn * Update internal_rules.jl --- src/internal_rules.jl | 71 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 69 insertions(+), 2 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index b6d081d57d6..438fcc8ecbc 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -74,7 +74,7 @@ function EnzymeRules.inactive( ) return nothing end -function EnzymeRules.inactive(::typeof(Random.randn!), args...) +function EnzymeRules.inactive(::typeof(Random.randn!), ::Random.AbstractRNG, ::AbstractArray) return nothing end function EnzymeRules.inactive(::typeof(Random.default_rng), args...) @@ -1396,4 +1396,71 @@ function EnzymeRules.reverse( smpl::Annotation{<:Random.SamplerTrivial{Random.CloseOpen01{FT}}}, ) where {rngty<:Union{TaskLocalRNG,Xoshiro},FT<:Union{Float32,Float64}} return (nothing, nothing, nothing) -end \ No newline at end of file +end + +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + Ty::Const{typeof(Random.randn!)}, + RT::Type, + rng::Annotation{<:Random.AbstractRNG}, + dst::Annotation{<:AbstractArray}) + + Ty.val(rng.val, dst.val) + + if !(dst isa Const) + if EnzymeRules.width(config) == 1 + make_zero!(dst.dval) + else + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + make_zero!(dst.dval[i]) + nothing + end + end + end + + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + dst + elseif EnzymeRules.needs_shadow(config) + dst.dval + elseif EnzymeRules.needs_primal(config) + dst.val + else + nothing + end +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfig, + Ty::Const{typeof(Random.randn!)}, + RT::Type, + rng::Annotation{<:Random.AbstractRNG}, + dst::Annotation{<:AbstractArray} +) + Ty.val(rng.val, dst.val) + if RT <: Duplicated || RT <: DuplicatedNoNeed + make_zero!(dst.dval) + dst.dval + elseif RT <: BatchDuplicated || RT <: BatchDuplicatedNoNeed + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + make_zero!(dst.dval[i]) + nothing + end + end + return EnzymeRules.AugmentedReturn( + EnzymeRules.needs_primal(config) ? dst.val : nothing, + EnzymeRules.needs_shadow(config) ? dst.dval : nothing, + nothing, + ) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfig, + Ty::Const{typeof(Random.randn!)}, + RT::Type, + tape, + rng::Annotation{<:Random.AbstractRNG}, + dst::Annotation{<:AbstractArray}) + return (nothing, nothing) +end From a5c6fee3e6470a17aea20342b9166feee1a0ffa3 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 27 Sep 2024 20:53:43 -0500 Subject: [PATCH 66/87] Fix deferred any active return (#1909) * Fix deferred any active return * fix * fix --- Project.toml | 2 +- lib/EnzymeCore/Project.toml | 2 +- lib/EnzymeCore/src/EnzymeCore.jl | 24 +-- src/Enzyme.jl | 245 +++++++++++++++++++------------ src/internal_rules.jl | 4 + test/abi.jl | 14 ++ test/runtests.jl | 4 +- 7 files changed, 188 insertions(+), 107 deletions(-) diff --git a/Project.toml b/Project.toml index 0a03500ac95..3b8ddd20602 100644 --- a/Project.toml +++ b/Project.toml @@ -35,7 +35,7 @@ EnzymeStaticArraysExt = "StaticArrays" BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" -EnzymeCore = "0.8.3" +EnzymeCore = "0.8.4" Enzyme_jll = "0.0.151" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27" LLVM = "6.1, 7, 8, 9" diff --git a/lib/EnzymeCore/Project.toml b/lib/EnzymeCore/Project.toml index 3a871b930cc..2e45d2c2f60 100644 --- a/lib/EnzymeCore/Project.toml +++ b/lib/EnzymeCore/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeCore" uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" authors = ["William Moses ", "Valentin Churavy "] -version = "0.8.3" +version = "0.8.4" [compat] Adapt = "3, 4" diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index f76b3023021..a536a664aa5 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -289,21 +289,21 @@ Reverse mode differentiation. - `Width`: Batch Size (0 if to be automatically derived) - `ModifiedBetween`: Tuple of each argument's modified between state (true if to be automatically derived). """ -struct ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,RuntimeActivity,ModifiedBetween,ABI, ErrIfFuncWritten} <: Mode{ABI, ErrIfFuncWritten,RuntimeActivity} end -const ReverseSplitNoPrimal = ReverseModeSplit{false, true, false, 0, true,DefaultABI, false}() -const ReverseSplitWithPrimal = ReverseModeSplit{true, true, false, 0, true,DefaultABI, false}() -@inline ReverseSplitModified(::ReverseModeSplit{ReturnPrimal, ReturnShadow, RuntimeActivity, Width, MBO, ABI, ErrIfFuncWritten}, ::Val{MB}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,MBO,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity, Width,MB,ABI, ErrIfFuncWritten}() -@inline ReverseSplitWidth(::ReverseModeSplit{ReturnPrimal, ReturnShadow, RuntimeActivity, WidthO, MB, ABI, ErrIfFuncWritten}, ::Val{Width}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,WidthO,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,ABI, ErrIfFuncWritten}() +struct ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,RuntimeActivity,ModifiedBetween,ABI,Holomorphic,ErrIfFuncWritten,ShadowInit} <: Mode{ABI, ErrIfFuncWritten,RuntimeActivity} end +const ReverseSplitNoPrimal = ReverseModeSplit{false, true, false, 0, true,DefaultABI, false, false, false}() +const ReverseSplitWithPrimal = ReverseModeSplit{true, true, false, 0, true,DefaultABI, false, false, false}() +@inline ReverseSplitModified(::ReverseModeSplit{ReturnPrimal, ReturnShadow, RuntimeActivity, Width, MBO, ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}, ::Val{MB}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,MBO,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity, Width,MB,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}() +@inline ReverseSplitWidth(::ReverseModeSplit{ReturnPrimal, ReturnShadow, RuntimeActivity, WidthO, MB, ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}, ::Val{Width}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,WidthO,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}() -@inline set_err_if_func_written(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, true}() -@inline clear_err_if_func_written(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI, false}() +@inline set_err_if_func_written(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, true, ShadowInit}() +@inline clear_err_if_func_written(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI, Holomorphic, false, ShadowInit}() -@inline set_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,true,Width,ModifiedBetween,ABI, ErrIfFuncWritten}() -@inline set_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}, rt::Bool) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,rt,Width,ModifiedBetween,ABI, ErrIfFuncWritten}() -@inline clear_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,false,Width,ModifiedBetween,ABI, ErrIfFuncWritten}() +@inline set_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,true,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}() +@inline set_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}, rt::Bool) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,rt,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}() +@inline clear_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,false,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}() -@inline WithPrimal(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{true,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}() -@inline NoPrimal(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{false,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}() +@inline WithPrimal(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{true,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}() +@inline NoPrimal(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{false,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}() """ diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 091021cb8bb..aa018ea23b8 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -350,18 +350,13 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) throw(ErrorException("Cannot differentiate with a batch size of 0")) end - ModifiedBetween = Val(falses_from_args(Nargs + 1)) + ModifiedBetweenT = falses_from_args(Nargs + 1) + ModifiedBetween = Val(ModifiedBetweenT) tt = Tuple{map(T -> eltype(Core.Typeof(T)), args)...} FTy = Core.Typeof(f.val) - opt_mi = if RABI <: NonGenABI - Compiler.fspec(eltype(FA), tt′) - else - Val(codegen_world_age(FTy, tt)) - end - rt = if A isa UnionAll Compiler.primal_return_type(rmode, Val(codegen_world_age(FTy, tt)), FTy, tt) else @@ -370,20 +365,22 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) if A <: Active if (!allocatedinline(rt) || rt isa Union) && rt != Union{} - forward, adjoint = Enzyme.Compiler.thunk( - opt_mi, + forward, adjoint = autodiff_thunk( + ReverseModeSplit{ + ReturnPrimal, + #=ReturnShadow=#false, + RuntimeActivity, + width, + ModifiedBetweenT, + RABI, + Holomorphic, + ErrIfFuncWritten, + #=ShadowInit=#true + }(), FA, Duplicated{rt}, - tt′, - Val(API.DEM_ReverseModeGradient), - Val(width), - ModifiedBetween, - Val(ReturnPrimal), - Val(true), - RABI, - Val(ErrIfFuncWritten), - Val(RuntimeActivity), - ) #=ShadowInit=# + (tt′).parameters... + ) res = forward(f, args...) tape = res[1] if ReturnPrimal @@ -400,6 +397,12 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) throw(ErrorException("Duplicated Returns not yet handled")) end + opt_mi = if RABI <: NonGenABI + Compiler.fspec(eltype(FA), tt′) + else + Val(codegen_world_age(FTy, tt)) + end + if (A <: Active && rt <: Complex) && rt != Union{} if Holomorphic seen = IdDict() @@ -651,7 +654,7 @@ Same as [`autodiff`](@ref) but uses deferred compilation to support usage in GPU code, as well as high-order differentiation. """ @inline function autodiff_deferred( - ::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, + rmode::ReverseMode{ReturnPrimal,RuntimeActivity,RABI,Holomorphic,ErrIfFuncWritten}, f::FA, ::Type{A}, args::Vararg{Annotation,Nargs}, @@ -660,7 +663,7 @@ code, as well as high-order differentiation. A<:Annotation, ReturnPrimal, Nargs, - ABI, + RABI<:ABI, Holomorphic, ErrIfFuncWritten, RuntimeActivity, @@ -672,27 +675,85 @@ code, as well as high-order differentiation. end tt = Tuple{map(T -> eltype(Core.Typeof(T)), args)...} - world = codegen_world_age(Core.Typeof(f.val), tt) + FTy = Core.Typeof(f.val) + world = codegen_world_age(FTy, tt) + + A2 = A if A isa UnionAll - rt = Core.Compiler.return_type(f.val, tt) - rt = A{rt} + rt = Compiler.primal_return_type(rmode, Val(world), FTy, tt) + A2 = A{rt} else @assert A isa DataType rt = A end - if eltype(rt) == Union{} + if rt == Union{} error("Return type inferred to be Union{}. Giving up.") end - ModifiedBetween = Val(falses_from_args(Nargs + 1)) + ModifiedBetweenT = falses_from_args(Nargs + 1) + ModifiedBetween = Val(ModifiedBetweenT) + + if A <: Active + if (!allocatedinline(rt) || rt isa Union) && rt != Union{} + rs = ReverseModeSplit{ + ReturnPrimal, + #=ReturnShadow=#false, + RuntimeActivity, + width, + ModifiedBetweenT, + RABI, + Holomorphic, + ErrIfFuncWritten, + #=ShadowInit=#true + }() + TapeType = tape_type(rs, FA, Duplicated{rt}, + (tt′).parameters...) + forward, adjoint = autodiff_deferred_thunk( + rs, + TapeType, + FA, + Duplicated{rt}, + (tt′).parameters... + ) + res = forward(f, args...) + tape = res[1] + if ReturnPrimal + return (adjoint(f, args..., tape)[1], res[2]) + else + return adjoint(f, args..., tape) + end + end + elseif A <: Duplicated || + A <: DuplicatedNoNeed || + A <: BatchDuplicated || + A <: BatchDuplicatedNoNeed || + A <: BatchDuplicatedFunc + throw(ErrorException("Duplicated Returns not yet handled")) + end + + if (A <: Active && rt <: Complex) && rt != Union{} + if Holomorphic + throw( + ErrorException( + "Reverse-mode Active Holomorphic is not yet implemented in deferred codegen", + ), + ) + end + + throw( + ErrorException( + "Reverse-mode Active Complex return is ambiguous and requires more information to specify the desired result. See https://enzyme.mit.edu/julia/stable/faq/#Complex-numbers for more details.", + ), + ) + end adjoint_ptr = Compiler.deferred_codegen( Val(world), FA, Val(tt′), - Val(rt), + Val(A), Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, @@ -704,9 +765,9 @@ code, as well as high-order differentiation. ) #=ShadowInit=# thunk = - Compiler.CombinedAdjointThunk{Ptr{Cvoid},FA,rt,tt′,width,ReturnPrimal}(adjoint_ptr) - if rt <: Active - args = (args..., Compiler.default_adjoint(eltype(rt))) + Compiler.CombinedAdjointThunk{Ptr{Cvoid},FA,A2,tt′,width,ReturnPrimal}(adjoint_ptr) + if A <: Active + args = (args..., Compiler.default_adjoint(rt)) elseif A <: Duplicated || A <: DuplicatedNoNeed || A <: BatchDuplicated || @@ -723,7 +784,7 @@ Same as `autodiff(::ForwardMode, f, Activity, args...)` but uses deferred compil code, as well as high-order differentiation. """ @inline function autodiff_deferred( - ::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}, + ::ForwardMode{ReturnPrimal,RABI,ErrIfFuncWritten,RuntimeActivity}, f::FA, ::Type{A}, args::Vararg{Annotation,Nargs}, @@ -732,7 +793,7 @@ code, as well as high-order differentiation. FA<:Annotation, A<:Annotation, Nargs, - ABI, + RABI<:ABI, ErrIfFuncWritten, RuntimeActivity, } @@ -857,7 +918,9 @@ result, ∂v, ∂A Width, ModifiedBetweenT, RABI, + #=Holomorphic=#false, ErrIfFuncWritten, + ShadowInit }, ::Type{FA}, ::Type{A}, @@ -872,6 +935,7 @@ result, ∂v, ∂A RABI<:ABI, Nargs, ErrIfFuncWritten, + ShadowInit, RuntimeActivity, } width = if Width == 0 @@ -892,9 +956,6 @@ result, ∂v, ∂A tt = Tuple{map(eltype, args)...} - if !(A <: Const) - @assert ReturnShadow - end tt′ = Tuple{args...} opt_mi = if RABI <: NonGenABI Compiler.fspec(eltype(FA), tt′) @@ -910,7 +971,7 @@ result, ∂v, ∂A Val(width), ModifiedBetween, Val(ReturnPrimal), - Val(false), + Val(ShadowInit), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity), @@ -1055,7 +1116,9 @@ end Width, ModifiedBetweenT, RABI, + #=Holomorphic=#false, ErrIfFuncWritten, + ShadowInit, }, ::Type{FA}, ::Type{A}, @@ -1071,6 +1134,7 @@ end Nargs, ErrIfFuncWritten, RuntimeActivity, + ShadowInit, } width = if Width == 0 w = same_or_one(1, args...) @@ -1088,7 +1152,6 @@ end ModifiedBetween = Val(ModifiedBetweenT) end - @assert ReturnShadow TT = Tuple{args...} primal_tt = Tuple{map(eltype, args)...} @@ -1106,7 +1169,7 @@ end Val(width), ModifiedBetween, Val(ReturnPrimal), - Val(false), + Val(ShadowInit), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity), @@ -1134,6 +1197,9 @@ import .Compiler: fspec, remove_innerty, UnknownTapeType Width, ModifiedBetweenT, RABI, + #=Holomorphic=#false, + #=ErrIfFuncWritten=#false, + #=ShadowInit=#false, }, ::Type{FA}, ::Type{A}, @@ -1215,7 +1281,7 @@ import .Compiler: fspec, remove_innerty, UnknownTapeType end """ - autodiff_deferred_thunk(::ReverseModeSplit, ftype, Activity, argtypes::Type{<:Annotation}...) + autodiff_deferred_thunk(::ReverseModeSplit, TapeType::Type, ftype::Type{<:Annotation}, Activity::Type{<:Annotation}, argtypes::Type{<:Annotation}...) Provide the split forward and reverse pass functions for annotated function type ftype when called with args of type `argtypes` when using reverse mode. @@ -1266,7 +1332,9 @@ result, ∂v, ∂A Width, ModifiedBetweenT, RABI, + #=Holomorphic=#false, ErrIfFuncWritten, + ShadowInit, }, tt::Type{TapeType}, fa::Type{FA}, @@ -1284,6 +1352,7 @@ result, ∂v, ∂A Nargs, ErrIfFuncWritten, RuntimeActivity, + ShadowInit } @assert RABI == FFIABI width = if Width == 0 @@ -1302,7 +1371,6 @@ result, ∂v, ∂A ModifiedBetween = Val(ModifiedBetweenT) end - @assert ReturnShadow TT = Tuple{args...} primal_tt = Tuple{map(eltype, args)...} @@ -1317,7 +1385,7 @@ result, ∂v, ∂A Val(width), ModifiedBetween, Val(ReturnPrimal), - Val(false), + Val(ShadowInit), TapeType, Val(ErrIfFuncWritten), Val(RuntimeActivity), @@ -2054,7 +2122,6 @@ this function will retun an AbstractArray of shape `size(output)` of values of t jac end else - @assert !Holomorphic n_out_val = if length(Compiler.element(n_outs)) == 0 0 else @@ -2074,32 +2141,27 @@ this function will retun an AbstractArray of shape `size(output)` of values of t Core.Compiler.return_type(f, tt) end - ModifiedBetween = Val((false, false)) + ModifiedBetweenT = (false, false) FRT = Core.Typeof(f) FA = Const{FRT} - opt_mi = if RABI <: NonGenABI - Compiler.fspec(FRT, tt′) - else - Val(codegen_world_age(FRT, tt)) - end - if chunk == Val(1) || chunk == nothing - tt′ = MD ? Tuple{MixedDuplicated{XT}} : Tuple{Duplicated{XT}} - primal, adjoint = Enzyme.Compiler.thunk( - opt_mi, + primal, adjoint = autodiff_thunk( + ReverseModeSplit{ + #=ReturnPrimal=#false, + #=ReturnShadow=#true, + RuntimeActivity, + #=width=#1, + ModifiedBetweenT, + RABI, + Holomorphic, + ErrIfFuncWritten, + #=ShadowInit=#false + }(), FA, DuplicatedNoNeed{rt}, - tt′, - Val(API.DEM_ReverseModeGradient), - Val(1), - ModifiedBetween, - Val(false), - Val(false), - RABI, - Val(ErrIfFuncWritten), - Val(RuntimeActivity), - ) #=ShadowInit=# + MD ? MixedDuplicated{XT} : Duplicated{XT} + ) tmp = ntuple(Val(n_out_val)) do i Base.@_inline_meta z = make_zero(x) @@ -2115,23 +2177,22 @@ this function will retun an AbstractArray of shape `size(output)` of values of t rows, outshape else chunksize = Compiler.element(chunk) - tt′ = - MD ? Tuple{BatchMixedDuplicated{XT,chunksize}} : - Tuple{BatchDuplicated{XT,chunksize}} - primal, adjoint = Enzyme.Compiler.thunk( - opt_mi, + primal, adjoint = autodiff_thunk( + ReverseModeSplit{ + #=ReturnPrimal=#false, + #=ReturnShadow=#true, + RuntimeActivity, + chunksize, + ModifiedBetweenT, + RABI, + Holomorphic, + ErrIfFuncWritten, + #=ShadowInit=#false + }(), FA, - BatchDuplicatedNoNeed{rt}, - tt′, - Val(API.DEM_ReverseModeGradient), - chunk, - ModifiedBetween, - Val(false), - Val(false), - RABI, - Val(ErrIfFuncWritten), - Val(RuntimeActivity), - ) #=ShadowInit=# + BatchDuplicatedNoNeed{rt, chunksize}, + MD ? BatchMixedDuplicated{XT, chunksize} : BatchDuplicated{XT, chunksize} + ) num = ((n_out_val + chunksize - 1) ÷ chunksize) @@ -2141,20 +2202,22 @@ this function will retun an AbstractArray of shape `size(output)` of values of t else last_size = n_out_val - (num - 1) * chunksize tt′ = Tuple{BatchDuplicated{Core.Typeof(x),last_size}} - primal2, adjoint2 = Enzyme.Compiler.thunk( - opt_mi, + primal2, adjoint2 = autodiff_thunk( + ReverseModeSplit{ + #=ReturnPrimal=#false, + #=ReturnShadow=#true, + RuntimeActivity, + last_size, + ModifiedBetweenT, + RABI, + Holomorphic, + ErrIfFuncWritten, + #=ShadowInit=#false + }(), FA, - BatchDuplicatedNoNeed{rt}, - tt′, - Val(API.DEM_ReverseModeGradient), - Val(last_size), - ModifiedBetween, - Val(false), - Val(false), - RABI, - Val(ErrIfFuncWritten), - Val(RuntimeActivity), - ) #=ShadowInit=# + BatchDuplicatedNoNeed{rt, last_size}, + MD ? BatchMixedDuplicated{XT, last_size} : BatchDuplicated{XT, last_size} + ) end tmp = ntuple(num) do i diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 438fcc8ecbc..53ca1f92833 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -351,6 +351,8 @@ function EnzymeRules.augmented_primal( EnzymeRules.overwritten(config)[2:end], InlineABI, false, + false, + false }() fwd_thunk, rev_thunk = autodiff_thunk(config2, BodyTy, Const, typeof(count), map(typeof, args)...) @@ -405,6 +407,8 @@ function EnzymeRules.reverse( EnzymeRules.overwritten(config)[2:end], InlineABI, false, + false, + false }() fwd_thunk, rev_thunk = autodiff_thunk(config2, BodyTy, Const, typeof(count), map(typeof, args)...) diff --git a/test/abi.jl b/test/abi.jl index cbd467c1555..5acb30e04fb 100644 --- a/test/abi.jl +++ b/test/abi.jl @@ -300,6 +300,20 @@ using Test # returns: sret, const/ghost, !deserve_retbox end +unstable_load(x) = Base.inferencebarrier(x)[1] + +@testset "Any Return" begin + x = [2.7] + dx = [0.0] + Enzyme.autodiff(Reverse, Const(unstable_load), Active, Duplicated(x, dx)) + @test dx ≈ [1.0] + + x = [2.7] + dx = [0.0] + Enzyme.autodiff_deferred(Reverse, Const(unstable_load), Active, Duplicated(x, dx)) + @test dx ≈ [1.0] +end + @testset "Mutable Struct ABI" begin mutable struct MStruct val::Float32 diff --git a/test/runtests.jl b/test/runtests.jl index bd1c7dd90dc..8b496b071f0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -513,8 +513,8 @@ end mul3(z) = Base.inferencebarrier(2 * z) - @test_throws ErrorException autodiff(ReverseHolomorphic, mul3, Active, Active(z)) - @test_throws ErrorException autodiff(ReverseHolomorphic, mul3, Active{Complex}, Active(z)) + @test_throws MethodError autodiff(ReverseHolomorphic, mul3, Active, Active(z)) + @test_throws MethodError autodiff(ReverseHolomorphic, mul3, Active{Complex}, Active(z)) vals = Complex{Float64}[3.4 + 2.7im] dvals = Complex{Float64}[0.0] From 327558b3c31ac8714f5468409d8752b0eec1df0e Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 27 Sep 2024 20:53:59 -0500 Subject: [PATCH 67/87] Handle type unstable getglobal (#1910) --- src/rules/jitrules.jl | 105 +++++++++++++++++++++++------------------- 1 file changed, 57 insertions(+), 48 deletions(-) diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index 75bc4156544..28ecb7afea3 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -456,26 +456,32 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) else annotation0 end - world = codegen_world_age(FT, tt) - opt_mi = Val(world) - forward, adjoint = thunk( - opt_mi, - dupClosure0 ? $dupty : Const{FT}, - annotationA, - Tuple{$(Types...)}, - Val(API.DEM_ReverseModePrimal), - width, - ModifiedBetween, - Val(true), - Val(false), - FFIABI, - Val(false), - runtimeActivity, - ) #=erriffuncwritten=# + internal_tape, origRet, initShadow, annotation = if f isa typeof(Core.getglobal) + gv = Core.getglobal(args[1].val, args[2].val) + @assert sizeof(gv) == 0 + (nothing, f, nothing, Const) + else + world = codegen_world_age(FT, tt) - internal_tape, origRet, initShadow = forward(dupClosure0 ? $dup : Const(f), args...) - annotation = annotationA + opt_mi = Val(world) + forward, adjoint = thunk( + opt_mi, + dupClosure0 ? $dupty : Const{FT}, + annotationA, + Tuple{$(Types...)}, + Val(API.DEM_ReverseModePrimal), + width, + ModifiedBetween, + Val(true), + Val(false), + FFIABI, + Val(false), + runtimeActivity, + ) #=erriffuncwritten=# + + (forward(dupClosure0 ? $dup : Const(f), args...)..., annotationA) + end resT = typeof(origRet) if annotation <: Const @@ -649,39 +655,42 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act annotation0 end - world = codegen_world_age(FT, tt) + if f isa typeof(Core.getglobal) + else + world = codegen_world_age(FT, tt) - opt_mi = Val(world) - _, adjoint = thunk( - opt_mi, - dupClosure0 ? $dupty : Const{FT}, - annotation, - Tuple{$(Types...)}, - Val(API.DEM_ReverseModePrimal), - width, - ModifiedBetween, - Val(true), - Val(false), - FFIABI, - Val(false), - runtimeActivity, - ) #=erriffuncwritten=# + opt_mi = Val(world) + _, adjoint = thunk( + opt_mi, + dupClosure0 ? $dupty : Const{FT}, + annotation, + Tuple{$(Types...)}, + Val(API.DEM_ReverseModePrimal), + width, + ModifiedBetween, + Val(true), + Val(false), + FFIABI, + Val(false), + runtimeActivity, + ) #=erriffuncwritten=# - tup = - if annotation0 <: Active || - annotation0 <: MixedDuplicated || - annotation0 <: BatchMixedDuplicated - adjoint( - dupClosure0 ? $dup : Const(f), - args..., - $shadowret, - tape.internal_tape, - )[1] - else - adjoint(dupClosure0 ? $dup : Const(f), args..., tape.internal_tape)[1] - end + tup = + if annotation0 <: Active || + annotation0 <: MixedDuplicated || + annotation0 <: BatchMixedDuplicated + adjoint( + dupClosure0 ? $dup : Const(f), + args..., + $shadowret, + tape.internal_tape, + )[1] + else + adjoint(dupClosure0 ? $dup : Const(f), args..., tape.internal_tape)[1] + end - $(outs...) + $(outs...) + end return nothing end From 259f16132157709f9a94020854439975db92b953 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 27 Sep 2024 22:08:28 -0500 Subject: [PATCH 68/87] Optimize active only rev grad (#1911) * Optimize active only rev grad * Update Project.toml * add makezero s/marray --- ext/EnzymeStaticArraysExt.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ext/EnzymeStaticArraysExt.jl b/ext/EnzymeStaticArraysExt.jl index bcaa3ec6cbb..b751c336a2a 100644 --- a/ext/EnzymeStaticArraysExt.jl +++ b/ext/EnzymeStaticArraysExt.jl @@ -23,4 +23,11 @@ end end end +@inline function Enzyme.EnzymeCore.make_zero(x::FT)::FT where {FT<:SArray} + return Base.zero(x) +end +@inline function Enzyme.EnzymeCore.make_zero(x::FT)::FT where {FT<:MArray} + return Base.zero(x) +end + end From 4ff5e44de2c3403818dd6f5c2b10d66bbc6e359d Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 27 Sep 2024 23:00:42 -0500 Subject: [PATCH 69/87] Fix getglobal value (#1912) --- src/rules/jitrules.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index 28ecb7afea3..f169ab2c4bf 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -458,9 +458,9 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) end internal_tape, origRet, initShadow, annotation = if f isa typeof(Core.getglobal) - gv = Core.getglobal(args[1].val, args[2].val) + gv = Core.getglobal(map(x->x.val, args)...) @assert sizeof(gv) == 0 - (nothing, f, nothing, Const) + (nothing, gv, nothing, Const) else world = codegen_world_age(FT, tt) From edd00954a7e274064bcd9097485bcdf19f3624cc Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 28 Sep 2024 18:26:41 +0200 Subject: [PATCH 70/87] Improve documentation of modes (#1895) * Improve documentation of modes * Alignment * Add comment on setter functions * List helper functions * More cases for set_runtime_activity * Merge remote-tracking branch 'upstream/main' into gd/modes_doc * Cleaner diff * Smaller diff --- lib/EnzymeCore/src/EnzymeCore.jl | 251 +++++++++++++++++++++++++++---- 1 file changed, 220 insertions(+), 31 deletions(-) diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index a536a664aa5..394cd00a5fc 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -209,50 +209,123 @@ end abstract type ABI Abstract type for what ABI will be used. + +# Subtypes + +- [`FFIABI`](@ref) (the default) +- [`InlineABI`](@ref) +- [`NonGenABI`](@ref) """ abstract type ABI end """ struct FFIABI <: ABI -Foreign function call ABI. JIT the differentiated function, then inttoptr call the address. +Foreign function call [`ABI`](@ref). JIT the differentiated function, then inttoptr call the address. """ struct FFIABI <: ABI end + """ struct InlineABI <: ABI -Inlining function call ABI. +Inlining function call [`ABI`](@ref). """ struct InlineABI <: ABI end + """ struct NonGenABI <: ABI -Non-generated function ABI. +Non-generated function [`ABI`](@ref). """ struct NonGenABI <: ABI end + const DefaultABI = FFIABI """ - abstract type Mode + abstract type Mode{ABI,ErrIfFuncWritten,RuntimeActivity} + +Abstract type for which differentiation mode will be used. -Abstract type for what differentiation mode will be used. +# Subtypes + +- [`ForwardMode`](@ref) +- [`ReverseMode`](@ref) +- [`ReverseModeSplit`](@ref) + +# Type parameters + +- `ABI`: what runtime [`ABI`](@ref) to use +- `ErrIfFuncWritten`: whether to error when the function differentiated is a closure and written to. +- `RuntimeActivity`: whether to enable runtime activity (default off) + +!!! warning + The type parameters of `Mode` are not part of the public API and can change without notice. + You can modify them with the following helper functions: + - [`WithPrimal`](@ref) / [`NoPrimal`](@ref) + - [`set_err_if_func_written`](@ref) / [`clear_err_if_func_written`](@ref) + - [`set_runtime_activity`](@ref) / [`clear_runtime_activity`](@ref) + - [`set_abi`](@ref) """ abstract type Mode{ABI, ErrIfFuncWritten, RuntimeActivity} end """ - struct ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} <: Mode{ABI, ErrIfFuncWritten, RuntimeActivity} + struct ReverseMode{ + ReturnPrimal, + RuntimeActivity, + ABI, + Holomorphic, + ErrIfFuncWritten + } <: Mode{ABI,ErrIfFuncWritten,RuntimeActivity} + +Subtype of [`Mode`](@ref) for reverse mode differentiation. -Reverse mode differentiation. -- `ReturnPrimal`: Should Enzyme return the primal return value from the augmented-forward. -- `RuntimeActivity`: Should Enzyme enable runtime activity (default off) -- `ABI`: What runtime ABI to use -- `Holomorphic`: Whether the complex result function is holomorphic and we should compute d/dz -- `ErrIfFuncWritten`: Should Enzyme err if the function differentiated is a closure and written to. +# Type parameters + +- `ReturnPrimal`: whether to return the primal return value from the augmented-forward pass. +- `Holomorphic`: Whether the complex result function is holomorphic and we should compute `d/dz` +- other parameters: see [`Mode`](@ref) + +!!! warning + The type parameters of `ReverseMode` are not part of the public API and can change without notice. + Please use one of the following concrete instantiations instead: + - [`Reverse`](@ref) + - [`ReverseWithPrimal`](@ref) + - [`ReverseHolomorphic`](@ref) + - [`ReverseHolomorphicWithPrimal`](@ref) + You can modify them with the following helper functions: + - [`WithPrimal`](@ref) / [`NoPrimal`](@ref) + - [`set_err_if_func_written`](@ref) / [`clear_err_if_func_written`](@ref) + - [`set_runtime_activity`](@ref) / [`clear_runtime_activity`](@ref) + - [`set_abi`](@ref) """ struct ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} <: Mode{ABI, ErrIfFuncWritten,RuntimeActivity} end + +""" + const Reverse + +Default instance of [`ReverseMode`](@ref) that doesn't return the primal +""" const Reverse = ReverseMode{false,false,DefaultABI, false, false}() + +""" + const ReverseWithPrimal + +Default instance of [`ReverseMode`](@ref) that also returns the primal. +""" const ReverseWithPrimal = ReverseMode{true,false,DefaultABI, false, false}() + +""" + const ReverseHolomorphic + +Holomorphic instance of [`ReverseMode`](@ref) that doesn't return the primal +""" const ReverseHolomorphic = ReverseMode{false,false,DefaultABI, true, false}() + +""" + const ReverseHolomorphicWithPrimal + +Holomorphic instance of [`ReverseMode`](@ref) that also returns the primal +""" const ReverseHolomorphicWithPrimal = ReverseMode{true,false,DefaultABI, true, false}() @inline set_err_if_func_written(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,true}() @@ -265,34 +338,80 @@ const ReverseHolomorphicWithPrimal = ReverseMode{true,false,DefaultABI, true, fa @inline clear_runtime_activity(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,false,ABI,Holomorphic,ErrIfFuncWritten}() """ - WithPrimal(::Enzyme.Mode) + WithPrimal(::Mode) -Modifies the mode to include the primal value. +Return a new mode which includes the primal value. """ @inline WithPrimal(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{true,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}() """ - NoPrimal(::Enzyme.Mode) + NoPrimal(::Mode) -Modifies the mode to exclude the primal value. +Return a new mode which excludes the primal value. """ @inline NoPrimal(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{false,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}() - """ - struct ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI} <: Mode{ABI,ErrIfFuncWritten,RuntimeActivity} + struct ReverseModeSplit{ + ReturnPrimal, + ReturnShadow, + Width, + RuntimeActivity, + ModifiedBetween, + ABI, + ErrFuncIfWritten + } <: Mode{ABI,ErrIfFuncWritten,RuntimeActivity} + WithPrimal(::Enzyme.Mode) + +Subtype of [`Mode`](@ref) for split reverse mode differentiation, to use in [`autodiff_thunk`](@ref) and variants. + +# Type parameters -Reverse mode differentiation. -- `ReturnPrimal`: Should Enzyme return the primal return value from the augmented-forward. -- `ReturnShadow`: Should Enzyme return the shadow return value from the augmented-forward. -- `RuntimeActivity`: Should Enzyme differentiate with runtime activity on (default off). -- `Width`: Batch Size (0 if to be automatically derived) -- `ModifiedBetween`: Tuple of each argument's modified between state (true if to be automatically derived). +- `ReturnShadow`: whether to return the shadow return value from the augmented-forward. +- `Width`: batch size (pick `0` to derive it automatically) +- `ModifiedBetween`: `Tuple` of each argument's "modified between" state (pick `true` to derive it automatically). +- other parameters: see [`ReverseMode`](@ref) + +!!! warning + The type parameters of `ReverseModeSplit` are not part of the public API and can change without notice. + Please use one of the following concrete instantiations instead: + - [`ReverseSplitNoPrimal`](@ref) + - [`ReverseSplitWithPrimal`](@ref) + You can modify them with the following helper functions: + - [`WithPrimal`](@ref) / [`NoPrimal`](@ref) + - [`set_err_if_func_written`](@ref) / [`clear_err_if_func_written`](@ref) + - [`set_runtime_activity`](@ref) / [`clear_runtime_activity`](@ref) + - [`set_abi`](@ref) + - [`ReverseSplitModified`](@ref), [`ReverseSplitWidth`](@ref) """ struct ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,RuntimeActivity,ModifiedBetween,ABI,Holomorphic,ErrIfFuncWritten,ShadowInit} <: Mode{ABI, ErrIfFuncWritten,RuntimeActivity} end + +""" + const ReverseSplitNoPrimal + +Default instance of [`ReverseModeSplit`](@ref) that doesn't return the primal +""" const ReverseSplitNoPrimal = ReverseModeSplit{false, true, false, 0, true,DefaultABI, false, false, false}() + +""" + const ReverseSplitWithPrimal + +Default instance of [`ReverseModeSplit`](@ref) that also returns the primal +""" const ReverseSplitWithPrimal = ReverseModeSplit{true, true, false, 0, true,DefaultABI, false, false, false}() + +""" + ReverseSplitModified(::ReverseModeSplit, ::Val{MB}) + +Return a new instance of [`ReverseModeSplit`](@ref) mode where `ModifiedBetween` is set to `MB`. +""" @inline ReverseSplitModified(::ReverseModeSplit{ReturnPrimal, ReturnShadow, RuntimeActivity, Width, MBO, ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}, ::Val{MB}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,MBO,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity, Width,MB,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}() + +""" + ReverseSplitWidth(::ReverseModeSplit, ::Val{W}) + +Return a new instance of [`ReverseModeSplit`](@ref) mode where `Width` is set to `W`. +""" @inline ReverseSplitWidth(::ReverseModeSplit{ReturnPrimal, ReturnShadow, RuntimeActivity, WidthO, MB, ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}, ::Val{Width}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,WidthO,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}() @inline set_err_if_func_written(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, true, ShadowInit}() @@ -307,13 +426,46 @@ const ReverseSplitWithPrimal = ReverseModeSplit{true, true, false, 0, true,Defau """ - struct Forward{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity} <: Mode{ABI, ErrIfFuncWritten, RuntimeActivity} + struct ForwardMode{ + ReturnPrimal, + ABI, + ErrIfFuncWritten, + RuntimeActivity + } <: Mode{ABI,ErrIfFuncWritten,RuntimeActivity} + +Subtype of [`Mode`](@ref) for forward mode differentiation. -Forward mode differentiation +# Type parameters + +- `ReturnPrimal`: whether to return the primal return value from the augmented-forward. +- other parameters: see [`Mode`](@ref) + +!!! warning + The type parameters of `ForwardMode` are not part of the public API and can change without notice. + Please use one of the following concrete instantiations instead: + - [`Forward`](@ref) + - [`ForwardWithPrimal`](@ref) + You can modify them with the following helper functions: + - [`WithPrimal`](@ref) / [`NoPrimal`](@ref) + - [`set_err_if_func_written`](@ref) / [`clear_err_if_func_written`](@ref) + - [`set_runtime_activity`](@ref) / [`clear_runtime_activity`](@ref) + - [`set_abi`](@ref) """ struct ForwardMode{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity} <: Mode{ABI, ErrIfFuncWritten, RuntimeActivity} end + +""" + const Forward + +Default instance of [`ForwardMode`](@ref) that doesn't return the primal +""" const Forward = ForwardMode{false, DefaultABI, false, false}() + +""" + const ForwardWithPrimal + +Default instance of [`ForwardMode`](@ref) that also returns the primal +""" const ForwardWithPrimal = ForwardMode{true, DefaultABI, false, false}() @inline set_err_if_func_written(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ReturnPrimal,ABI,true,RuntimeActivity}() @@ -337,22 +489,22 @@ function autodiff_deferred_thunk end """ make_zero(::Type{T}, seen::IdDict, prev::T, ::Val{copy_if_inactive}=Val(false))::T - Recursively make a zero'd copy of the value `prev` of type `T`. The argument `copy_if_inactive` specifies - what to do if the type `T` is guaranteed to be inactive, use the primal (the default) or still copy the value. +Recursively make a zero'd copy of the value `prev` of type `T`. The argument `copy_if_inactive` specifies +what to do if the type `T` is guaranteed to be inactive, use the primal (the default) or still copy the value. """ function make_zero end """ make_zero!(val::T, seen::IdSet{Any}=IdSet())::Nothing - Recursively set a variables differentiable fields to zero. Only applicable for mutable types `T`. +Recursively set a variables differentiable fields to zero. Only applicable for mutable types `T`. """ function make_zero! end """ make_zero(prev::T) - Helper function to recursively make zero. +Helper function to recursively make zero. """ @inline function make_zero(prev::T, ::Val{copy_if_inactive}=Val(false)) where {T, copy_if_inactive} make_zero(Core.Typeof(prev), IdDict(), prev, Val(copy_if_inactive)) @@ -383,10 +535,47 @@ if !isdefined(Base, :get_extension) end """ - within_autodiff() + within_autodiff() Returns true if within autodiff, otherwise false. """ function within_autodiff end +""" + set_err_if_func_written(::Mode) + +Return a new mode which throws an error for any attempt to write into an unannotated function object. +""" +function set_err_if_func_written end + +""" + clear_err_if_func_written(::Mode) + +Return a new mode which doesn't throw an error for attempts to write into an unannotated function object. +""" +function clear_err_if_func_written end + +""" + set_runtime_activity(::Mode) + set_runtime_activity(::Mode, activitiy::Bool) + set_runtime_activity(::Mode, config::Union{FwdConfig,RevConfig}) + +Return a new mode where runtime activity analysis is activated / set to the desired value. +""" +function set_runtime_activity end + +""" + clear_runtime_activity(::Mode) + +Return a new mode where runtime activity analysis is deactivated. +""" +function clear_runtime_activity end + +""" + set_abi(::Mode, ::Type{ABI}) + +Return a new mode with its [`ABI`](@ref) set to the chosen type. +""" +function set_abi end + end # module EnzymeCore From 467b4f7a7cd368ba2189b59464d5904cb3259394 Mon Sep 17 00:00:00 2001 From: ExpandingMan Date: Sat, 28 Sep 2024 12:27:23 -0400 Subject: [PATCH 71/87] fix some cases of gradient/jacobian with StaticArrays (#1875) * fix some cases of gradient/jacobian with StaticArrays * add tests * Update EnzymeStaticArraysExt.jl * jacobian is exported, wtf? --------- Co-authored-by: William Moses --- ext/EnzymeStaticArraysExt.jl | 9 ++++++- test/runtests.jl | 47 +++++++++++++++++++++++++++++++++++- 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/ext/EnzymeStaticArraysExt.jl b/ext/EnzymeStaticArraysExt.jl index b751c336a2a..af31d405d79 100644 --- a/ext/EnzymeStaticArraysExt.jl +++ b/ext/EnzymeStaticArraysExt.jl @@ -3,7 +3,14 @@ module EnzymeStaticArraysExt using StaticArrays using Enzyme -@inline Enzyme.tupstack(rows::(NTuple{N, <:StaticArrays.SArray} where N), inshape, outshape) = reshape(cat(rows..., dims=length(inshape)), (inshape..., outshape...)) +@inline function Base.convert(::Type{SArray}, tpa::Enzyme.TupleArray{T,S,L,N}) where {T,S,L,N} + SArray{Tuple{S...},T,N,L}(tpa.data) +end +@inline Base.convert(::Type{StaticArray}, tpa::Enzyme.TupleArray) = convert(SArray, tpa) + +@inline function Enzyme.tupstack(rows::(NTuple{N, <:StaticArrays.SArray} where N), inshape, outshape) + reshape(reduce(hcat, map(vec, rows)), Size(inshape..., outshape...)) +end @inline function Enzyme.onehot(x::StaticArrays.SArray{S, T, N, L}) where {S, T, N, L} ntuple(Val(L)) do i diff --git a/test/runtests.jl b/test/runtests.jl index 8b496b071f0..eb9dfccb6cb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2854,6 +2854,51 @@ end @test dx[1] ≈ 0 @test dx[2] ≈ 30 @test dx[3] ≈ 0 + + f0 = x -> sum(2*x) + f1 = x -> @SVector Float64[x[2], 2*x[2]] + f2 = x -> @SMatrix Float64[x[2] x[1]; 2*x[2] 2*x[1]] + + x = @SVector Float64[1, 2] + + dx = gradient(Forward, f0, x)[1] + @test dx isa Enzyme.TupleArray + @test convert(SArray, dx) == [2.0, 2.0] # test to make sure conversion works + @test gradient(Forward, f1, x)[1] isa SMatrix + @test gradient(Forward, f1, x)[1] == [0 1.0; 0 2.0] + @test Enzyme.jacobian(Forward, f2, x)[1] isa SArray + @test Enzyme.jacobian(Forward, f2, x)[1] == reshape(Float64[0,0,1,2,1,2,0,0], (2,2,2)) + + x = @SMatrix Float64[1 2; 3 4] + + dx = gradient(Forward, f0, x)[1] + @test dx isa Enzyme.TupleArray + @test convert(SArray, dx) == fill(2.0, (2,2)) + @test gradient(Forward, f1, x)[1] isa SArray + @test gradient(Forward, f1, x)[1] == reshape(Float64[0,0,1,2,0,0,0,0], (2,2,2)) + @test Enzyme.jacobian(Forward, f2, x)[1] isa SArray + @test Enzyme.jacobian(Forward, f2, x)[1] == reshape( + Float64[0,0,1,2,1,2,0,0,0,0,0,0,0,0,0,0], (2,2,2,2), + ) + + x = @SVector Float64[1, 2] + + dx = gradient(Reverse, f0, x)[1] + @test dx isa SVector + @test convert(SArray, dx) == [2.0, 2.0] # test to make sure conversion works + @test_broken gradient(Reverse, f1, x)[1] isa SMatrix + @test_broken gradient(Reverse, f1, x)[1] == [0 1.0; 0 2.0] + @test_broken Enzyme.jacobian(Reverse, f2, x)[1] isa SArray + @test_broken Enzyme.jacobian(Reverse, f2, x)[1] == reshape(Float64[0,0,1,2,1,2,0,0], (2,2,2)) + + x = @SMatrix Float64[1 2; 3 4] + + @test_broken gradient(Reverse, f1, x)[1] isa SArray + @test_broken gradient(Reverse, f1, x)[1] == reshape(Float64[0,0,1,2,0,0,0,0], (2,2,2)) + @test_broken Enzyme.jacobian(Reverse, f2, x)[1] isa SArray + @test_broken Enzyme.jacobian(Reverse, f2, x)[1] == reshape( + Float64[0,0,1,2,1,2,0,0,0,0,0,0,0,0,0,0], (2,2,2,2), + ) end function unstable_fun(A0) @@ -4101,4 +4146,4 @@ include("ext/logexpfunctions.jl") @testset "BFloat16s ext" begin include("ext/bfloat16s.jl") -end \ No newline at end of file +end From b999e5a00301a11dddfa1f9da416c6736eaf3bd9 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 28 Sep 2024 13:14:28 -0500 Subject: [PATCH 72/87] Use actual_size instead of sizeof (#1915) * Use actual_size instead of sizeof * Better error str --- src/rules/llvmrules.jl | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index a41912cf829..3c4b95d8eea 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -1519,7 +1519,7 @@ end found, arty, byref = abs_typeof(origops[1]) anti = shadowin elSize = if found - LLVM.ConstantInt(Csize_t(sizeof(eltype(arty)))) + LLVM.ConstantInt(Csize_t(actual_size(eltype(arty)))) else elSize = LLVM.zext!( B, @@ -1534,7 +1534,12 @@ end length = LLVM.mul!(B, len, elSize) if !found && !(eltype(arty) <: Base.IEEEFloat) - GPUCompiler.@safe_warn "TODO reverse jl_array_del_end zero-set used memset rather than runtime type of $((found, arty)) in $(string(origops[1]))" + bt = GPUCompiler.backtrace(orig) + btstr = sprint() do io + print(io, "\nCaused by:") + Base.show_backtrace(io, bt) + end + GPUCompiler.@safe_warn "TODO reverse jl_array_del_end zero-set used memset rather than runtime type of $((found, arty)) in $(string(origops[1])) $btstr" end toset = get_array_data(B, anti) toset = gep!(B, i8, toset, LLVM.Value[length]) From 55582f8ba60f41411d162b3d3d73155d0545199c Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 28 Sep 2024 14:00:41 -0500 Subject: [PATCH 73/87] Even more indexed typeinfo (#1916) --- src/absint.jl | 61 ++++++++++++++++++++++++++----------------------- src/compiler.jl | 8 ++++--- src/typetree.jl | 25 +++++++++++++++++--- 3 files changed, 59 insertions(+), 35 deletions(-) diff --git a/src/absint.jl b/src/absint.jl index 041b3bd1cc3..77ce2b6a7e5 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -204,6 +204,34 @@ function should_recurse(@nospecialize(typ2), arg_t, byref, dl) end end +function get_base_and_offset(larg::LLVM.Value)::Tuple{LLVM.Value, Int, Bool} + offset = 0 + error = false + while true + if isa(larg, LLVM.BitCastInst) || isa(larg, LLVM.AddrSpaceCastInst) + larg = operands(larg)[1] + continue + end + if isa(larg, LLVM.GetElementPtrInst) && + all(x -> isa(x, LLVM.ConstantInt), operands(larg)[2:end]) + b = LLVM.IRBuilder() + position!(b, larg) + offty = LLVM.IntType(8 * sizeof(Int)) + offset2 = API.EnzymeComputeByteOffsetOfGEP(b, larg, offty) + @assert isa(offset2, LLVM.ConstantInt) + offset += convert(Int, offset2) + larg = operands(larg)[1] + continue + end + if isa(larg, LLVM.Argument) + break + end + error = true + break + end + return larg, offset, error +end + function abs_typeof( arg::LLVM.Value, partial::Bool = false, @@ -354,32 +382,7 @@ function abs_typeof( end if isa(arg, LLVM.LoadInst) - larg = operands(arg)[1] - offset = nothing - error = false - while true - if isa(larg, LLVM.BitCastInst) || isa(larg, LLVM.AddrSpaceCastInst) - larg = operands(larg)[1] - continue - end - if offset === nothing && - isa(larg, LLVM.GetElementPtrInst) && - all(x -> isa(x, LLVM.ConstantInt), operands(larg)[2:end]) - b = LLVM.IRBuilder() - position!(b, larg) - offty = LLVM.IntType(8 * sizeof(Int)) - offset = API.EnzymeComputeByteOffsetOfGEP(b, larg, offty) - @assert isa(offset, LLVM.ConstantInt) - offset = convert(Int, offset) - larg = operands(larg)[1] - continue - end - if isa(larg, LLVM.Argument) - break - end - error = true - break - end + larg, offset, error = get_base_and_offset(operands(arg)[1]) if !error legal, typ, byref = abs_typeof(larg) @@ -387,7 +390,7 @@ function abs_typeof( @static if VERSION < v"1.11-" if typ <: Array && Base.isconcretetype(typ) T = eltype(typ) - if offset === nothing || offset == 0 + if offset == 0 return (true, Ptr{T}, GPUCompiler.BITS_VALUE) else return (true, Int, GPUCompiler.BITS_VALUE) @@ -400,14 +403,14 @@ function abs_typeof( byref = GPUCompiler.BITS_VALUE legal = true - while (offset !== nothing && offset != 0) && legal + while offset != 0 && legal @assert Base.isconcretetype(typ) seen = false lasti = 1 for i = 1:fieldcount(typ) fo = fieldoffset(typ, i) if fieldoffset(typ, i) == offset - offset = nothing + offset = 0 typ = fieldtype(typ, i) if !Base.allocatedinline(typ) if byref != GPUCompiler.BITS_VALUE diff --git a/src/compiler.jl b/src/compiler.jl index 08dc5f05c95..e890bd998da 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -7988,7 +7988,8 @@ function GPUCompiler.codegen( if intr == LLVM.Intrinsic("llvm.memcpy").id || intr == LLVM.Intrinsic("llvm.memmove").id || intr == LLVM.Intrinsic("llvm.memset").id - legal, jTy, byref = abs_typeof(operands(inst)[1]) + base, offset, _ = get_base_and_offset(operands(inst)[1]) + legal, jTy, byref = abs_typeof(base) sz = if intr == LLVM.Intrinsic("llvm.memcpy").id || intr == LLVM.Intrinsic("llvm.memmove").id @@ -8007,8 +8008,9 @@ function GPUCompiler.codegen( any(T2 isa Core.TypeofVararg for T2 in jTy.parameters) ) ) - if isa(sz, LLVM.ConstantInt) && sizeof(jTy) == convert(Int, sz) - md = to_fullmd(jTy) + if offset < sizeof(jTy) && isa(sz, LLVM.ConstantInt) && sizeof(jTy) - offset >= convert(Int, sz) + lim = convert(Int, sz) + md = to_fullmd(jTy, offset, lim) @assert byref == GPUCompiler.BITS_REF || byref == GPUCompiler.MUT_REF metadata(inst)["enzyme_truetype"] = md diff --git a/src/typetree.jl b/src/typetree.jl index 8ddce070b26..c96d41fb2b1 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -137,9 +137,28 @@ function get_offsets(@nospecialize(T::Type)) return results end -function to_fullmd(@nospecialize(T::Type)) +function to_fullmd(@nospecialize(T::Type), offset::Int, lim::Int) mds = LLVM.Metadata[] - for (sT, sO) in get_offsets(T) + offs = get_offsets(T) + + minoff = -1 + for (sT, sO) in offs + if sO >= offset + if sO == offset + minOff = sO + end + else + minoff = max(minoff, sO) + end + end + + for (sT, sO) in offs + if sO != minoff && (sO < offset) + continue + end + if sO >= lim + continue + end if sT == API.DT_Pointer push!(mds, LLVM.MDString("Pointer")) elseif sT == API.DT_Integer @@ -155,7 +174,7 @@ function to_fullmd(@nospecialize(T::Type)) else @assert false end - push!(mds, LLVM.Metadata(LLVM.ConstantInt(sO))) + push!(mds, LLVM.Metadata(LLVM.ConstantInt(min(0, sO - offset)))) end return LLVM.MDNode(mds) end From b739cbfb241b6305ae4ca73683d590d23f141ed8 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 28 Sep 2024 15:45:48 -0500 Subject: [PATCH 74/87] Correct offset to use max (#1918) * Correct offset to use max * fix --- src/typetree.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/typetree.jl b/src/typetree.jl index c96d41fb2b1..89fe522670f 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -174,7 +174,7 @@ function to_fullmd(@nospecialize(T::Type), offset::Int, lim::Int) else @assert false end - push!(mds, LLVM.Metadata(LLVM.ConstantInt(min(0, sO - offset)))) + push!(mds, LLVM.Metadata(LLVM.ConstantInt(max(0, sO - offset)))) end return LLVM.MDNode(mds) end From 287b847b2c7c7382da0503c7a6bd810f933def16 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 28 Sep 2024 15:46:04 -0500 Subject: [PATCH 75/87] Fix active reg inner of literal type (#1917) --- src/compiler.jl | 2 +- test/runtests.jl | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index e890bd998da..560ef0fcb31 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -663,7 +663,7 @@ end return AnyState end - if isghostty(T) || Core.Compiler.isconstType(T) + if isghostty(T) || Core.Compiler.isconstType(T) || T <: Type return AnyState end diff --git a/test/runtests.jl b/test/runtests.jl index eb9dfccb6cb..f3b5be421f2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -116,6 +116,7 @@ mutable struct MInts{A, B} end @testset "Internal tests" begin + @assert Enzyme.Compiler.active_reg_inner(Type{Array}, (), nothing) == Enzyme.Compiler.AnyState @assert Enzyme.Compiler.active_reg_inner(Ints{<:Any, Integer}, (), nothing) == Enzyme.Compiler.AnyState @assert Enzyme.Compiler.active_reg_inner(Ints{<:Any, Float64}, (), nothing) == Enzyme.Compiler.DupState @assert Enzyme.Compiler.active_reg_inner(Ints{Integer, <:Any}, (), nothing) == Enzyme.Compiler.DupState From d91151bae770bceb1aa639e211f54e16c913c642 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 28 Sep 2024 16:09:11 -0500 Subject: [PATCH 76/87] Fix limit being relative (#1919) * Fix limit being relative * fix --- src/typetree.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/typetree.jl b/src/typetree.jl index 89fe522670f..c886c683ced 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -145,7 +145,7 @@ function to_fullmd(@nospecialize(T::Type), offset::Int, lim::Int) for (sT, sO) in offs if sO >= offset if sO == offset - minOff = sO + minoff = sO end else minoff = max(minoff, sO) @@ -156,7 +156,7 @@ function to_fullmd(@nospecialize(T::Type), offset::Int, lim::Int) if sO != minoff && (sO < offset) continue end - if sO >= lim + if sO >= lim + offset continue end if sT == API.DT_Pointer From 4ab422baa2f76064219636273135c8854f94e48b Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 28 Sep 2024 20:39:03 -0500 Subject: [PATCH 77/87] Fix memory of float (#1920) --- src/compiler.jl | 8 ++++++++ test/runtests.jl | 4 ++++ 2 files changed, 12 insertions(+) diff --git a/src/compiler.jl b/src/compiler.jl index 560ef0fcb31..1b11da719ec 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -523,6 +523,10 @@ end @inline ptreltype(::Type{IdDict{K,V}}) where {K,V} = V @inline ptreltype(::Type{IdDict{K,V} where K}) where {V} = V @inline ptreltype(::Type{SparseArrays.CHOLMOD.Dense{T}}) where T = T +@static if VERSION < v"1.11-" +else +@inline ptreltype(::Type{Memory{T}}) where T = T +end @inline is_arrayorvararg_ty(::Type) = false @inline is_arrayorvararg_ty(::Type{Array{T,N}}) where {T,N} = true @@ -535,6 +539,10 @@ end @inline is_arrayorvararg_ty(::Type{IdDict{K,V}}) where {K,V} = true @inline is_arrayorvararg_ty(::Type{IdDict{K,V} where K}) where {V} = true @inline is_arrayorvararg_ty(::Type{SparseArrays.CHOLMOD.Dense{T}}) where T = true +@static if VERSION < v"1.11-" +else +@inline is_arrayorvararg_ty(::Type{Memory{T}}) where T = true +end @inline function datatype_fieldcount(t::Type{T}) where {T} return Base.datatype_fieldcount(t) diff --git a/test/runtests.jl b/test/runtests.jl index f3b5be421f2..b4aab36752e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -116,6 +116,10 @@ mutable struct MInts{A, B} end @testset "Internal tests" begin + @static if VERSION < v"1.11-" + else + @assert Enzyme.Compiler.active_reg_inner(Memory{Float64}, (), nothing) == Enzyme.Compiler.DupState + end @assert Enzyme.Compiler.active_reg_inner(Type{Array}, (), nothing) == Enzyme.Compiler.AnyState @assert Enzyme.Compiler.active_reg_inner(Ints{<:Any, Integer}, (), nothing) == Enzyme.Compiler.AnyState @assert Enzyme.Compiler.active_reg_inner(Ints{<:Any, Float64}, (), nothing) == Enzyme.Compiler.DupState From b968cfe5a54bf5b9285217fd476605eaae54ed4a Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 28 Sep 2024 23:52:01 -0500 Subject: [PATCH 78/87] Attempt to fix apple (#1834) --- test/runtests.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index b4aab36752e..e720ba46c41 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -364,6 +364,9 @@ make3() = (1.0, 2.0, 3.0) @test autodiff(Forward, tanh, Duplicated(1.0f0, 1.0f0))[1] ≈ Float32(0.41997434161402606939) for T in (Float64, Float32, Float16) + if T == Float16 && Sys.isapple() + continue + end res = autodiff(Reverse, tanh, Active, Active(T(1)))[1][1] @test res isa T cmp = if T == Float64 From a16f41a68fc8a19e9d70a2b57d349bfd53252063 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 29 Sep 2024 01:30:54 -0500 Subject: [PATCH 79/87] Bump jll (#1921) --- Project.toml | 2 +- test/runtests.jl | 30 ++++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 3b8ddd20602..21d0cab215a 100644 --- a/Project.toml +++ b/Project.toml @@ -36,7 +36,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.8.4" -Enzyme_jll = "0.0.151" +Enzyme_jll = "0.0.152" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27" LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" diff --git a/test/runtests.jl b/test/runtests.jl index e720ba46c41..ba4462bc235 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3557,6 +3557,36 @@ end @test din[2, 1] ≈ 1.0 end +@testset "View Vars" begin + + x = [Float32(0.25)] + dx = [Float32(0.0)] + rng = Base.UnitRange{Int64}(1, 0) + + f = Const(Base.SubArray{T, N, P, I, L} where L where I where P where N where T) + a1 = Const(Base.IndexLinear()) + a2 = Duplicated(x, dx) + a3 = Const((rng,)) + a4 = Const((true,)) + + fwd, rev = autodiff_thunk(ReverseSplitWithPrimal, + typeof(f), + Duplicated, + typeof(a1), + typeof(a2), + typeof(a3), + typeof(a4) + ) + + res = fwd(f,a1,a2,a3,a4) + @test res[2].indices == (rng,) + @test res[3].indices == (rng,) + @test res[2].offset1 == 0 + @test res[3].offset1 == 0 + @test res[2].stride1 == 1 + @test res[3].stride1 == 1 +end + @testset "Uncached batch sizes" begin genericsin(x) = Base.invokelatest(sin, x) res = Enzyme.autodiff(Forward, genericsin, BatchDuplicated(2.0, NTuple{10,Float64}((Float64(i) for i in 1:10))))[1] From f14ad34dfde2323d50b0339bec2e88a2bb729aee Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 29 Sep 2024 13:21:51 -0500 Subject: [PATCH 80/87] Bump jll (#1922) --- Project.toml | 4 ++-- test/runtests.jl | 27 +++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 21d0cab215a..87ab945958f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.4" +version = "0.13.5" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -36,7 +36,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.8.4" -Enzyme_jll = "0.0.152" +Enzyme_jll = "0.0.153" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27" LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" diff --git a/test/runtests.jl b/test/runtests.jl index ba4462bc235..902b9e4f652 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3585,6 +3585,33 @@ end @test res[3].offset1 == 0 @test res[2].stride1 == 1 @test res[3].stride1 == 1 + + x = [Float32(0.25)] + dx = [Float32(0.0)] + rng = Base.UnitRange{Int64}(1, 0) + + f = Const(Base.SubArray{T, N, P, I, L} where L where I where P where N where T) + a1 = Const(Base.IndexLinear()) + a2 = Duplicated(x, dx) + a3 = Const((rng,)) + a4 = Const((true,)) + + fwd, rev = autodiff_thunk(set_runtime_activity(ReverseSplitWithPrimal), + typeof(f), + Duplicated, + typeof(a1), + typeof(a2), + typeof(a3), + typeof(a4) + ) + + res = fwd(f,a1,a2,a3,a4) + @test res[2].indices == (rng,) + @test res[3].indices == (rng,) + @test res[2].offset1 == 0 + @test res[3].offset1 == 0 + @test res[2].stride1 == 1 + @test res[3].stride1 == 1 end @testset "Uncached batch sizes" begin From c9eae5b83f9053186111573da57f8e7ef3ffc947 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 29 Sep 2024 18:56:50 -0500 Subject: [PATCH 81/87] Fix make_zero on constant fields (#1926) * Fix make_zero on constant fields * type --- src/compiler.jl | 400 +--------------------------------------------- src/make_zero.jl | 404 +++++++++++++++++++++++++++++++++++++++++++++++ test/abi.jl | 13 ++ 3 files changed, 418 insertions(+), 399 deletions(-) create mode 100644 src/make_zero.jl diff --git a/src/compiler.jl b/src/compiler.jl index 1b11da719ec..16d54481cf8 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1803,405 +1803,7 @@ function allocate_sret!(gutils::API.EnzymeGradientUtilsRef, N) allocate_sret!(B, N) end -@inline function EnzymeCore.make_zero(x::FT)::FT where {FT<:AbstractFloat} - return Base.zero(x) -end -@inline function EnzymeCore.make_zero(x::Complex{FT})::Complex{FT} where {FT<:AbstractFloat} - return Base.zero(x) -end -@inline function EnzymeCore.make_zero( - x::Array{FT,N}, -)::Array{FT,N} where {FT<:AbstractFloat,N} - return Base.zero(x) -end -@inline function EnzymeCore.make_zero( - x::Array{Complex{FT},N}, -)::Array{Complex{FT},N} where {FT<:AbstractFloat,N} - return Base.zero(x) -end - -@inline function EnzymeCore.make_zero( - ::Type{Array{FT,N}}, - seen::IdDict, - prev::Array{FT,N}, - ::Val{copy_if_inactive} = Val(false), -)::Array{FT,N} where {copy_if_inactive,FT<:AbstractFloat,N} - if haskey(seen, prev) - return seen[prev] - end - newa = Base.zero(prev) - seen[prev] = newa - return newa -end -@inline function EnzymeCore.make_zero( - ::Type{Array{Complex{FT},N}}, - seen::IdDict, - prev::Array{Complex{FT},N}, - ::Val{copy_if_inactive} = Val(false), -)::Array{Complex{FT},N} where {copy_if_inactive,FT<:AbstractFloat,N} - if haskey(seen, prev) - return seen[prev] - end - newa = Base.zero(prev) - seen[prev] = newa - return newa -end - -@inline function EnzymeCore.make_zero( - ::Type{RT}, - seen::IdDict, - prev::RT, - ::Val{copy_if_inactive} = Val(false), -)::RT where {copy_if_inactive,RT<:AbstractFloat} - return RT(0) -end - -@inline function EnzymeCore.make_zero( - ::Type{Complex{RT}}, - seen::IdDict, - prev::Complex{RT}, - ::Val{copy_if_inactive} = Val(false), -)::Complex{RT} where {copy_if_inactive,RT<:AbstractFloat} - return RT(0) -end - -@inline function EnzymeCore.make_zero( - ::Type{RT}, - seen::IdDict, - prev::RT, - ::Val{copy_if_inactive} = Val(false), -)::RT where {copy_if_inactive,RT<:Array} - if haskey(seen, prev) - return seen[prev] - end - if guaranteed_const_nongen(RT, nothing) - return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev - end - newa = RT(undef, size(prev)) - seen[prev] = newa - for I in eachindex(prev) - if isassigned(prev, I) - pv = prev[I] - innerty = Core.Typeof(pv) - @inbounds newa[I] = - EnzymeCore.make_zero(innerty, seen, pv, Val(copy_if_inactive)) - end - end - return newa -end - -@inline function EnzymeCore.make_zero( - ::Type{RT}, - seen::IdDict, - prev::RT, - ::Val{copy_if_inactive} = Val(false), -)::RT where {copy_if_inactive,RT<:Tuple} - return ntuple(length(prev)) do i - Base.@_inline_meta - EnzymeCore.make_zero(RT.parameters[i], seen, prev[i], Val(copy_if_inactive)) - end -end - -@inline function EnzymeCore.make_zero( - ::Type{NamedTuple{A,RT}}, - seen::IdDict, - prev::NamedTuple{A,RT}, - ::Val{copy_if_inactive} = Val(false), -)::NamedTuple{A,RT} where {copy_if_inactive,A,RT} - return NamedTuple{A,RT}(EnzymeCore.make_zero(RT, seen, RT(prev), Val(copy_if_inactive))) -end - -@inline function EnzymeCore.make_zero( - ::Type{Core.Box}, - seen::IdDict, - prev::Core.Box, - ::Val{copy_if_inactive} = Val(false), -) where {copy_if_inactive} - if haskey(seen, prev) - return seen[prev] - end - prev2 = prev.contents - res = Core.Box() - seen[prev] = res - res.contents = Base.Ref( - EnzymeCore.make_zero(Core.Typeof(prev2), seen, prev2, Val(copy_if_inactive)), - ) - return res -end - -@inline function EnzymeCore.make_zero( - ::Type{RT}, - seen::IdDict, - prev::RT, - ::Val{copy_if_inactive} = Val(false), -)::RT where {copy_if_inactive,RT} - if guaranteed_const_nongen(RT, nothing) - return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev - end - if haskey(seen, prev) - return seen[prev] - end - @assert !Base.isabstracttype(RT) - @assert Base.isconcretetype(RT) - nf = fieldcount(RT) - - if ismutable(prev) - y = ccall(:jl_new_struct_uninit, Any, (Any,), RT) - seen[prev] = y - for i = 1:nf - if isdefined(prev, i) - xi = getfield(prev, i) - T = Core.Typeof(xi) - xi = EnzymeCore.make_zero(T, seen, xi, Val(copy_if_inactive)) - setfield!(y, i, xi) - end - end - return y - end - - if nf == 0 - return prev - end - - flds = Vector{Any}(undef, nf) - for i = 1:nf - if isdefined(prev, i) - xi = getfield(prev, i) - xi = EnzymeCore.make_zero(Core.Typeof(xi), seen, xi, Val(copy_if_inactive)) - flds[i] = xi - else - nf = i - 1 # rest of tail must be undefined values - break - end - end - y = ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), RT, flds, nf) - seen[prev] = y - return y -end - -function make_zero_immutable!(prev::T, seen::S)::T where {T<:AbstractFloat,S} - zero(T) -end - -function make_zero_immutable!( - prev::Complex{T}, - seen::S, -)::Complex{T} where {T<:AbstractFloat,S} - zero(T) -end - -function make_zero_immutable!(prev::T, seen::S)::T where {T<:Tuple,S} - ntuple(Val(length(T.parameters))) do i - Base.@_inline_meta - make_zero_immutable!(prev[i], seen) - end -end - -function make_zero_immutable!(prev::NamedTuple{a,b}, seen::S)::NamedTuple{a,b} where {a,b,S} - NamedTuple{a,b}(ntuple(Val(length(T.parameters))) do i - Base.@_inline_meta - make_zero_immutable!(prev[a[i]], seen) - end) -end - - -function make_zero_immutable!(prev::T, seen::S)::T where {T,S} - if guaranteed_const_nongen(T, nothing) - return prev - end - @assert !ismutable(prev) - - RT = Core.Typeof(prev) - @assert !Base.isabstracttype(RT) - @assert Base.isconcretetype(RT) - nf = fieldcount(RT) - - flds = Vector{Any}(undef, nf) - for i = 1:nf - if isdefined(prev, i) - xi = getfield(prev, i) - ST = Core.Typeof(xi) - flds[i] = if active_reg_inner(ST, (), nothing, Val(true)) == ActiveState #=justActive=# - make_zero_immutable!(xi, seen) - else - EnzymeCore.make_zero!(xi, seen) - xi - end - else - nf = i - 1 # rest of tail must be undefined values - break - end - end - ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), RT, flds, nf)::T -end - -@inline function EnzymeCore.make_zero!( - prev::Base.RefValue{T}, - seen::ST, -)::Nothing where {T<:AbstractFloat,ST} - T[] = zero(T) - nothing -end - -@inline function EnzymeCore.make_zero!( - prev::Base.RefValue{Complex{T}}, - seen::ST, -)::Nothing where {T<:AbstractFloat,ST} - T[] = zero(Complex{T}) - nothing -end - -@inline function EnzymeCore.make_zero!( - prev::Array{T,N}, - seen::ST, -)::Nothing where {T<:AbstractFloat,N,ST} - fill!(prev, zero(T)) - nothing -end - -@inline function EnzymeCore.make_zero!( - prev::Array{Complex{T},N}, - seen::ST, -)::Nothing where {T<:AbstractFloat,N,ST} - fill!(prev, zero(Complex{T})) - nothing -end - -@inline function EnzymeCore.make_zero!( - prev::Base.RefValue{T}, -)::Nothing where {T<:AbstractFloat} - EnzymeCore.make_zero!(prev, nothing) - nothing -end - -@inline function EnzymeCore.make_zero!( - prev::Base.RefValue{Complex{T}}, -)::Nothing where {T<:AbstractFloat} - EnzymeCore.make_zero!(prev, nothing) - nothing -end - -@inline function EnzymeCore.make_zero!(prev::Array{T,N})::Nothing where {T<:AbstractFloat,N} - EnzymeCore.make_zero!(prev, nothing) - nothing -end - -@inline function EnzymeCore.make_zero!( - prev::Array{Complex{T},N}, -)::Nothing where {T<:AbstractFloat,N} - EnzymeCore.make_zero!(prev, nothing) - nothing -end - -@inline function EnzymeCore.make_zero!(prev::Array{T,N}, seen::ST)::Nothing where {T,N,ST} - if guaranteed_const_nongen(T, nothing) - return - end - if in(seen, prev) - return - end - push!(seen, prev) - - for I in eachindex(prev) - if isassigned(prev, I) - pv = prev[I] - SBT = Core.Typeof(pv) - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# - @inbounds prev[I] = make_zero_immutable!(pv, seen) - nothing - else - EnzymeCore.make_zero!(pv, seen) - nothing - end - end - end - nothing -end - -@inline function EnzymeCore.make_zero!( - prev::Base.RefValue{T}, - seen::ST, -)::Nothing where {T,ST} - if guaranteed_const_nongen(T, nothing) - return - end - if in(seen, prev) - return - end - push!(seen, prev) - - pv = prev[] - SBT = Core.Typeof(pv) - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# - prev[] = make_zero_immutable!(pv, seen) - nothing - else - EnzymeCore.make_zero!(pv, seen) - nothing - end - nothing -end - -@inline function EnzymeCore.make_zero!(prev::Core.Box, seen::ST)::Nothing where {ST} - pv = prev.contents - T = Core.Typeof(pv) - if guaranteed_const_nongen(T, nothing) - return - end - if in(seen, prev) - return - end - push!(seen, prev) - SBT = Core.Typeof(pv) - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# - prev.contents = EnzymeCore.make_zero_immutable!(pv, seen) - nothing - else - EnzymeCore.make_zero!(pv, seen) - nothing - end - nothing -end - -@inline function EnzymeCore.make_zero!( - prev::T, - seen::S = Base.IdSet{Any}(), -)::Nothing where {T,S} - if guaranteed_const_nongen(T, nothing) - return - end - if in(prev, seen) - return - end - @assert !Base.isabstracttype(T) - @assert Base.isconcretetype(T) - nf = fieldcount(T) - - - if nf == 0 - return - end - - push!(seen, prev) - - for i = 1:nf - if isdefined(prev, i) - xi = getfield(prev, i) - SBT = Core.Typeof(xi) - if guaranteed_const_nongen(SBT, nothing) - continue - end - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# - setfield!(prev, i, make_zero_immutable!(xi, seen)) - nothing - else - EnzymeCore.make_zero!(xi, seen) - nothing - end - end - end - return -end +include("make_zero.jl") function emit_error(B::LLVM.IRBuilder, orig, string, errty = EnzymeRuntimeException) curent_bb = position(B) diff --git a/src/make_zero.jl b/src/make_zero.jl new file mode 100644 index 00000000000..4f627581ea6 --- /dev/null +++ b/src/make_zero.jl @@ -0,0 +1,404 @@ + +@inline function EnzymeCore.make_zero(x::FT)::FT where {FT<:AbstractFloat} + return Base.zero(x) +end +@inline function EnzymeCore.make_zero(x::Complex{FT})::Complex{FT} where {FT<:AbstractFloat} + return Base.zero(x) +end +@inline function EnzymeCore.make_zero( + x::Array{FT,N}, +)::Array{FT,N} where {FT<:AbstractFloat,N} + return Base.zero(x) +end +@inline function EnzymeCore.make_zero( + x::Array{Complex{FT},N}, +)::Array{Complex{FT},N} where {FT<:AbstractFloat,N} + return Base.zero(x) +end + +@inline function EnzymeCore.make_zero( + ::Type{Array{FT,N}}, + seen::IdDict, + prev::Array{FT,N}, + ::Val{copy_if_inactive} = Val(false), +)::Array{FT,N} where {copy_if_inactive,FT<:AbstractFloat,N} + if haskey(seen, prev) + return seen[prev] + end + newa = Base.zero(prev) + seen[prev] = newa + return newa +end +@inline function EnzymeCore.make_zero( + ::Type{Array{Complex{FT},N}}, + seen::IdDict, + prev::Array{Complex{FT},N}, + ::Val{copy_if_inactive} = Val(false), +)::Array{Complex{FT},N} where {copy_if_inactive,FT<:AbstractFloat,N} + if haskey(seen, prev) + return seen[prev] + end + newa = Base.zero(prev) + seen[prev] = newa + return newa +end + +@inline function EnzymeCore.make_zero( + ::Type{RT}, + seen::IdDict, + prev::RT, + ::Val{copy_if_inactive} = Val(false), +)::RT where {copy_if_inactive,RT<:AbstractFloat} + return RT(0) +end + +@inline function EnzymeCore.make_zero( + ::Type{Complex{RT}}, + seen::IdDict, + prev::Complex{RT}, + ::Val{copy_if_inactive} = Val(false), +)::Complex{RT} where {copy_if_inactive,RT<:AbstractFloat} + return RT(0) +end + +@inline function EnzymeCore.make_zero( + ::Type{RT}, + seen::IdDict, + prev::RT, + ::Val{copy_if_inactive} = Val(false), +)::RT where {copy_if_inactive,RT<:Array} + if haskey(seen, prev) + return seen[prev] + end + if guaranteed_const_nongen(RT, nothing) + return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev + end + newa = RT(undef, size(prev)) + seen[prev] = newa + for I in eachindex(prev) + if isassigned(prev, I) + pv = prev[I] + innerty = Core.Typeof(pv) + @inbounds newa[I] = + EnzymeCore.make_zero(innerty, seen, pv, Val(copy_if_inactive)) + end + end + return newa +end + +@inline function EnzymeCore.make_zero( + ::Type{RT}, + seen::IdDict, + prev::RT, + ::Val{copy_if_inactive} = Val(false), +)::RT where {copy_if_inactive,RT<:Tuple} + return ntuple(length(prev)) do i + Base.@_inline_meta + EnzymeCore.make_zero(RT.parameters[i], seen, prev[i], Val(copy_if_inactive)) + end +end + +@inline function EnzymeCore.make_zero( + ::Type{NamedTuple{A,RT}}, + seen::IdDict, + prev::NamedTuple{A,RT}, + ::Val{copy_if_inactive} = Val(false), +)::NamedTuple{A,RT} where {copy_if_inactive,A,RT} + return NamedTuple{A,RT}(EnzymeCore.make_zero(RT, seen, RT(prev), Val(copy_if_inactive))) +end + +@inline function EnzymeCore.make_zero( + ::Type{Core.Box}, + seen::IdDict, + prev::Core.Box, + ::Val{copy_if_inactive} = Val(false), +) where {copy_if_inactive} + if haskey(seen, prev) + return seen[prev] + end + prev2 = prev.contents + res = Core.Box() + seen[prev] = res + res.contents = Base.Ref( + EnzymeCore.make_zero(Core.Typeof(prev2), seen, prev2, Val(copy_if_inactive)), + ) + return res +end + +@inline function EnzymeCore.make_zero( + ::Type{RT}, + seen::IdDict, + prev::RT, + ::Val{copy_if_inactive} = Val(false), +)::RT where {copy_if_inactive,RT} + if guaranteed_const_nongen(RT, nothing) + return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev + end + if haskey(seen, prev) + return seen[prev] + end + @assert !Base.isabstracttype(RT) + @assert Base.isconcretetype(RT) + nf = fieldcount(RT) + + if ismutable(prev) + y = ccall(:jl_new_struct_uninit, Any, (Any,), RT)::RT + seen[prev] = y + for i = 1:nf + if isdefined(prev, i) + xi = getfield(prev, i) + T = Core.Typeof(xi) + xi = EnzymeCore.make_zero(T, seen, xi, Val(copy_if_inactive)) + if Base.isconst(RT, i) + ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), y, i-1, xi) + else + setfield!(y, i, xi) + end + end + end + return y + end + + if nf == 0 + return prev + end + + flds = Vector{Any}(undef, nf) + for i = 1:nf + if isdefined(prev, i) + xi = getfield(prev, i) + xi = EnzymeCore.make_zero(Core.Typeof(xi), seen, xi, Val(copy_if_inactive)) + flds[i] = xi + else + nf = i - 1 # rest of tail must be undefined values + break + end + end + y = ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), RT, flds, nf) + seen[prev] = y + return y +end + +function make_zero_immutable!(prev::T, seen::S)::T where {T<:AbstractFloat,S} + zero(T) +end + +function make_zero_immutable!( + prev::Complex{T}, + seen::S, +)::Complex{T} where {T<:AbstractFloat,S} + zero(T) +end + +function make_zero_immutable!(prev::T, seen::S)::T where {T<:Tuple,S} + ntuple(Val(length(T.parameters))) do i + Base.@_inline_meta + make_zero_immutable!(prev[i], seen) + end +end + +function make_zero_immutable!(prev::NamedTuple{a,b}, seen::S)::NamedTuple{a,b} where {a,b,S} + NamedTuple{a,b}(ntuple(Val(length(T.parameters))) do i + Base.@_inline_meta + make_zero_immutable!(prev[a[i]], seen) + end) +end + + +function make_zero_immutable!(prev::T, seen::S)::T where {T,S} + if guaranteed_const_nongen(T, nothing) + return prev + end + @assert !ismutable(prev) + + RT = Core.Typeof(prev) + @assert !Base.isabstracttype(RT) + @assert Base.isconcretetype(RT) + nf = fieldcount(RT) + + flds = Vector{Any}(undef, nf) + for i = 1:nf + if isdefined(prev, i) + xi = getfield(prev, i) + ST = Core.Typeof(xi) + flds[i] = if active_reg_inner(ST, (), nothing, Val(true)) == ActiveState #=justActive=# + make_zero_immutable!(xi, seen) + else + EnzymeCore.make_zero!(xi, seen) + xi + end + else + nf = i - 1 # rest of tail must be undefined values + break + end + end + ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), RT, flds, nf)::T +end + +@inline function EnzymeCore.make_zero!( + prev::Base.RefValue{T}, + seen::ST, +)::Nothing where {T<:AbstractFloat,ST} + T[] = zero(T) + nothing +end + +@inline function EnzymeCore.make_zero!( + prev::Base.RefValue{Complex{T}}, + seen::ST, +)::Nothing where {T<:AbstractFloat,ST} + T[] = zero(Complex{T}) + nothing +end + +@inline function EnzymeCore.make_zero!( + prev::Array{T,N}, + seen::ST, +)::Nothing where {T<:AbstractFloat,N,ST} + fill!(prev, zero(T)) + nothing +end + +@inline function EnzymeCore.make_zero!( + prev::Array{Complex{T},N}, + seen::ST, +)::Nothing where {T<:AbstractFloat,N,ST} + fill!(prev, zero(Complex{T})) + nothing +end + +@inline function EnzymeCore.make_zero!( + prev::Base.RefValue{T}, +)::Nothing where {T<:AbstractFloat} + EnzymeCore.make_zero!(prev, nothing) + nothing +end + +@inline function EnzymeCore.make_zero!( + prev::Base.RefValue{Complex{T}}, +)::Nothing where {T<:AbstractFloat} + EnzymeCore.make_zero!(prev, nothing) + nothing +end + +@inline function EnzymeCore.make_zero!(prev::Array{T,N})::Nothing where {T<:AbstractFloat,N} + EnzymeCore.make_zero!(prev, nothing) + nothing +end + +@inline function EnzymeCore.make_zero!( + prev::Array{Complex{T},N}, +)::Nothing where {T<:AbstractFloat,N} + EnzymeCore.make_zero!(prev, nothing) + nothing +end + +@inline function EnzymeCore.make_zero!(prev::Array{T,N}, seen::ST)::Nothing where {T,N,ST} + if guaranteed_const_nongen(T, nothing) + return + end + if in(seen, prev) + return + end + push!(seen, prev) + + for I in eachindex(prev) + if isassigned(prev, I) + pv = prev[I] + SBT = Core.Typeof(pv) + if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# + @inbounds prev[I] = make_zero_immutable!(pv, seen) + nothing + else + EnzymeCore.make_zero!(pv, seen) + nothing + end + end + end + nothing +end + +@inline function EnzymeCore.make_zero!( + prev::Base.RefValue{T}, + seen::ST, +)::Nothing where {T,ST} + if guaranteed_const_nongen(T, nothing) + return + end + if in(seen, prev) + return + end + push!(seen, prev) + + pv = prev[] + SBT = Core.Typeof(pv) + if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# + prev[] = make_zero_immutable!(pv, seen) + nothing + else + EnzymeCore.make_zero!(pv, seen) + nothing + end + nothing +end + +@inline function EnzymeCore.make_zero!(prev::Core.Box, seen::ST)::Nothing where {ST} + pv = prev.contents + T = Core.Typeof(pv) + if guaranteed_const_nongen(T, nothing) + return + end + if in(seen, prev) + return + end + push!(seen, prev) + SBT = Core.Typeof(pv) + if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# + prev.contents = EnzymeCore.make_zero_immutable!(pv, seen) + nothing + else + EnzymeCore.make_zero!(pv, seen) + nothing + end + nothing +end + +@inline function EnzymeCore.make_zero!( + prev::T, + seen::S = Base.IdSet{Any}(), +)::Nothing where {T,S} + if guaranteed_const_nongen(T, nothing) + return + end + if in(prev, seen) + return + end + @assert !Base.isabstracttype(T) + @assert Base.isconcretetype(T) + nf = fieldcount(T) + + + if nf == 0 + return + end + + push!(seen, prev) + + for i = 1:nf + if isdefined(prev, i) + xi = getfield(prev, i) + SBT = Core.Typeof(xi) + if guaranteed_const_nongen(SBT, nothing) + continue + end + if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# + setfield!(prev, i, make_zero_immutable!(xi, seen)) + nothing + else + EnzymeCore.make_zero!(xi, seen) + nothing + end + end + end + return +end diff --git a/test/abi.jl b/test/abi.jl index 5acb30e04fb..7a7917553f1 100644 --- a/test/abi.jl +++ b/test/abi.jl @@ -480,6 +480,19 @@ mulsin(x) = sin(x[1] * x[2]) @test Enzyme.autodiff(ForwardWithPrimal, () -> Enzyme.within_autodiff())[1] end +mutable struct ConstVal + x::Float64 + const y::Float64 +end + +@testset "Make Zero" begin + v = ConstVal(2.0, 3.0) + dv = make_zero(v) + @test dv isa ConstVal + @test dv.x ≈ 0.0 + @test dv.y ≈ 0.0 +end + @testset "Type inference" begin x = ones(10) @inferred autodiff(Enzyme.Reverse, abssum, Duplicated(x,x)) From 288a419d464bd423218e05e6996463d8c98bce42 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 29 Sep 2024 19:45:09 -0500 Subject: [PATCH 82/87] Fix pass manager bug which allows functions to be deleted and replaced (#1924) * Fix pass manager bug which allows functions to be deleted and replaced * fix --- Project.toml | 2 +- src/compiler/optimize.jl | 41 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 87ab945958f..b70795fca1a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.5" +version = "0.13.6" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index d11daaa0b3e..cc143ce4f29 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -2410,8 +2410,35 @@ function optimize!(mod::LLVM.Module, tm) mem_cpy_opt!(pm) always_inliner!(pm) alloc_opt_tm!(pm, tm) + LLVM.run!(pm, mod) + end + + # Globalopt is separated as it can delete functions, which invalidates the Julia hardcoded pointers to + # known functions + ModulePassManager() do pm + + add_library_info!(pm, triple(mod)) + add_transform_info!(pm, tm) + + scoped_no_alias_aa!(pm) + type_based_alias_analysis!(pm) + basic_alias_analysis!(pm) + cpu_features_tm!(pm, tm) + LLVM.API.LLVMAddGlobalOptimizerPass(pm) # Extra gvn!(pm) # Extra + LLVM.run!(pm, mod) + end + + ModulePassManager() do pm + add_library_info!(pm, triple(mod)) + add_transform_info!(pm, tm) + + scoped_no_alias_aa!(pm) + type_based_alias_analysis!(pm) + basic_alias_analysis!(pm) + cpu_features_tm!(pm, tm) + instruction_combining!(pm) jl_inst_simplify!(pm) cfgsimplification!(pm) @@ -2473,6 +2500,20 @@ function optimize!(mod::LLVM.Module, tm) cfgsimplification!(pm) instruction_combining!(pm) # Extra for Enzyme jl_inst_simplify!(pm) + LLVM.run!(pm, mod) + end + + # Globalopt is separated as it can delete functions, which invalidates the Julia hardcoded pointers to + # known functions + ModulePassManager() do pm + add_library_info!(pm, triple(mod)) + add_transform_info!(pm, tm) + + scoped_no_alias_aa!(pm) + type_based_alias_analysis!(pm) + basic_alias_analysis!(pm) + cpu_features_tm!(pm, tm) + LLVM.API.LLVMAddGlobalOptimizerPass(pm) # Exxtra gvn!(pm) # Exxtra LLVM.run!(pm, mod) From dad67bfc3913f4eb66126d7c186a59b0c1f18586 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 29 Sep 2024 22:39:20 -0500 Subject: [PATCH 83/87] Union member type info (#1927) * Union member type info * fix * fix --- src/typetree.jl | 7 ++++++- test/typetree.jl | 35 +++++++++++++++++++++++++---------- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/src/typetree.jl b/src/typetree.jl index c886c683ced..61d700acb8c 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -353,12 +353,17 @@ function typetree_inner(@nospecialize(T::Type), ctx, dl, seen::TypeTreeTable) for f = 1:fieldcount(T) offset = fieldoffset(T, f) subT = fieldtype(T, f) - subtree = copy(typetree(subT, ctx, dl, seen)) if subT isa UnionAll || subT isa Union || subT == Union{} + if !allocatedinline(subT) + subtree = TypeTree(API.DT_Pointer, offset, ctx) + merge!(tt, subtree) + end # FIXME: Handle union continue end + + subtree = copy(typetree(subT, ctx, dl, seen)) # Allocated inline so adjust first path if allocatedinline(subT) diff --git a/test/typetree.jl b/test/typetree.jl index 1a869d66878..3b47161f621 100644 --- a/test/typetree.jl +++ b/test/typetree.jl @@ -37,6 +37,12 @@ struct Sibling2{T} b::T end +struct UnionMember + a::Float32 + b::Union{Function, Number} + c::Bool +end + @testset "TypeTree" begin @test tt(Float16) == "{[-1]:Float@half}" @test tt(Float32) == "{[-1]:Float@float}" @@ -55,28 +61,31 @@ end @test at2.z == 0.0 @test at2.type == 4 + if Sys.WORD_SIZE == 64 - @test tt(LList2{Float64}) == "{[8]:Float@double}" - @test tt(Sibling{LList2{Float64}}) == "{[-1]:Pointer, [-1,8]:Float@double}" + @test tt(UnionMember) == "{[0]:Float@float, [8]:Pointer, [16]:Integer}" + @test tt(LList2{Float64}) == "{[0]:Pointer, [8]:Float@double}" + @test tt(Sibling{LList2{Float64}}) == "{[-1]:Pointer, [-1,0]:Pointer, [-1,8]:Float@double}" @test tt(Sibling2{LList2{Float64}}) == - "{[0]:Pointer, [0,8]:Float@double, [8]:Integer, [16]:Pointer, [16,8]:Float@double}" + "{[0]:Pointer, [0,0]:Pointer, [0,8]:Float@double, [8]:Integer, [16]:Pointer, [16,0]:Pointer, [16,8]:Float@double}" @test tt(Sibling{Tuple{Int,Float64}}) == "{[0]:Integer, [1]:Integer, [2]:Integer, [3]:Integer, [4]:Integer, [5]:Integer, [6]:Integer, [7]:Integer, [8]:Float@double, [16]:Integer, [17]:Integer, [18]:Integer, [19]:Integer, [20]:Integer, [21]:Integer, [22]:Integer, [23]:Integer, [24]:Float@double}" @test tt(Sibling{LList2{Tuple{Int,Float64}}}) == - "{[-1]:Pointer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Float@double}" + "{[-1]:Pointer, [-1,0]:Pointer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Float@double}" @test tt(Sibling2{Sibling2{LList2{Tuple{Float32,Float64}}}}) == - "{[0]:Pointer, [0,8]:Float@float, [0,16]:Float@double, [8]:Integer, [16]:Pointer, [16,8]:Float@float, [16,16]:Float@double, [24]:Integer, [32]:Pointer, [32,8]:Float@float, [32,16]:Float@double, [40]:Integer, [48]:Pointer, [48,8]:Float@float, [48,16]:Float@double}" + "{[0]:Pointer, [0,0]:Pointer, [0,8]:Float@float, [0,16]:Float@double, [8]:Integer, [16]:Pointer, [16,0]:Pointer, [16,8]:Float@float, [16,16]:Float@double, [24]:Integer, [32]:Pointer, [32,0]:Pointer, [32,8]:Float@float, [32,16]:Float@double, [40]:Integer, [48]:Pointer, [48,0]:Pointer, [48,8]:Float@float, [48,16]:Float@double}" else - @test tt(LList2{Float64}) == "{[4]:Float@double}" - @test tt(Sibling{LList2{Float64}}) == "{[-1]:Pointer, [-1,4]:Float@double}" + @test tt(UnionMember) == "{[0]:Float@float, [4]:Pointer, [8]:Integer}" + @test tt(LList2{Float64}) == "{[0]:Pointer, [4]:Float@double}" + @test tt(Sibling{LList2{Float64}}) == "{[-1]:Pointer, [-1,0]:Pointer, [-1,4]:Float@double}" @test tt(Sibling2{LList2{Float64}}) == - "{[0]:Pointer, [0,4]:Float@double, [4]:Integer, [8]:Pointer, [8,4]:Float@double}" + "{[0]:Pointer, [0,0]:Pointer, [0,4]:Float@double, [4]:Integer, [8]:Pointer, [8,0]:Pointer, [8,4]:Float@double}" @test tt(Sibling{Tuple{Int,Float64}}) == "{[0]:Integer, [1]:Integer, [2]:Integer, [3]:Integer, [4]:Float@double, [12]:Integer, [13]:Integer, [14]:Integer, [15]:Integer, [16]:Float@double}" @test tt(Sibling{LList2{Tuple{Int,Float64}}}) == - "{[-1]:Pointer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Float@double}" + "{[-1]:Pointer, [-1,0]:Pointer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Float@double}" @test tt(Sibling2{Sibling2{LList2{Tuple{Float32,Float64}}}}) == - "{[0]:Pointer, [0,4]:Float@float, [0,8]:Float@double, [4]:Integer, [8]:Pointer, [8,4]:Float@float, [8,8]:Float@double, [12]:Integer, [16]:Pointer, [16,4]:Float@float, [16,8]:Float@double, [20]:Integer, [24]:Pointer, [24,4]:Float@float, [24,8]:Float@double}" + "{[0]:Pointer, [0,0]:Pointer, [0,4]:Float@float, [0,8]:Float@double, [4]:Integer, [8]:Pointer, [8,0]:Pointer, [8,4]:Float@float, [8,8]:Float@double, [12]:Integer, [16]:Pointer, [16,0]:Pointer, [16,4]:Float@float, [16,8]:Float@double, [20]:Integer, [24]:Pointer, [24,0]:Pointer, [24,4]:Float@float, [24,8]:Float@double}" end end @@ -91,4 +100,10 @@ end @test Enzyme.get_offsets(Ptr{Float32}) == ((Enzyme.API.DT_Pointer,0),) @test Enzyme.get_offsets(Vector{Float32}) == ((Enzyme.API.DT_Pointer,0),) @test Enzyme.get_offsets(Tuple{Float64, Int}) == [(Enzyme.API.DT_Double,0),(Enzyme.API.DT_Integer, 8)] + + if Sys.WORD_SIZE == 64 + @test Enzyme.get_offsets(UnionMember) == [(Enzyme.API.DT_Float,0),(Enzyme.API.DT_Pointer, 8), (Enzyme.API.DT_Integer, 16)] + else + @test Enzyme.get_offsets(UnionMember) == [(Enzyme.API.DT_Float, 0), (Enzyme.API.DT_Pointer, 4), (Enzyme.API.DT_Integer, 8)] + end end From d97bb83cdfdf55bf0abdbcbdfc6cf61d7f062d01 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 29 Sep 2024 22:39:31 -0500 Subject: [PATCH 84/87] Stabilize global (#1928) --- src/rules/jitrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index f169ab2c4bf..bf98aaf8854 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -458,7 +458,7 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) end internal_tape, origRet, initShadow, annotation = if f isa typeof(Core.getglobal) - gv = Core.getglobal(map(x->x.val, args)...) + gv = Core.getglobal(args[1].val, args[2].val) @assert sizeof(gv) == 0 (nothing, gv, nothing, Const) else From 66ef0f3566fae94f6af8da6f91dae7e36179cba6 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 30 Sep 2024 13:37:50 -0500 Subject: [PATCH 85/87] Fix error exception (#1931) * Fix error exception * fix --------- Co-authored-by: William Moses --- src/Enzyme.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index aa018ea23b8..17a7c6ff5d1 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -682,15 +682,19 @@ code, as well as high-order differentiation. if A isa UnionAll rt = Compiler.primal_return_type(rmode, Val(world), FTy, tt) - A2 = A{rt} + rt = Core.Compiler.return_type(f.val, tt) + A2 = A{rt} + if rt == Union{} + throw(ErrorException("Return type inferred to be Union{}. Giving up.")) + end else @assert A isa DataType rt = A + if rt == Union{} + throw(ErrorException("Return type inferred to be Union{}. Giving up.")) + end end - if rt == Union{} - error("Return type inferred to be Union{}. Giving up.") - end ModifiedBetweenT = falses_from_args(Nargs + 1) ModifiedBetween = Val(ModifiedBetweenT) From 1bc2ce18f0999740afc1b8f409ff370bc1b34dc4 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 30 Sep 2024 13:46:14 -0500 Subject: [PATCH 86/87] Update Project.toml (#1932) --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b70795fca1a..f4a342a6d5a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.6" +version = "0.13.7" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" From 3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e Mon Sep 17 00:00:00 2001 From: Vaibhav Kumar Dixit Date: Wed, 2 Oct 2024 13:40:25 -0400 Subject: [PATCH 87/87] Update DuplicatedNoNeed error message (#1933) * Update DuplicatedNoNeed error message * Update src/Enzyme.jl --- src/Enzyme.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 17a7c6ff5d1..2e7789d6608 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -600,7 +600,7 @@ f(x) = x*x if A <: DuplicatedNoNeed || A <: BatchDuplicatedNoNeed throw( ErrorException( - "Return activity `DuplicatedNoNeed` is no longer now returning or avoiding the primal is passed in for Forward Mode AD.\nPlease use autodiff(Forward, ...) or autodiff(ForwardWithPrimal, ...)", + "`DuplicatedNoNeed` passed in as return activity for Forward Mode AD is no longer returning or avoiding the primal.\nPlease use autodiff(Forward, ...) or autodiff(ForwardWithPrimal, ...)", ), ) end