Skip to content

Commit

Permalink
Fix GeoDataFrame Caching & Override Behavior (#611)
Browse files Browse the repository at this point in the history
* fix geodataframe cache

* simplify logic
  • Loading branch information
philipc2 authored Dec 7, 2023
1 parent 8e2ac94 commit 0135493
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 20 deletions.
34 changes: 33 additions & 1 deletion test/test_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,4 +390,36 @@ class TestGeoDataFrame(TestCase):
def test_to_gdf(self):
uxgrid = ux.open_grid(gridfile_geoflow)

gdf_with = uxgrid.to_geodataframe(exclude_antimeridian=False)
gdf_with_am = uxgrid.to_geodataframe(exclude_antimeridian=False)

gdf_without_am = uxgrid.to_geodataframe(exclude_antimeridian=True)

def test_cache_and_override(self):
"""Tests the cache and override functionality for GeoDataFrame
conversion."""

uxgrid = ux.open_grid(gridfile_geoflow)

gdf_a = uxgrid.to_geodataframe(exclude_antimeridian=False)

gdf_b = uxgrid.to_geodataframe(exclude_antimeridian=False)

assert gdf_a is gdf_b

gdf_c = uxgrid.to_geodataframe(exclude_antimeridian=True)

assert gdf_a is not gdf_c

gdf_d = uxgrid.to_geodataframe(exclude_antimeridian=True)

assert gdf_d is gdf_c

gdf_e = uxgrid.to_geodataframe(exclude_antimeridian=True,
override=True,
cache=False)

assert gdf_d is not gdf_e

gdf_f = uxgrid.to_geodataframe(exclude_antimeridian=True)

assert gdf_f is not gdf_e
9 changes: 0 additions & 9 deletions uxarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,15 +149,6 @@ def to_geodataframe(self,
f"for face-centered data.")

if self.values.size == self.uxgrid.n_face:
# face-centered data
if self.uxgrid._gdf is not None:
# determine if we need to re-compute the cached GeoDataFrame
if exclude_antimeridian and len(
self.uxgrid.antimeridian_face_indices) != len(
self.uxgrid._gdf):
override = True
elif self.uxgrid.n_face != len(self.uxgrid._gdf):
override = True

gdf = self.uxgrid.to_geodataframe(
override=override,
Expand Down
14 changes: 6 additions & 8 deletions uxarray/grid/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(self,

# initialize cached data structures (visualization)
self._gdf = None
self._gdf_exclude_am = None
self._poly_collection = None
self._line_collection = None
self._centroid_points_df_proj = [None, None]
Expand Down Expand Up @@ -953,14 +954,10 @@ def to_geodataframe(self,
"""

if self._gdf is not None:
# determine if we need to recompute a cached GeoDataFrame
if exclude_antimeridian:
if len(self._gdf) != self.n_face - len(
self.antimeridian_face_indices):
override = True
elif not exclude_antimeridian:
if len(self._gdf) != self.n_face:
override = True
# determine if we need to recompute a cached GeoDataFrame based on antimeridian
if self._gdf_exclude_am != exclude_antimeridian:
# cached gdf should match the exclude_antimeridian_flag
override = True

# use cached geodataframe
if self._gdf is not None and not override:
Expand All @@ -973,6 +970,7 @@ def to_geodataframe(self,
# cache computed geodataframe
if cache:
self._gdf = gdf
self._gdf_exclude_am = exclude_antimeridian

return gdf

Expand Down
6 changes: 6 additions & 0 deletions uxarray/plot/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ def rasterize(self,
interpolation: Optional[str] = "linear",
npartitions: Optional[int] = 1,
cache: Optional[bool] = True,
override: Optional[bool] = False,
size: Optional[int] = 5,
**kwargs):
"""Raster plot of a data variable residing on an unstructured grid
Expand Down Expand Up @@ -348,6 +349,7 @@ def rasterize(self,
interpolation=interpolation,
npartitions=npartitions,
cache=cache,
override=override,
size=size,
**kwargs)

Expand All @@ -359,6 +361,8 @@ def polygons(self,
height: Optional[int] = 500,
colorbar: Optional[bool] = True,
cmap: Optional[str] = "Blues",
cache: Optional[bool] = True,
override: Optional[bool] = False,
**kwargs):
"""Vector polygon plot shaded using a face-centered data variable.
Expand All @@ -382,6 +386,8 @@ def polygons(self,
height=height,
colorbar=colorbar,
cmap=cmap,
cache=cache,
override=override,
**kwargs)

@functools.wraps(dataarray_plot.points)
Expand Down
15 changes: 13 additions & 2 deletions uxarray/plot/dataarray_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def rasterize(uxda: UxDataArray,
interpolation: Optional[str] = "linear",
npartitions: Optional[int] = 1,
cache: Optional[bool] = True,
override: Optional[bool] = False,
size: Optional[int] = 5,
**kwargs):
"""Rasterized Plot of a Data Variable Residing on an Unstructured Grid.
Expand Down Expand Up @@ -169,6 +170,8 @@ def rasterize(uxda: UxDataArray,
aggregator=aggregator,
interpolation=interpolation,
pixel_ratio=pixel_ratio,
cache=cache,
override=override,
**kwargs)
else:
raise ValueError(f"Unsupported method: {method}.")
Expand Down Expand Up @@ -318,6 +321,8 @@ def _polygon_raster(uxda: UxDataArray,
interpolation: Optional[str] = "linear",
xlabel: Optional[str] = "Longitude",
ylabel: Optional[str] = "Latitude",
cache: Optional[bool] = True,
override: Optional[bool] = False,
**kwargs):
"""Implementation of Polygon Rasterization."""

Expand All @@ -327,7 +332,9 @@ def _polygon_raster(uxda: UxDataArray,
else:
clabel = kwargs.get("clabel")

gdf = uxda.to_geodataframe(exclude_antimeridian=exclude_antimeridian)
gdf = uxda.to_geodataframe(exclude_antimeridian=exclude_antimeridian,
cache=cache,
override=override)

hv_polygons = hv.Polygons(gdf, vdims=[uxda.name])

Expand Down Expand Up @@ -382,6 +389,8 @@ def polygons(uxda: UxDataArray,
cmap: Optional[str] = "Blues",
xlabel: Optional[str] = "Longitude",
ylabel: Optional[str] = "Latitude",
cache: Optional[bool] = True,
override: Optional[bool] = False,
**kwargs):
"""Vector Polygon Plot of a Data Variable Residing on an Unstructured Grid.
Expand All @@ -407,7 +416,9 @@ def polygons(uxda: UxDataArray,
else:
clabel = kwargs.get("clabel")

gdf = uxda.to_geodataframe(exclude_antimeridian=exclude_antimeridian)
gdf = uxda.to_geodataframe(exclude_antimeridian=exclude_antimeridian,
cache=cache,
override=override)

hv_polygons = hv.Polygons(gdf, vdims=[uxda.name])

Expand Down

0 comments on commit 0135493

Please sign in to comment.