Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Inverse Face Indices to Subsetted Grids #1122

Merged
merged 20 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Indexing
:toctree: generated/

Grid.isel
Grid.inverse_indices

Dimensions
~~~~~~~~~~
Expand Down
33 changes: 31 additions & 2 deletions test/test_subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,19 +104,48 @@ def test_grid_bounding_box_subset():
bbox_antimeridian[0], bbox_antimeridian[1], element=element)




def test_uxda_isel():
uxds = ux.open_dataset(GRID_PATHS[0], DATA_PATHS[0])

sub = uxds['bottomDepth'].isel(n_face=[1, 2, 3])

assert len(sub) == 3


def test_uxda_isel_with_coords():
uxds = ux.open_dataset(GRID_PATHS[0], DATA_PATHS[0])
uxds = uxds.assign_coords({"lon_face": uxds.uxgrid.face_lon})
sub = uxds['bottomDepth'].isel(n_face=[1, 2, 3])

assert "lon_face" in sub.coords
assert len(sub.coords['lon_face']) == 3


def test_inverse_indices():
philipc2 marked this conversation as resolved.
Show resolved Hide resolved
grid = ux.open_grid(GRID_PATHS[0])

# Test nearest neighbor subsetting
coord = [0, 0]
subset = grid.subset.nearest_neighbor(coord, k=1, element="face centers", inverse_indices=True)

assert subset.inverse_indices is not None

# Test bounding box subsetting
box = [(-10, 10), (-10, 10)]
subset = grid.subset.bounding_box(box[0], box[1], element="face centers", inverse_indices=True)

assert subset.inverse_indices is not None

# Test bounding circle subsetting
center_coord = [0, 0]
subset = grid.subset.bounding_circle(center_coord, r=10, element="face centers", inverse_indices=True)

assert subset.inverse_indices is not None

# Ensure code raises exceptions when the element is edges or nodes
assert pytest.raises(Exception, grid.subset.bounding_circle, center_coord, r=10, element="edge centers", inverse_indices=True)
assert pytest.raises(Exception, grid.subset.bounding_circle, center_coord, r=10, element="nodes", inverse_indices=True)

# Test isel directly
subset = grid.isel(n_face=[1], inverse_indices=True)
philipc2 marked this conversation as resolved.
Show resolved Hide resolved
assert subset.inverse_indices is not None
14 changes: 10 additions & 4 deletions uxarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,7 +1035,7 @@ def _edge_centered(self) -> bool:
"n_edge" dimension)"""
return "n_edge" in self.dims

def isel(self, ignore_grid=False, *args, **kwargs):
def isel(self, ignore_grid=False, inverse_indices=False, *args, **kwargs):
"""Grid-informed implementation of xarray's ``isel`` method, which
enables indexing across grid dimensions.

Expand Down Expand Up @@ -1069,11 +1069,17 @@ def isel(self, ignore_grid=False, *args, **kwargs):
raise ValueError("Only one grid dimension can be sliced at a time")

if "n_node" in kwargs:
sliced_grid = self.uxgrid.isel(n_node=kwargs["n_node"])
sliced_grid = self.uxgrid.isel(
n_node=kwargs["n_node"], inverse_indices=inverse_indices
)
elif "n_edge" in kwargs:
sliced_grid = self.uxgrid.isel(n_edge=kwargs["n_edge"])
sliced_grid = self.uxgrid.isel(
n_edge=kwargs["n_edge"], inverse_indices=inverse_indices
)
else:
sliced_grid = self.uxgrid.isel(n_face=kwargs["n_face"])
sliced_grid = self.uxgrid.isel(
n_face=kwargs["n_face"], inverse_indices=inverse_indices
)

return self._slice_from_grid(sliced_grid)

Expand Down
20 changes: 15 additions & 5 deletions uxarray/cross_sections/dataarray_accessor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations


from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Union, List, Set

if TYPE_CHECKING:
pass
Expand All @@ -22,7 +22,9 @@ def __repr__(self):

return prefix + methods_heading

def constant_latitude(self, lat: float):
def constant_latitude(
self, lat: float, inverse_indices: Union[List[str], Set[str], bool] = False
):
"""Extracts a cross-section of the data array by selecting all faces that
intersect with a specified line of constant latitude.

Expand All @@ -31,6 +33,9 @@ def constant_latitude(self, lat: float):
lat : float
The latitude at which to extract the cross-section, in degrees.
Must be between -90.0 and 90.0
inverse_indices : Union[List[str], Set[str], bool], optional
Indicates whether to store the original grids indices. Passing `True` stores the original face centers,
other reverse indices can be stored by passing any or all of the following: (["face", "edge", "node"], True)

Returns
-------
Expand Down Expand Up @@ -60,9 +65,11 @@ def constant_latitude(self, lat: float):

faces = self.uxda.uxgrid.get_faces_at_constant_latitude(lat)

return self.uxda.isel(n_face=faces)
return self.uxda.isel(n_face=faces, inverse_indices=inverse_indices)

def constant_longitude(self, lon: float):
def constant_longitude(
self, lon: float, inverse_indices: Union[List[str], Set[str], bool] = False
):
"""Extracts a cross-section of the data array by selecting all faces that
intersect with a specified line of constant longitude.

Expand All @@ -71,6 +78,9 @@ def constant_longitude(self, lon: float):
lon : float
The latitude at which to extract the cross-section, in degrees.
Must be between -180.0 and 180.0
inverse_indices : Union[List[str], Set[str], bool], optional
Indicates whether to store the original grids indices. Passing `True` stores the original face centers,
other reverse indices can be stored by passing any or all of the following: (["face", "edge", "node"], True)

Returns
-------
Expand Down Expand Up @@ -102,7 +112,7 @@ def constant_longitude(self, lon: float):
lon,
)

return self.uxda.isel(n_face=faces)
return self.uxda.isel(n_face=faces, inverse_indices=inverse_indices)

def gca(self, *args, **kwargs):
raise NotImplementedError
Expand Down
18 changes: 15 additions & 3 deletions uxarray/cross_sections/grid_accessor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations


from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Union, List, Set

if TYPE_CHECKING:
from uxarray.grid import Grid
Expand All @@ -25,6 +25,7 @@ def constant_latitude(
self,
lat: float,
return_face_indices: bool = False,
inverse_indices: Union[List[str], Set[str], bool] = False,
):
"""Extracts a cross-section of the grid by selecting all faces that
intersect with a specified line of constant latitude.
Expand All @@ -36,6 +37,9 @@ def constant_latitude(
Must be between -90.0 and 90.0
return_face_indices : bool, optional
If True, also returns the indices of the faces that intersect with the line of constant latitude.
inverse_indices : Union[List[str], Set[str], bool], optional
Indicates whether to store the original grids indices. Passing `True` stores the original face centers,
other reverse indices can be stored by passing any or all of the following: (["face", "edge", "node"], True)

Returns
-------
Expand Down Expand Up @@ -66,7 +70,9 @@ def constant_latitude(
if len(faces) == 0:
raise ValueError(f"No intersections found at lat={lat}.")

grid_at_constant_lat = self.uxgrid.isel(n_face=faces)
grid_at_constant_lat = self.uxgrid.isel(
n_face=faces, inverse_indices=inverse_indices
)

if return_face_indices:
return grid_at_constant_lat, faces
Expand All @@ -77,6 +83,7 @@ def constant_longitude(
self,
lon: float,
return_face_indices: bool = False,
inverse_indices: Union[List[str], Set[str], bool] = False,
):
"""Extracts a cross-section of the grid by selecting all faces that
intersect with a specified line of constant longitude.
Expand All @@ -88,6 +95,9 @@ def constant_longitude(
Must be between -90.0 and 90.0
return_face_indices : bool, optional
If True, also returns the indices of the faces that intersect with the line of constant longitude.
inverse_indices : Union[List[str], Set[str], bool], optional
Indicates whether to store the original grids indices. Passing `True` stores the original face centers,
other reverse indices can be stored by passing any or all of the following: (["face", "edge", "node"], True)

Returns
-------
Expand Down Expand Up @@ -117,7 +127,9 @@ def constant_longitude(
if len(faces) == 0:
raise ValueError(f"No intersections found at lon={lon}")

grid_at_constant_lon = self.uxgrid.isel(n_face=faces)
grid_at_constant_lon = self.uxgrid.isel(
n_face=faces, inverse_indices=inverse_indices
)

if return_face_indices:
return grid_at_constant_lon, faces
Expand Down
53 changes: 50 additions & 3 deletions uxarray/grid/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from typing import (
Optional,
Union,
List,
Set,
)

# reader and writer imports
Expand Down Expand Up @@ -137,6 +139,12 @@ class Grid:
source_dims_dict : dict, default={}
Mapping of dimensions from the source dataset to their UGRID equivalent (i.e. {nCell : n_face})

is_subset : bool, default=False
Flag to mark if the grid is a subset or not

inverse_indices: xr.Dataset, default=None
A dataset of indices that correspond to the original grid, if the grid being constructed is a subset

Examples
----------

Expand All @@ -160,6 +168,8 @@ def __init__(
grid_ds: xr.Dataset,
source_grid_spec: Optional[str] = None,
source_dims_dict: Optional[dict] = {},
is_subset: bool = False,
inverse_indices: Optional[xr.Dataset] = None,
):
# check if inputted dataset is a minimum representable 2D UGRID unstructured grid
if not _validate_minimum_ugrid(grid_ds):
Expand Down Expand Up @@ -191,6 +201,10 @@ def __init__(
# initialize attributes
self._antimeridian_face_indices = None
self._ds.assign_attrs({"source_grid_spec": self.source_grid_spec})
self.is_subset = is_subset

if inverse_indices is not None:
self._inverse_indices = inverse_indices

# cached parameters for GeoDataFrame conversions
self._gdf_cached_parameters = {
Expand Down Expand Up @@ -252,6 +266,8 @@ def from_dataset(cls, dataset, use_dual: Optional[bool] = False, **kwargs):
containing ASCII files represents a FESOM2 grid.
use_dual : bool, default=False
When reading in MPAS formatted datasets, indicates whether to use the Dual Mesh
is_subset : bool, default=False
Bool flag to indicate whether a grid is a subset
"""

if isinstance(dataset, xr.Dataset):
Expand Down Expand Up @@ -301,7 +317,13 @@ def from_dataset(cls, dataset, use_dual: Optional[bool] = False, **kwargs):
except TypeError:
raise ValueError("Unsupported Grid Format")

return cls(grid_ds, source_grid_spec, source_dims_dict)
return cls(
grid_ds,
source_grid_spec,
source_dims_dict,
is_subset=kwargs.get("is_subset", False),
inverse_indices=kwargs.get("inverse_indices"),
)

@classmethod
def from_file(
Expand Down Expand Up @@ -1506,6 +1528,16 @@ def global_sphere_coverage(self):
(i.e. contains no holes)"""
return not self.partial_sphere_coverage

@property
def inverse_indices(self) -> xr.Dataset:
"""Indices for a subset that map each face in the subset back to the original grid"""
philipc2 marked this conversation as resolved.
Show resolved Hide resolved
if self.is_subset:
return self._inverse_indices
else:
raise Exception(
"Grid is not a subset, therefore no inverse face indices exist"
)

def chunk(self, n_node="auto", n_edge="auto", n_face="auto"):
"""Converts all arrays to dask arrays with given chunks across grid
dimensions in-place.
Expand Down Expand Up @@ -2201,7 +2233,9 @@ def get_dual(self):

return dual

def isel(self, **dim_kwargs):
def isel(
self, inverse_indices: Union[List[str], Set[str], bool] = False, **dim_kwargs
):
"""Indexes an unstructured grid along a given dimension (``n_node``,
``n_edge``, or ``n_face``) and returns a new grid.

Expand All @@ -2211,6 +2245,9 @@ def isel(self, **dim_kwargs):
exclusive and clipped indexing is in the works.

Parameters
inverse_indices : Union[List[str], Set[str], bool], default=False
Indicates whether to store the original grids indices. Passing `True` stores the original face indices,
other reverse indices can be stored by passing any or all of the following: (["face", "edge", "node"], True)
**dims_kwargs: kwargs
Dimension to index, one of ['n_node', 'n_edge', 'n_face']

Expand All @@ -2226,13 +2263,23 @@ def isel(self, **dim_kwargs):
raise ValueError("Indexing must be along a single dimension.")

if "n_node" in dim_kwargs:
if inverse_indices:
raise Exception(
"Inverse indices are not yet supported for node selection, please use face centers"
)
return _slice_node_indices(self, dim_kwargs["n_node"])

elif "n_edge" in dim_kwargs:
if inverse_indices:
raise Exception(
"Inverse indices are not yet supported for edge selection, please use face centers"
)
return _slice_edge_indices(self, dim_kwargs["n_edge"])

elif "n_face" in dim_kwargs:
return _slice_face_indices(self, dim_kwargs["n_face"])
return _slice_face_indices(
self, dim_kwargs["n_face"], inverse_indices=inverse_indices
)

else:
raise ValueError(
Expand Down
Loading
Loading