Skip to content

Commit

Permalink
Merge pull request #292 from SciML/ap/reactant
Browse files Browse the repository at this point in the history
feat: reactant support
  • Loading branch information
ChrisRackauckas authored Jan 6, 2025
2 parents 105eeaf + 21890c3 commit dd0e4c0
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 7 deletions.
13 changes: 8 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ComponentArrays"
uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
authors = ["Jonnie Diegelman <47193959+jonniedie@users.noreply.github.com>"]
version = "0.15.20"
version = "0.15.21"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -18,6 +18,7 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Expand All @@ -28,6 +29,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ComponentArraysGPUArraysExt = "GPUArrays"
ComponentArraysKernelAbstractionsExt = "KernelAbstractions"
ComponentArraysOptimisersExt = "Optimisers"
ComponentArraysReactantExt = "Reactant"
ComponentArraysRecursiveArrayToolsExt = "RecursiveArrayTools"
ComponentArraysReverseDiffExt = "ReverseDiff"
ComponentArraysSciMLBaseExt = "SciMLBase"
Expand All @@ -36,20 +38,21 @@ ComponentArraysZygoteExt = "Zygote"

[compat]
Adapt = "4.1"
ArrayInterface = "7.10"
ChainRulesCore = "1.24"
ArrayInterface = "7.17.1"
ChainRulesCore = "1.25"
ConstructionBase = "1"
ForwardDiff = "0.10.36"
Functors = "0.4.12, 0.5"
GPUArrays = "10, 11"
GPUArrays = "10.3.1, 11"
KernelAbstractions = "0.9.29"
LinearAlgebra = "1.10"
Optimisers = "0.3, 0.4"
Reactant = "0.2.15"
RecursiveArrayTools = "3.8"
ReverseDiff = "1.15"
SciMLBase = "2"
StaticArrayInterface = "1"
StaticArraysCore = "1.4"
Tracker = "0.2.34"
Tracker = "0.2.37"
Zygote = "0.6.70, 0.7"
julia = "1.10"
26 changes: 26 additions & 0 deletions ext/ComponentArraysReactantExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
module ComponentArraysReactantExt

using ArrayInterface: ArrayInterface
using ComponentArrays, Reactant

const TracedComponentVector{T} = ComponentVector{
Reactant.TracedRNumber{T},<:Reactant.TracedRArray{T}
} where {T}

# Reactant is good at memory management but not great at handling wrapped types. So we avoid
# wrapping types into SubArrays and let Reactant optimize out intermediate allocations.

@inline function Base.getproperty(x::TracedComponentVector{T}, s::Symbol) where {T}
return getproperty(x, Val(s))
end

@inline function Base.getproperty(x::TracedComponentVector{T}, v::Val) where {T}
return ComponentArrays._getindex(Base.getindex, x, v)
end

function ArrayInterface.restructure(x::ComponentVector, y::TracedComponentVector)
getaxes(x) == getaxes(y) || error("Axes must match")
return y
end

end
2 changes: 2 additions & 0 deletions src/componentarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ end
Adapt.adapt_storage(::Type{ComponentArray{T,N,A,Ax}}, xs::AT) where {T,N,A,Ax,AT<:AbstractArray} =
Adapt.adapt_storage(A, xs)

Adapt.parent_type(::Type{ComponentArray{T,N,A,Ax}}) where {T,N,A,Ax} = A

# Entry from NamedTuple, Dict, or kwargs
ComponentArray{T}(nt::NamedTuple) where T = ComponentArray(make_carray_args(T, nt)...)
ComponentArray{T}(::NamedTuple{(), Tuple{}}) where T = ComponentArray(T[], (FlatAxis(),))
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
4 changes: 2 additions & 2 deletions test/gpu_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using JLArrays
using JLArrays, LinearAlgebra

JLArrays.allowscalar(false)

Expand All @@ -11,7 +11,7 @@ jlca = ComponentArray(jla, Axis(a=1:2, b=3:4))
@test getdata(map(identity, jlca)) isa JLArray
@test all(==(0), map(-, jlca, jla))
@test all(map(-, jlca, jlca) .== 0)
@test all(==(0), map(-, jla, jlca))
@test all(==(0), map(-, jla, jlca)) broken=(pkgversion(JLArrays.GPUArrays) v"11")

@test any(==(1), jlca)
@test count(>(2), jlca) == 2
Expand Down
8 changes: 8 additions & 0 deletions test/reactant_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
using Reactant, ComponentArrays

x = ComponentArray(; a = rand(4), b = rand(2))
x_ra = Reactant.to_rarray(x)

fn(x) = x.a .+ sum(abs2, x.b) .+ 1

@test @jit(fn(x_ra)) fn(x)
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -732,3 +732,7 @@ end
@testset "GPU" begin
include("gpu_tests.jl")
end

@testset "Reactant" begin
include("reactant_tests.jl")
end

0 comments on commit dd0e4c0

Please sign in to comment.