Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Inverse Face Indices to Subsetted Grids #1122

Merged
merged 20 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions test/test_subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,19 +104,40 @@ def test_grid_bounding_box_subset():
bbox_antimeridian[0], bbox_antimeridian[1], element=element)




def test_uxda_isel():
uxds = ux.open_dataset(GRID_PATHS[0], DATA_PATHS[0])

sub = uxds['bottomDepth'].isel(n_face=[1, 2, 3])

assert len(sub) == 3


def test_uxda_isel_with_coords():
uxds = ux.open_dataset(GRID_PATHS[0], DATA_PATHS[0])
uxds = uxds.assign_coords({"lon_face": uxds.uxgrid.face_lon})
sub = uxds['bottomDepth'].isel(n_face=[1, 2, 3])

assert "lon_face" in sub.coords
assert len(sub.coords['lon_face']) == 3


def test_inverse_face_indices():
grid = ux.open_grid(GRID_PATHS[0])

# Test nearest neighbor subsetting
coord = [0, 0]
subset = grid.subset.nearest_neighbor(coord, k=1, element="face centers", inverse_indices=True)

assert subset.inverse_face_indices is not None

# Test bounding box subsetting
box = [(-10, 10), (-10, 10)]
subset = grid.subset.bounding_box(box[0], box[1], element="edge centers", inverse_indices=True)

assert subset.inverse_face_indices is not None

# Test bounding circle subsetting
center_coord = [0, 0]
subset = grid.subset.bounding_circle(center_coord, r=10, element="nodes", inverse_indices=True)

assert subset.inverse_face_indices is not None
14 changes: 10 additions & 4 deletions uxarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,7 +1035,7 @@ def _edge_centered(self) -> bool:
"n_edge" dimension)"""
return "n_edge" in self.dims

def isel(self, ignore_grid=False, *args, **kwargs):
def isel(self, ignore_grid=False, inverse_indices=False, *args, **kwargs):
"""Grid-informed implementation of xarray's ``isel`` method, which
enables indexing across grid dimensions.
Expand Down Expand Up @@ -1069,11 +1069,17 @@ def isel(self, ignore_grid=False, *args, **kwargs):
raise ValueError("Only one grid dimension can be sliced at a time")

if "n_node" in kwargs:
sliced_grid = self.uxgrid.isel(n_node=kwargs["n_node"])
sliced_grid = self.uxgrid.isel(
n_node=kwargs["n_node"], inverse_indices=inverse_indices
)
elif "n_edge" in kwargs:
sliced_grid = self.uxgrid.isel(n_edge=kwargs["n_edge"])
sliced_grid = self.uxgrid.isel(
n_edge=kwargs["n_edge"], inverse_indices=inverse_indices
)
else:
sliced_grid = self.uxgrid.isel(n_face=kwargs["n_face"])
sliced_grid = self.uxgrid.isel(
n_face=kwargs["n_face"], inverse_indices=inverse_indices
)

return self._slice_from_grid(sliced_grid)

Expand Down
12 changes: 8 additions & 4 deletions uxarray/cross_sections/dataarray_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __repr__(self):

return prefix + methods_heading

def constant_latitude(self, lat: float):
def constant_latitude(self, lat: float, inverse_indices: bool = False):
"""Extracts a cross-section of the data array by selecting all faces that
intersect with a specified line of constant latitude.

Expand All @@ -31,6 +31,8 @@ def constant_latitude(self, lat: float):
lat : float
The latitude at which to extract the cross-section, in degrees.
Must be between -90.0 and 90.0
inverse_indices : bool, optional
If True, stores the original grid indices

Returns
-------
Expand Down Expand Up @@ -60,9 +62,9 @@ def constant_latitude(self, lat: float):

faces = self.uxda.uxgrid.get_faces_at_constant_latitude(lat)

return self.uxda.isel(n_face=faces)
return self.uxda.isel(n_face=faces, inverse_indices=inverse_indices)

def constant_longitude(self, lon: float):
def constant_longitude(self, lon: float, inverse_indices: bool = False):
"""Extracts a cross-section of the data array by selecting all faces that
intersect with a specified line of constant longitude.

Expand All @@ -71,6 +73,8 @@ def constant_longitude(self, lon: float):
lon : float
The latitude at which to extract the cross-section, in degrees.
Must be between -180.0 and 180.0
inverse_indices : bool, optional
If True, stores the original grid indices

Returns
-------
Expand Down Expand Up @@ -102,7 +106,7 @@ def constant_longitude(self, lon: float):
lon,
)

return self.uxda.isel(n_face=faces)
return self.uxda.isel(n_face=faces, inverse_indices=inverse_indices)

def gca(self, *args, **kwargs):
raise NotImplementedError
Expand Down
14 changes: 12 additions & 2 deletions uxarray/cross_sections/grid_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def constant_latitude(
self,
lat: float,
return_face_indices: bool = False,
inverse_indices: bool = False,
):
"""Extracts a cross-section of the grid by selecting all faces that
intersect with a specified line of constant latitude.
Expand All @@ -36,6 +37,8 @@ def constant_latitude(
Must be between -90.0 and 90.0
return_face_indices : bool, optional
If True, also returns the indices of the faces that intersect with the line of constant latitude.
inverse_indices : bool, optional
If True, stores the original grid indices

Returns
-------
Expand Down Expand Up @@ -66,7 +69,9 @@ def constant_latitude(
if len(faces) == 0:
raise ValueError(f"No intersections found at lat={lat}.")

grid_at_constant_lat = self.uxgrid.isel(n_face=faces)
grid_at_constant_lat = self.uxgrid.isel(
n_face=faces, inverse_indices=inverse_indices
)

if return_face_indices:
return grid_at_constant_lat, faces
Expand All @@ -77,6 +82,7 @@ def constant_longitude(
self,
lon: float,
return_face_indices: bool = False,
inverse_indices: bool = False,
):
"""Extracts a cross-section of the grid by selecting all faces that
intersect with a specified line of constant longitude.
Expand All @@ -88,6 +94,8 @@ def constant_longitude(
Must be between -90.0 and 90.0
return_face_indices : bool, optional
If True, also returns the indices of the faces that intersect with the line of constant longitude.
inverse_indices : bool, optional
If True, stores the original grid indices

Returns
-------
Expand Down Expand Up @@ -117,7 +125,9 @@ def constant_longitude(
if len(faces) == 0:
raise ValueError(f"No intersections found at lon={lon}")

grid_at_constant_lon = self.uxgrid.isel(n_face=faces)
grid_at_constant_lon = self.uxgrid.isel(
n_face=faces, inverse_indices=inverse_indices
)

if return_face_indices:
return grid_at_constant_lon, faces
Expand Down
34 changes: 28 additions & 6 deletions uxarray/grid/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def __init__(
grid_ds: xr.Dataset,
source_grid_spec: Optional[str] = None,
source_dims_dict: Optional[dict] = {},
is_subset=False,
philipc2 marked this conversation as resolved.
Show resolved Hide resolved
):
# check if inputted dataset is a minimum representable 2D UGRID unstructured grid
if not _validate_minimum_ugrid(grid_ds):
Expand Down Expand Up @@ -191,6 +192,7 @@ def __init__(
# initialize attributes
self._antimeridian_face_indices = None
self._ds.assign_attrs({"source_grid_spec": self.source_grid_spec})
self.is_subset = is_subset

# cached parameters for GeoDataFrame conversions
self._gdf_cached_parameters = {
Expand Down Expand Up @@ -242,7 +244,9 @@ def __init__(
cross_section = UncachedAccessor(GridCrossSectionAccessor)

@classmethod
def from_dataset(cls, dataset, use_dual: Optional[bool] = False, **kwargs):
def from_dataset(
cls, dataset, use_dual: Optional[bool] = False, is_subset=False, **kwargs
philipc2 marked this conversation as resolved.
Show resolved Hide resolved
):
"""Constructs a ``Grid`` object from a dataset.

Parameters
Expand All @@ -252,6 +256,8 @@ def from_dataset(cls, dataset, use_dual: Optional[bool] = False, **kwargs):
containing ASCII files represents a FESOM2 grid.
use_dual : bool, default=False
When reading in MPAS formatted datasets, indicates whether to use the Dual Mesh
is_subset : bool, default=False
Bool flag to indicate whether a grid is a subset
"""

if isinstance(dataset, xr.Dataset):
Expand Down Expand Up @@ -301,7 +307,7 @@ def from_dataset(cls, dataset, use_dual: Optional[bool] = False, **kwargs):
except TypeError:
raise ValueError("Unsupported Grid Format")

return cls(grid_ds, source_grid_spec, source_dims_dict)
return cls(grid_ds, source_grid_spec, source_dims_dict, is_subset=is_subset)

@classmethod
def from_file(
Expand Down Expand Up @@ -1506,6 +1512,16 @@ def global_sphere_coverage(self):
(i.e. contains no holes)"""
return not self.partial_sphere_coverage

@property
def inverse_face_indices(self):
"""Indices for a subset that map each face in the subset back to the original grid"""
philipc2 marked this conversation as resolved.
Show resolved Hide resolved
if self.is_subset:
return self._ds["inverse_face_indices"]
else:
raise Exception(
"Grid is not a subset, therefore no inverse face indices exist"
)

def chunk(self, n_node="auto", n_edge="auto", n_face="auto"):
"""Converts all arrays to dask arrays with given chunks across grid
dimensions in-place.
Expand Down Expand Up @@ -2201,7 +2217,7 @@ def get_dual(self):

return dual

def isel(self, **dim_kwargs):
def isel(self, inverse_indices=False, **dim_kwargs):
philipc2 marked this conversation as resolved.
Show resolved Hide resolved
"""Indexes an unstructured grid along a given dimension (``n_node``,
``n_edge``, or ``n_face``) and returns a new grid.

Expand All @@ -2226,13 +2242,19 @@ def isel(self, **dim_kwargs):
raise ValueError("Indexing must be along a single dimension.")

if "n_node" in dim_kwargs:
return _slice_node_indices(self, dim_kwargs["n_node"])
return _slice_node_indices(
self, dim_kwargs["n_node"], inverse_indices=inverse_indices
)

elif "n_edge" in dim_kwargs:
return _slice_edge_indices(self, dim_kwargs["n_edge"])
return _slice_edge_indices(
self, dim_kwargs["n_edge"], inverse_indices=inverse_indices
)

elif "n_face" in dim_kwargs:
return _slice_face_indices(self, dim_kwargs["n_face"])
return _slice_face_indices(
self, dim_kwargs["n_face"], inverse_indices=inverse_indices
)

else:
raise ValueError(
Expand Down
16 changes: 9 additions & 7 deletions uxarray/grid/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
pass


def _slice_node_indices(grid, indices, inclusive=True):
def _slice_node_indices(grid, indices, inclusive=True, inverse_indices=False):
"""Slices (indexes) an unstructured grid given a list/array of node
indices, returning a new Grid composed of elements that contain the nodes
specified in the indices.
Expand All @@ -33,10 +33,10 @@ def _slice_node_indices(grid, indices, inclusive=True):
face_indices = np.unique(grid.node_face_connectivity.values[indices].ravel())
face_indices = face_indices[face_indices != INT_FILL_VALUE]

return _slice_face_indices(grid, face_indices)
return _slice_face_indices(grid, face_indices, inverse_indices=inverse_indices)


def _slice_edge_indices(grid, indices, inclusive=True):
def _slice_edge_indices(grid, indices, inclusive=True, inverse_indices=False):
"""Slices (indexes) an unstructured grid given a list/array of edge
indices, returning a new Grid composed of elements that contain the edges
specified in the indices.
Expand All @@ -59,10 +59,10 @@ def _slice_edge_indices(grid, indices, inclusive=True):
face_indices = np.unique(grid.edge_face_connectivity.values[indices].ravel())
face_indices = face_indices[face_indices != INT_FILL_VALUE]

return _slice_face_indices(grid, face_indices)
return _slice_face_indices(grid, face_indices, inverse_indices=inverse_indices)


def _slice_face_indices(grid, indices, inclusive=True):
def _slice_face_indices(grid, indices, inclusive=True, inverse_indices=False):
"""Slices (indexes) an unstructured grid given a list/array of face
indices, returning a new Grid composed of elements that contain the faces
specified in the indices.
Expand All @@ -77,7 +77,6 @@ def _slice_face_indices(grid, indices, inclusive=True):
Whether to perform inclusive (i.e. elements must contain at least one desired feature from a slice) as opposed
to exclusive (i.e elements be made up all desired features from a slice)
"""

if inclusive is False:
raise ValueError("Exclusive slicing is not yet supported.")

Expand Down Expand Up @@ -132,4 +131,7 @@ def _slice_face_indices(grid, indices, inclusive=True):
# drop any conn that would require re-computation
ds = ds.drop_vars(conn_name)

return Grid.from_dataset(ds, source_grid_spec=grid.source_grid_spec)
if inverse_indices:
ds["inverse_face_indices"] = indices

return Grid.from_dataset(ds, source_grid_spec=grid.source_grid_spec, is_subset=True)
15 changes: 12 additions & 3 deletions uxarray/subset/dataarray_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def bounding_box(
lat_bounds: Union[Tuple, List, np.ndarray],
element: Optional[str] = "nodes",
method: Optional[str] = "coords",
inverse_indices=False,
**kwargs,
):
"""Subsets an unstructured grid between two latitude and longitude
Expand All @@ -53,9 +54,11 @@ def bounding_box(
face centers, or edge centers lie within the bounds.
element: str
Element for use with `coords` comparison, one of `nodes`, `face centers`, or `edge centers`
inverse_indices : bool
Flag to indicate whether to store the original grids face indices for later use
"""
grid = self.uxda.uxgrid.subset.bounding_box(
lon_bounds, lat_bounds, element, method
lon_bounds, lat_bounds, element, method, inverse_indices=inverse_indices
)

return self.uxda._slice_from_grid(grid)
Expand All @@ -65,6 +68,7 @@ def bounding_circle(
center_coord: Union[Tuple, List, np.ndarray],
r: Union[float, int],
element: Optional[str] = "nodes",
inverse_indices=False,
**kwargs,
):
"""Subsets an unstructured grid by returning all elements within some
Expand All @@ -78,9 +82,11 @@ def bounding_circle(
Radius of bounding circle (in degrees)
element: str
Element for use with `coords` comparison, one of `nodes`, `face centers`, or `edge centers`
inverse_indices : bool
Flag to indicate whether to store the original grids face indices for later use
"""
grid = self.uxda.uxgrid.subset.bounding_circle(
center_coord, r, element, **kwargs
center_coord, r, element, inverse_indices=inverse_indices**kwargs
)
return self.uxda._slice_from_grid(grid)

Expand All @@ -89,6 +95,7 @@ def nearest_neighbor(
center_coord: Union[Tuple, List, np.ndarray],
k: int,
element: Optional[str] = "nodes",
inverse_indices=False,
**kwargs,
):
"""Subsets an unstructured grid by returning the ``k`` closest
Expand All @@ -102,10 +109,12 @@ def nearest_neighbor(
Number of neighbors to query
element: str
Element for use with `coords` comparison, one of `nodes`, `face centers`, or `edge centers`
inverse_indices : bool
Flag to indicate whether to store the original grids face indices for later use
"""

grid = self.uxda.uxgrid.subset.nearest_neighbor(
center_coord, k, element, **kwargs
center_coord, k, element, inverse_indices=inverse_indices, **kwargs
)

return self.uxda._slice_from_grid(grid)
Loading
Loading