Skip to content

Commit

Permalink
feat: reactant support
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 5, 2025
1 parent 882df4d commit 8cae035
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 1 deletion.
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ComponentArrays"
uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
authors = ["Jonnie Diegelman <47193959+jonniedie@users.noreply.github.com>"]
version = "0.15.20"
version = "0.15.21"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -18,6 +18,7 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Expand All @@ -28,6 +29,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ComponentArraysGPUArraysExt = "GPUArrays"
ComponentArraysKernelAbstractionsExt = "KernelAbstractions"
ComponentArraysOptimisersExt = "Optimisers"
ComponentArraysReactantExt = "Reactant"
ComponentArraysRecursiveArrayToolsExt = "RecursiveArrayTools"
ComponentArraysReverseDiffExt = "ReverseDiff"
ComponentArraysSciMLBaseExt = "SciMLBase"
Expand All @@ -45,6 +47,7 @@ GPUArrays = "10, 11"
KernelAbstractions = "0.9.29"
LinearAlgebra = "1.10"
Optimisers = "0.3, 0.4"
Reactant = "0.2.14"
RecursiveArrayTools = "3.8"
ReverseDiff = "1.15"
SciMLBase = "2"
Expand Down
26 changes: 26 additions & 0 deletions ext/ComponentArraysReactantExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
module ComponentArraysReactantExt

using ArrayInterface: ArrayInterface
using ComponentArrays, Reactant

const TracedComponentVector{T} = ComponentVector{
Reactant.TracedRNumber{T},<:Reactant.TracedRArray{T}
} where {T}

# Reactant is good at memory management but not great at handling wrapped types. So we avoid
# wrapping types into SubArrays and let Reactant optimize out intermediate allocations.

@inline function Base.getproperty(x::TracedComponentVector{T}, s::Symbol) where {T}
return getproperty(x, Val(s))
end

@inline function Base.getproperty(x::TracedComponentVector{T}, v::Val) where {T}
return ComponentArrays._getindex(Base.getindex, x, v)
end

function ArrayInterface.restructure(x::ComponentVector, y::TracedComponentVector)
getaxes(x) == getaxes(y) || error("Axes must match")
return y
end

end
2 changes: 2 additions & 0 deletions src/componentarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ end
Adapt.adapt_storage(::Type{ComponentArray{T,N,A,Ax}}, xs::AT) where {T,N,A,Ax,AT<:AbstractArray} =
Adapt.adapt_storage(A, xs)

Adapt.parent_type(::Type{ComponentArray{T,N,A,Ax}}) where {T,N,A,Ax} = A

# Entry from NamedTuple, Dict, or kwargs
ComponentArray{T}(nt::NamedTuple) where T = ComponentArray(make_carray_args(T, nt)...)
ComponentArray{T}(::NamedTuple{(), Tuple{}}) where T = ComponentArray(T[], (FlatAxis(),))
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
8 changes: 8 additions & 0 deletions test/reactant_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
using Reactant, ComponentArrays

x = ComponentArray(; a = rand(4), b = rand(2))
x_ra = Reactant.to_rarray(x)

fn(x) = x.a .+ sum(abs2, x.b) .+ 1

@test @jit(fn(x_ra)) fn(x)
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -732,3 +732,7 @@ end
@testset "GPU" begin
include("gpu_tests.jl")
end

@testset "Reactant" begin
include("reactant_tests.jl")
end

0 comments on commit 8cae035

Please sign in to comment.