Skip to content

Commit

Permalink
fix: force unthunk in rrule
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 11, 2025
1 parent dd0e4c0 commit 506efe9
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ComponentArrays"
uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
authors = ["Jonnie Diegelman <47193959+jonniedie@users.noreply.github.com>"]
version = "0.15.21"
version = "0.15.22"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
8 changes: 5 additions & 3 deletions src/compat/chainrulescore.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
function ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, s::Union{Symbol,Val})
return getproperty(x, s), Δ -> getproperty_adjoint(Δ, x, s)
return getproperty(x, s), Δ -> getproperty_adjoint(ChainRulesCore.unthunk(Δ), x, s)
end

function getproperty_adjoint(Δ, x, s)
Expand Down Expand Up @@ -28,9 +28,9 @@ function ChainRulesCore.rrule(cfg::ChainRulesCore.RuleConfig{>:ChainRulesCore.Ha
return y_, pb_f
end

ChainRulesCore.rrule(::typeof(getdata), x::ComponentArray) = getdata(x), Δ -> (ChainRulesCore.NoTangent(), ComponentArray(Δ, getaxes(x)))
ChainRulesCore.rrule(::typeof(getdata), x::ComponentArray) = getdata(x), Δ -> (ChainRulesCore.NoTangent(), ComponentArray(ChainRulesCore.unthunk(Δ), getaxes(x)))

ChainRulesCore.rrule(::Type{ComponentArray}, data, axes) = ComponentArray(data, axes), Δ -> (ChainRulesCore.NoTangent(), getdata(Δ), ChainRulesCore.NoTangent())
ChainRulesCore.rrule(::Type{ComponentArray}, data, axes) = ComponentArray(data, axes), Δ -> (ChainRulesCore.NoTangent(), getdata(ChainRulesCore.unthunk(Δ)), ChainRulesCore.NoTangent())

function ChainRulesCore.ProjectTo(ca::ComponentArray)
return ChainRulesCore.ProjectTo{ComponentArray}(; project=ChainRulesCore.ProjectTo(getdata(ca)), axes=getaxes(ca))
Expand All @@ -49,6 +49,8 @@ end
function ChainRulesCore.rrule(::Type{CA}, nt::NamedTuple) where {CA<:ComponentArray}
y = CA(nt)

∇NamedTupleToComponentArray(Δ) = ∇NamedTupleToComponentArray(ChainRulesCore.unthunk(Δ))

function ∇NamedTupleToComponentArray::AbstractArray)
if length(Δ) == length(y)
return ∇NamedTupleToComponentArray(ComponentArray(vec(Δ), getaxes(y)))
Expand Down

0 comments on commit 506efe9

Please sign in to comment.