Skip to content

Commit

Permalink
Merge pull request #249 from dingraha/mat_index_bug
Browse files Browse the repository at this point in the history
Mat index bug
  • Loading branch information
ChrisRackauckas authored Jan 31, 2025
2 parents a826dfb + e1024d3 commit 862e05c
Show file tree
Hide file tree
Showing 9 changed files with 196 additions and 22 deletions.
2 changes: 1 addition & 1 deletion docs/src/indexing_behavior.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ ComponentVector{Int64}(b = [4, 1], c = (a = 2, b = [6, 30]))
But what if our range doesn't capture a full component? We can see below that using `KeepIndex` on the first five elements returns a `ComponentVector` with those elements but only the `a` and `b` names, since the `c` component wasn't fully captured.
```jldoctest indexing-label-retain
julia> ca[KeepIndex(1:5)]
5-element ComponentVector{Int64} with axis Axis(a = 1, b = 2:3):
5-element ComponentVector{Int64} with axis Axis(a = 1, b = ViewAxis(2:3, Shaped1DAxis((2,)))):
5
4
1
Expand Down
6 changes: 3 additions & 3 deletions docs/src/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ ComponentVector{Int64}(a = 11, b = 2, c = 3, new = 42)
Higher dimensional `ComponentArray`s can be created too, but it's a little messy at the moment. The nice thing for modeling is that dimension expansion through broadcasted operations can create higher-dimensional `ComponentArray`s automatically, so Jacobian cache arrays that are created internally with `false .* x .* x'` will be `ComponentArray`s with proper axes. Check out the [ODE with Jacobian](https://github.com/SciML/ComponentArrays.jl/blob/master/examples/ODE_jac_example.jl) example in the examples folder to see how this looks in practice.
```jldoctest quickstart
julia> x2 = x .* x'
7×7 ComponentMatrix{Float64} with axes Axis(a = 1, b = 2:4, c = ViewAxis(5:7, Axis(a = 1, b = 2:3))) × Axis(a = 1, b = 2:4, c = ViewAxis(5:7, Axis(a = 1, b = 2:3)))
7×7 ComponentMatrix{Float64} with axes Axis(a = 1, b = ViewAxis(2:4, Shaped1DAxis((3,))), c = ViewAxis(5:7, Axis(a = 1, b = ViewAxis(2:3, Shaped1DAxis((2,)))))) × Axis(a = 1, b = ViewAxis(2:4, Shaped1DAxis((3,))), c = ViewAxis(5:7, Axis(a = 1, b = ViewAxis(2:3, Shaped1DAxis((2,))))))
1.0 2.0 1.0 4.0 400.0 1.0 2.0
2.0 4.0 2.0 8.0 800.0 2.0 4.0
1.0 2.0 1.0 4.0 400.0 1.0 2.0
Expand All @@ -54,7 +54,7 @@ julia> x2 = x .* x'
2.0 4.0 2.0 8.0 800.0 2.0 4.0
julia> x2[:c,:c]
3×3 ComponentMatrix{Float64} with axes Axis(a = 1, b = 2:3) × Axis(a = 1, b = 2:3)
3×3 ComponentMatrix{Float64} with axes Axis(a = 1, b = ViewAxis(2:3, Shaped1DAxis((2,)))) × Axis(a = 1, b = ViewAxis(2:3, Shaped1DAxis((2,))))
160000.0 400.0 800.0
400.0 1.0 2.0
800.0 2.0 4.0
Expand All @@ -66,7 +66,7 @@ julia> x2[:a,:c]
ComponentVector{Float64}(a = 400.0, b = [1.0, 2.0])
julia> x2[:b,:c]
3×3 ComponentMatrix{Float64} with axes FlatAxis() × Axis(a = 1, b = 2:3)
3×3 ComponentMatrix{Float64} with axes Shaped1DAxis((3,)) × Axis(a = 1, b = ViewAxis(2:3, Shaped1DAxis((2,))))
800.0 2.0 4.0
400.0 1.0 2.0
1600.0 4.0 8.0
Expand Down
2 changes: 1 addition & 1 deletion src/ComponentArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ export fastindices # Deprecated
include("lazyarray.jl")

include("axis.jl")
export AbstractAxis, Axis, PartitionedAxis, ShapedAxis, ViewAxis, FlatAxis
export AbstractAxis, Axis, PartitionedAxis, ShapedAxis, Shaped1DAxis, ViewAxis, FlatAxis

include("componentarray.jl")
export ComponentArray, ComponentVector, ComponentMatrix, getaxes, getdata, valkeys
Expand Down
22 changes: 18 additions & 4 deletions src/axis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,22 @@ example)
"""
struct ShapedAxis{Shape} <: AbstractAxis{nothing} end
@inline ShapedAxis(Shape) = ShapedAxis{Shape}()
ShapedAxis(::Tuple{<:Int}) = FlatAxis()
# ShapedAxis(::Tuple{<:Int}) = FlatAxis()
Base.length(::ShapedAxis{Shape}) where{Shape} = prod(Shape)

struct Shaped1DAxis{Shape} <: AbstractAxis{nothing} end
ShapedAxis(shape::Tuple{<:Int}) = Shaped1DAxis{shape}()
Shaped1DAxis(shape::Tuple{<:Int}) = Shaped1DAxis{shape}()
Base.length(::Shaped1DAxis{Shape}) where {Shape} = only(Shape)

const Shape = ShapedAxis

unshape(ax) = ax
unshape(ax::ShapedAxis) = Axis(indexmap(ax))
unshape(ax::Shaped1DAxis) = Axis(indexmap(ax))

Base.size(::ShapedAxis{Shape}) where {Shape} = Shape
Base.size(::Shaped1DAxis{Shape}) where {Shape} = Shape



Expand Down Expand Up @@ -133,9 +141,9 @@ Axis(::Number) = NullAxis()
Axis(::NamedTuple{()}) = FlatAxis()
Axis(x) = FlatAxis()

const NotShapedAxis = Union{Axis{IdxMap}, FlatAxis, NullAxis} where {IdxMap}
const NotPartitionedAxis = Union{Axis{IdxMap}, FlatAxis, NullAxis, ShapedAxis{Shape}} where {Shape, IdxMap}
const NotShapedOrPartitionedAxis = Union{Axis{IdxMap}, FlatAxis, NullAxis} where {IdxMap}
const NotShapedAxis = Union{Axis{IdxMap}, FlatAxis, NullAxis, Shaped1DAxis} where {IdxMap}
const NotPartitionedAxis = Union{Axis{IdxMap}, FlatAxis, NullAxis, ShapedAxis{Shape}, Shaped1DAxis} where {Shape, IdxMap}
const NotShapedOrPartitionedAxis = Union{Axis{IdxMap}, FlatAxis, Shaped1DAxis} where {IdxMap}


Base.merge(axs::Vararg{Axis}) = Axis(merge(indexmap.(axs)...))
Expand All @@ -149,6 +157,10 @@ reindex(i, offset) = i .+ offset
reindex(ax::FlatAxis, _) = ax
reindex(ax::Axis, offset) = Axis(map(x->reindex(x, offset), indexmap(ax)))
reindex(ax::ViewAxis, offset) = ViewAxis(viewindex(ax) .+ offset, indexmap(ax))
function reindex(ax::ViewAxis{OldInds,IdxMap,Ax}, offset) where {OldInds,IdxMap,Ax<:Shaped1DAxis}
NewInds = viewindex(ax) .+ offset
return ViewAxis(NewInds, Ax())
end

# Get AbstractAxis index
@inline Base.getindex(::AbstractAxis, idx) = ComponentIndex(idx)
Expand All @@ -175,6 +187,7 @@ end

_maybe_view_axis(inds, ax::AbstractAxis) = ViewAxis(inds, ax)
_maybe_view_axis(inds, ::NullAxis) = inds[1]
_maybe_view_axis(inds, ax::Union{ShapedAxis,Shaped1DAxis}) = ViewAxis(inds, ax)

struct CombinedAxis{C,A} <: AbstractUnitRange{Int}
component_axis::C
Expand All @@ -188,6 +201,7 @@ _component_axis(ax) = FlatAxis()

_array_axis(ax::CombinedAxis) = ax.array_axis
_array_axis(ax) = ax
_array_axis(ax::Int) = Shaped1DAxis((ax,))

Base.first(ax::CombinedAxis) = first(_array_axis(ax))

Expand Down
3 changes: 2 additions & 1 deletion src/compat/static_arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ end

_maybe_SArray(x::SubArray, ::Val{N}, ::FlatAxis) where {N} = SVector{N}(x)
_maybe_SArray(x::Base.ReshapedArray, ::Val, ::ShapedAxis{Sz}) where {Sz} = SArray{Tuple{Sz...}}(x)
_maybe_SArray(x, ::Val, ::Shaped1DAxis{Sz}) where {Sz} = SArray{Tuple{Sz...}}(x)
_maybe_SArray(x, vals...) = x

@generated function static_getproperty(ca::ComponentVector, ::Val{s}) where {s}
Expand Down Expand Up @@ -32,4 +33,4 @@ macro static_unpack(expr)
push!(out.args, :($esc_name = static_getproperty($parent_var_name, $(Val(name)))))
end
return out
end
end
9 changes: 7 additions & 2 deletions src/componentarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ ComponentArray{T}(::UndefInitializer, ax::Axes) where {T,Axes<:Tuple} =

# Entry from data array and AbstractAxis types dispatches to correct shapes and partitions
# then packs up axes into a tuple for inner constructor
ComponentArray(data, ::FlatAxis...) = data
# ComponentArray(data, ::FlatAxis...) = data
ComponentArray(data, ::Union{FlatAxis,Shaped1DAxis}...) = data
ComponentArray(data, ax::NotShapedOrPartitionedAxis...) = ComponentArray(data, ax)
ComponentArray(data, ax::NotPartitionedAxis...) = ComponentArray(maybe_reshape(data, ax...), unshape.(ax)...)
function ComponentArray(data, ax::AbstractAxis...)
Expand Down Expand Up @@ -179,6 +180,10 @@ function make_idx(data, nt::Union{NamedTuple, AbstractDict}, last_val)
)...)
return (data, ViewAxis(last_index(last_val) .+ (1:len), kvs))
end
function make_idx(data, nt::NamedTuple{(), Tuple{}}, last_val)
out = last_index(last_val) .+ (1:length(nt))
return (data, ViewAxis(out, ShapedAxis((length(nt),))))
end
function make_idx(data, pair::Pair, last_val)
data, ax = make_idx(data, pair.second, last_val)
len = recursive_length(data)
Expand Down Expand Up @@ -245,7 +250,7 @@ end
# Reshape ComponentArrays with ShapedAxis axes
maybe_reshape(data, ::NotShapedOrPartitionedAxis...) = data
function maybe_reshape(data, axs::AbstractAxis...)
shapes = filter_by_type(ShapedAxis, axs...) .|> size
shapes = filter_by_type(Union{ShapedAxis,Shaped1DAxis}, axs...) .|> size
shapes = reduce((tup, s) -> (tup..., s...), shapes)
return reshape(data, shapes)
end
Expand Down
4 changes: 3 additions & 1 deletion src/componentindex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ struct ComponentIndex{Idx, Ax<:AbstractAxis}
ax::Ax
end
ComponentIndex(idx) = ComponentIndex(idx, FlatAxis())
ComponentIndex(idx::CartesianIndex) = ComponentIndex(idx, ShapedAxis((1,)))
ComponentIndex(idx::AbstractArray{<:Integer}) = ComponentIndex(idx, ShapedAxis(size(idx)))
ComponentIndex(idx::Int) = ComponentIndex(idx, NullAxis())
ComponentIndex(vax::ViewAxis{Inds,IdxMap,Ax}) where {Inds,IdxMap,Ax} = ComponentIndex(Inds, vax.ax)

Expand Down Expand Up @@ -44,4 +46,4 @@ function _getindex_keep(ax::AbstractAxis, sym::Symbol)
end
new_ax = reindex(new_ax, -first(idx)+1)
return ComponentIndex(idx, new_ax)
end
end
2 changes: 2 additions & 0 deletions src/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ Base.show(io::IO, ::PartitionedAxis{PartSz, IdxMap, Ax}) where {PartSz, IdxMap,

Base.show(io::IO, ::ShapedAxis{Shape}) where {Shape} =
print(io, "ShapedAxis($Shape)")
Base.show(io::IO, ::Shaped1DAxis{Shape}) where {Shape} =
print(io, "Shaped1DAxis($Shape)")

Base.show(io::IO, ::MIME"text/plain", ::ViewAxis{Inds, IdxMap, Ax}) where {Inds, IdxMap, Ax} =
print(io, "ViewAxis($Inds, $(Ax()))")
Expand Down
Loading

0 comments on commit 862e05c

Please sign in to comment.