diff --git a/Project.toml b/Project.toml index 9f17d977..7e27b3ae 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ComponentArrays" uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" authors = ["Jonnie Diegelman <47193959+jonniedie@users.noreply.github.com>"] -version = "0.15.20" +version = "0.15.21" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -18,6 +18,7 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" @@ -28,6 +29,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" ComponentArraysGPUArraysExt = "GPUArrays" ComponentArraysKernelAbstractionsExt = "KernelAbstractions" ComponentArraysOptimisersExt = "Optimisers" +ComponentArraysReactantExt = "Reactant" ComponentArraysRecursiveArrayToolsExt = "RecursiveArrayTools" ComponentArraysReverseDiffExt = "ReverseDiff" ComponentArraysSciMLBaseExt = "SciMLBase" @@ -36,20 +38,21 @@ ComponentArraysZygoteExt = "Zygote" [compat] Adapt = "4.1" -ArrayInterface = "7.10" -ChainRulesCore = "1.24" +ArrayInterface = "7.17.1" +ChainRulesCore = "1.25" ConstructionBase = "1" ForwardDiff = "0.10.36" Functors = "0.4.12, 0.5" -GPUArrays = "10, 11" +GPUArrays = "10.3.1, 11" KernelAbstractions = "0.9.29" LinearAlgebra = "1.10" Optimisers = "0.3, 0.4" +Reactant = "0.2.15" RecursiveArrayTools = "3.8" ReverseDiff = "1.15" SciMLBase = "2" StaticArrayInterface = "1" StaticArraysCore = "1.4" -Tracker = "0.2.34" +Tracker = "0.2.37" Zygote = "0.6.70, 0.7" julia = "1.10" diff --git a/ext/ComponentArraysReactantExt.jl b/ext/ComponentArraysReactantExt.jl new file mode 100644 index 00000000..dbc55e10 --- /dev/null +++ b/ext/ComponentArraysReactantExt.jl @@ -0,0 +1,26 @@ +module ComponentArraysReactantExt + +using ArrayInterface: ArrayInterface +using ComponentArrays, Reactant + +const TracedComponentVector{T} = ComponentVector{ + Reactant.TracedRNumber{T},<:Reactant.TracedRArray{T} +} where {T} + +# Reactant is good at memory management but not great at handling wrapped types. So we avoid +# wrapping types into SubArrays and let Reactant optimize out intermediate allocations. + +@inline function Base.getproperty(x::TracedComponentVector{T}, s::Symbol) where {T} + return getproperty(x, Val(s)) +end + +@inline function Base.getproperty(x::TracedComponentVector{T}, v::Val) where {T} + return ComponentArrays._getindex(Base.getindex, x, v) +end + +function ArrayInterface.restructure(x::ComponentVector, y::TracedComponentVector) + getaxes(x) == getaxes(y) || error("Axes must match") + return y +end + +end diff --git a/src/componentarray.jl b/src/componentarray.jl index a3d3c6b1..41cda164 100644 --- a/src/componentarray.jl +++ b/src/componentarray.jl @@ -66,6 +66,8 @@ end Adapt.adapt_storage(::Type{ComponentArray{T,N,A,Ax}}, xs::AT) where {T,N,A,Ax,AT<:AbstractArray} = Adapt.adapt_storage(A, xs) +Adapt.parent_type(::Type{ComponentArray{T,N,A,Ax}}) where {T,N,A,Ax} = A + # Entry from NamedTuple, Dict, or kwargs ComponentArray{T}(nt::NamedTuple) where T = ComponentArray(make_carray_args(T, nt)...) ComponentArray{T}(::NamedTuple{(), Tuple{}}) where T = ComponentArray(T[], (FlatAxis(),)) diff --git a/test/Project.toml b/test/Project.toml index 33e5bf17..f5a2aa73 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -11,6 +11,7 @@ LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/gpu_tests.jl b/test/gpu_tests.jl index 4a345c47..498b9e3e 100644 --- a/test/gpu_tests.jl +++ b/test/gpu_tests.jl @@ -1,4 +1,4 @@ -using JLArrays +using JLArrays, LinearAlgebra JLArrays.allowscalar(false) @@ -11,7 +11,7 @@ jlca = ComponentArray(jla, Axis(a=1:2, b=3:4)) @test getdata(map(identity, jlca)) isa JLArray @test all(==(0), map(-, jlca, jla)) @test all(map(-, jlca, jlca) .== 0) - @test all(==(0), map(-, jla, jlca)) + @test all(==(0), map(-, jla, jlca)) broken=(pkgversion(JLArrays.GPUArrays) ≥ v"11") @test any(==(1), jlca) @test count(>(2), jlca) == 2 diff --git a/test/reactant_tests.jl b/test/reactant_tests.jl new file mode 100644 index 00000000..62581225 --- /dev/null +++ b/test/reactant_tests.jl @@ -0,0 +1,8 @@ +using Reactant, ComponentArrays + +x = ComponentArray(; a = rand(4), b = rand(2)) +x_ra = Reactant.to_rarray(x) + +fn(x) = x.a .+ sum(abs2, x.b) .+ 1 + +@test @jit(fn(x_ra)) ≈ fn(x) diff --git a/test/runtests.jl b/test/runtests.jl index b4c6fc0e..a46931de 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -732,3 +732,7 @@ end @testset "GPU" begin include("gpu_tests.jl") end + +@testset "Reactant" begin + include("reactant_tests.jl") +end