Skip to content

Commit

Permalink
Make flatten work on tables too (#217)
Browse files Browse the repository at this point in the history
* Make flatten work on tables too

Add a test for flatten

Fix eltype from table

* Clarify what's going on

Co-authored-by: Rafael Schouten <rafaelschouten@gmail.com>

---------

Co-authored-by: Rafael Schouten <rafaelschouten@gmail.com>
  • Loading branch information
asinghvi17 and rafaqz authored Sep 23, 2024
1 parent c4c3a29 commit 17151d9
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/primitives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -473,8 +473,14 @@ flatten(f, ::Type{Target}, geom) where {Target<:GI.AbstractTrait} = _flatten(f,

_flatten(f, ::Type{Target}, geom) where Target = _flatten(f, Target, GI.trait(geom), geom)
# Try to flatten over iterables
_flatten(f, ::Type{Target}, ::Nothing, iterable) where Target =
Iterators.flatten(Iterators.map(x -> _flatten(f, Target, x), iterable))
function _flatten(f, ::Type{Target}, ::Nothing, iterable) where Target
if Tables.istable(iterable)
column = Tables.getcolumn(iterable, first(GI.geometrycolumns(iterable)))
Iterators.map(x -> _flatten(f, Target, x), column) |> Iterators.flatten
else
Iterators.map(x -> _flatten(f, Target, x), iterable) |> Iterators.flatten
end
end
# Flatten feature collections
function _flatten(f, ::Type{Target}, ::GI.FeatureCollectionTrait, fc) where Target
Iterators.map(GI.getfeature(fc)) do feature
Expand Down
19 changes: 19 additions & 0 deletions test/primitives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,25 @@ end
@test GO._tuple_point.(GO.flatten(GI.PointTrait, very_wrapped)) == vcat(pv1, pv2)
@test collect(GO.flatten(GI.AbstractCurveTrait, [poly])) == [lr1, lr2]
@test collect(GO.flatten(GI.x, GI.PointTrait, very_wrapped)) == first.(vcat(pv1, pv2))
@testset "flatten with tables" begin
# Construct a simple table with a geometry column
geom_column = [GI.Point(1.0,1.0), GI.Point(2.0,2.0), GI.Point(3.0,3.0)]
table = (geometry = geom_column, id = [1, 2, 3])

# Test flatten on the table
flattened = collect(GO.flatten(GI.PointTrait, table))

@test length(flattened) == 3
@test all(p isa GI.Point for p in flattened)
@test flattened == geom_column

# Test flatten with a function
flattened_coords = collect(GO.flatten(p -> (GI.x(p), GI.y(p)), GI.PointTrait, table))

@test length(flattened_coords) == 3
@test all(c isa Tuple{Float64,Float64} for c in flattened_coords)
@test flattened_coords == [(1.0,1.0), (2.0,2.0), (3.0,3.0)]
end
end

@testset "reconstruct" begin
Expand Down

0 comments on commit 17151d9

Please sign in to comment.