diff --git a/Project.toml b/Project.toml index 7e27b3ae..cab778b8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ComponentArrays" uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" authors = ["Jonnie Diegelman <47193959+jonniedie@users.noreply.github.com>"] -version = "0.15.21" +version = "0.15.22" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/compat/chainrulescore.jl b/src/compat/chainrulescore.jl index 9f9211bd..bd07814f 100644 --- a/src/compat/chainrulescore.jl +++ b/src/compat/chainrulescore.jl @@ -1,5 +1,5 @@ function ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, s::Union{Symbol,Val}) - return getproperty(x, s), Δ -> getproperty_adjoint(Δ, x, s) + return getproperty(x, s), Δ -> getproperty_adjoint(ChainRulesCore.unthunk(Δ), x, s) end function getproperty_adjoint(Δ, x, s) @@ -28,9 +28,9 @@ function ChainRulesCore.rrule(cfg::ChainRulesCore.RuleConfig{>:ChainRulesCore.Ha return y_, pb_f end -ChainRulesCore.rrule(::typeof(getdata), x::ComponentArray) = getdata(x), Δ -> (ChainRulesCore.NoTangent(), ComponentArray(Δ, getaxes(x))) +ChainRulesCore.rrule(::typeof(getdata), x::ComponentArray) = getdata(x), Δ -> (ChainRulesCore.NoTangent(), ComponentArray(ChainRulesCore.unthunk(Δ), getaxes(x))) -ChainRulesCore.rrule(::Type{ComponentArray}, data, axes) = ComponentArray(data, axes), Δ -> (ChainRulesCore.NoTangent(), getdata(Δ), ChainRulesCore.NoTangent()) +ChainRulesCore.rrule(::Type{ComponentArray}, data, axes) = ComponentArray(data, axes), Δ -> (ChainRulesCore.NoTangent(), getdata(ChainRulesCore.unthunk(Δ)), ChainRulesCore.NoTangent()) function ChainRulesCore.ProjectTo(ca::ComponentArray) return ChainRulesCore.ProjectTo{ComponentArray}(; project=ChainRulesCore.ProjectTo(getdata(ca)), axes=getaxes(ca)) @@ -49,6 +49,8 @@ end function ChainRulesCore.rrule(::Type{CA}, nt::NamedTuple) where {CA<:ComponentArray} y = CA(nt) + ∇NamedTupleToComponentArray(Δ) = ∇NamedTupleToComponentArray(ChainRulesCore.unthunk(Δ)) + function ∇NamedTupleToComponentArray(Δ::AbstractArray) if length(Δ) == length(y) return ∇NamedTupleToComponentArray(ComponentArray(vec(Δ), getaxes(y)))