Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Inverse Face Indices to Subsetted Grids #1122

Merged
merged 20 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions test/test_subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,20 @@ def test_inverse_face_indices():
coord = [0, 0]
subset = grid.subset.nearest_neighbor(coord, k=1, element="face centers", inverse_indices=True)

assert subset.inverse_face_indices is not None
assert subset.inverse_indices is not None

# Test bounding box subsetting
box = [(-10, 10), (-10, 10)]
subset = grid.subset.bounding_box(box[0], box[1], element="edge centers", inverse_indices=True)
subset = grid.subset.bounding_box(box[0], box[1], element="face centers", inverse_indices=True)

assert subset.inverse_face_indices is not None
assert subset.inverse_indices is not None

# Test bounding circle subsetting
center_coord = [0, 0]
subset = grid.subset.bounding_circle(center_coord, r=10, element="nodes", inverse_indices=True)
subset = grid.subset.bounding_circle(center_coord, r=10, element="face centers", inverse_indices=True)

assert subset.inverse_face_indices is not None
assert subset.inverse_indices is not None

# Ensure code raises exceptions when the element is edges or nodes
assert pytest.raises(Exception, grid.subset.bounding_circle, center_coord, r=10, element="edge centers", inverse_indices=True)
assert pytest.raises(Exception, grid.subset.bounding_circle, center_coord, r=10, element="node centers", inverse_indices=True)
10 changes: 7 additions & 3 deletions uxarray/cross_sections/dataarray_accessor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations


from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Union, List, Set

if TYPE_CHECKING:
pass
Expand All @@ -22,7 +22,9 @@ def __repr__(self):

return prefix + methods_heading

def constant_latitude(self, lat: float, inverse_indices: bool = False):
def constant_latitude(
self, lat: float, inverse_indices: Union[List[str], Set[str], bool] = False
):
"""Extracts a cross-section of the data array by selecting all faces that
intersect with a specified line of constant latitude.

Expand Down Expand Up @@ -64,7 +66,9 @@ def constant_latitude(self, lat: float, inverse_indices: bool = False):

return self.uxda.isel(n_face=faces, inverse_indices=inverse_indices)

def constant_longitude(self, lon: float, inverse_indices: bool = False):
def constant_longitude(
self, lon: float, inverse_indices: Union[List[str], Set[str], bool] = False
):
"""Extracts a cross-section of the data array by selecting all faces that
intersect with a specified line of constant longitude.

Expand Down
6 changes: 3 additions & 3 deletions uxarray/cross_sections/grid_accessor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations


from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Union, List, Set

if TYPE_CHECKING:
from uxarray.grid import Grid
Expand All @@ -25,7 +25,7 @@ def constant_latitude(
self,
lat: float,
return_face_indices: bool = False,
inverse_indices: bool = False,
inverse_indices: Union[List[str], Set[str], bool] = False,
):
"""Extracts a cross-section of the grid by selecting all faces that
intersect with a specified line of constant latitude.
Expand Down Expand Up @@ -82,7 +82,7 @@ def constant_longitude(
self,
lon: float,
return_face_indices: bool = False,
inverse_indices: bool = False,
inverse_indices: Union[List[str], Set[str], bool] = False,
):
"""Extracts a cross-section of the grid by selecting all faces that
intersect with a specified line of constant longitude.
Expand Down
39 changes: 32 additions & 7 deletions uxarray/grid/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from typing import (
Optional,
Union,
List,
Set,
)

# reader and writer imports
Expand Down Expand Up @@ -161,6 +163,7 @@ def __init__(
source_grid_spec: Optional[str] = None,
source_dims_dict: Optional[dict] = {},
is_subset=False,
philipc2 marked this conversation as resolved.
Show resolved Hide resolved
inverse_indices: Optional[xr.Dataset] = None,
):
# check if inputted dataset is a minimum representable 2D UGRID unstructured grid
if not _validate_minimum_ugrid(grid_ds):
Expand Down Expand Up @@ -194,6 +197,9 @@ def __init__(
self._ds.assign_attrs({"source_grid_spec": self.source_grid_spec})
self.is_subset = is_subset

if inverse_indices is not None:
self.inverse_indices = inverse_indices

# cached parameters for GeoDataFrame conversions
self._gdf_cached_parameters = {
"gdf": None,
Expand Down Expand Up @@ -244,9 +250,7 @@ def __init__(
cross_section = UncachedAccessor(GridCrossSectionAccessor)

@classmethod
def from_dataset(
cls, dataset, use_dual: Optional[bool] = False, is_subset=False, **kwargs
):
def from_dataset(cls, dataset, use_dual: Optional[bool] = False, **kwargs):
"""Constructs a ``Grid`` object from a dataset.

Parameters
Expand Down Expand Up @@ -307,7 +311,13 @@ def from_dataset(
except TypeError:
raise ValueError("Unsupported Grid Format")

return cls(grid_ds, source_grid_spec, source_dims_dict, is_subset=is_subset)
return cls(
grid_ds,
source_grid_spec,
source_dims_dict,
is_subset=kwargs.get("is_subset", False),
inverse_indices=kwargs.get("inverse_indices"),
)

@classmethod
def from_file(
Expand Down Expand Up @@ -1513,15 +1523,20 @@ def global_sphere_coverage(self):
return not self.partial_sphere_coverage

@property
def inverse_face_indices(self):
def inverse_indices(self):
philipc2 marked this conversation as resolved.
Show resolved Hide resolved
"""Indices for a subset that map each face in the subset back to the original grid"""
philipc2 marked this conversation as resolved.
Show resolved Hide resolved
if self.is_subset:
return self._ds["inverse_face_indices"]
return self._inverse_indices
else:
raise Exception(
"Grid is not a subset, therefore no inverse face indices exist"
)

@inverse_indices.setter
def inverse_indices(self, value):
assert isinstance(value, xr.Dataset)
self._inverse_indices = value
philipc2 marked this conversation as resolved.
Show resolved Hide resolved

def chunk(self, n_node="auto", n_edge="auto", n_face="auto"):
"""Converts all arrays to dask arrays with given chunks across grid
dimensions in-place.
Expand Down Expand Up @@ -2217,7 +2232,9 @@ def get_dual(self):

return dual

def isel(self, inverse_indices=False, **dim_kwargs):
def isel(
self, inverse_indices: Union[List[str], Set[str], bool] = False, **dim_kwargs
):
"""Indexes an unstructured grid along a given dimension (``n_node``,
``n_edge``, or ``n_face``) and returns a new grid.

Expand All @@ -2242,11 +2259,19 @@ def isel(self, inverse_indices=False, **dim_kwargs):
raise ValueError("Indexing must be along a single dimension.")

if "n_node" in dim_kwargs:
if inverse_indices:
raise Exception(
"Inverse indices are not yet supported for node selection, please use face centers"
)
return _slice_node_indices(
self, dim_kwargs["n_node"], inverse_indices=inverse_indices
)

elif "n_edge" in dim_kwargs:
if inverse_indices:
raise Exception(
"Inverse indices are not yet supported for edge selection, please use face centers"
)
return _slice_edge_indices(
self, dim_kwargs["n_edge"], inverse_indices=inverse_indices
)
Expand Down
44 changes: 39 additions & 5 deletions uxarray/grid/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,18 @@
import xarray as xr
from uxarray.constants import INT_FILL_VALUE, INT_DTYPE

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Union, List, Set

if TYPE_CHECKING:
pass


def _slice_node_indices(grid, indices, inclusive=True, inverse_indices=False):
def _slice_node_indices(
grid,
indices,
inclusive=True,
inverse_indices: Union[List[str], Set[str], bool] = False,
):
"""Slices (indexes) an unstructured grid given a list/array of node
indices, returning a new Grid composed of elements that contain the nodes
specified in the indices.
Expand All @@ -36,7 +41,12 @@ def _slice_node_indices(grid, indices, inclusive=True, inverse_indices=False):
return _slice_face_indices(grid, face_indices, inverse_indices=inverse_indices)


def _slice_edge_indices(grid, indices, inclusive=True, inverse_indices=False):
def _slice_edge_indices(
grid,
indices,
inclusive=True,
inverse_indices: Union[List[str], Set[str], bool] = False,
):
"""Slices (indexes) an unstructured grid given a list/array of edge
indices, returning a new Grid composed of elements that contain the edges
specified in the indices.
Expand All @@ -62,7 +72,12 @@ def _slice_edge_indices(grid, indices, inclusive=True, inverse_indices=False):
return _slice_face_indices(grid, face_indices, inverse_indices=inverse_indices)


def _slice_face_indices(grid, indices, inclusive=True, inverse_indices=False):
def _slice_face_indices(
grid,
indices,
inclusive=True,
inverse_indices: Union[List[str], Set[str], bool] = False,
):
"""Slices (indexes) an unstructured grid given a list/array of face
indices, returning a new Grid composed of elements that contain the faces
specified in the indices.
Expand Down Expand Up @@ -132,6 +147,25 @@ def _slice_face_indices(grid, indices, inclusive=True, inverse_indices=False):
ds = ds.drop_vars(conn_name)

if inverse_indices:
ds["inverse_face_indices"] = indices
inverse_indices_ds = xr.Dataset()

index_types = {
"face centers": face_indices,
"edge centers": edge_indices,
"nodes": node_indices,
}
if isinstance(inverse_indices, bool):
inverse_indices_ds["face centers"] = face_indices
else:
for index_type in inverse_indices[0]:
if index_type in index_types:
inverse_indices_ds[index_type] = index_types[index_type]

return Grid.from_dataset(
ds,
source_grid_spec=grid.source_grid_spec,
is_subset=True,
inverse_indices=inverse_indices_ds,
)

return Grid.from_dataset(ds, source_grid_spec=grid.source_grid_spec, is_subset=True)
10 changes: 5 additions & 5 deletions uxarray/subset/dataarray_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from typing import TYPE_CHECKING, Union, Tuple, List, Optional
from typing import TYPE_CHECKING, Union, Tuple, List, Optional, Set

if TYPE_CHECKING:
pass
Expand Down Expand Up @@ -33,7 +33,7 @@ def bounding_box(
lat_bounds: Union[Tuple, List, np.ndarray],
element: Optional[str] = "nodes",
method: Optional[str] = "coords",
inverse_indices=False,
inverse_indices: Union[List[str], Set[str], bool] = False,
**kwargs,
):
"""Subsets an unstructured grid between two latitude and longitude
Expand Down Expand Up @@ -68,7 +68,7 @@ def bounding_circle(
center_coord: Union[Tuple, List, np.ndarray],
r: Union[float, int],
element: Optional[str] = "nodes",
inverse_indices=False,
inverse_indices: Union[List[str], Set[str], bool] = False,
**kwargs,
):
"""Subsets an unstructured grid by returning all elements within some
Expand All @@ -86,7 +86,7 @@ def bounding_circle(
Flag to indicate whether to store the original grids face indices for later use
"""
grid = self.uxda.uxgrid.subset.bounding_circle(
center_coord, r, element, inverse_indices=inverse_indices**kwargs
center_coord, r, element, inverse_indices=inverse_indices, **kwargs
)
return self.uxda._slice_from_grid(grid)

Expand All @@ -95,7 +95,7 @@ def nearest_neighbor(
center_coord: Union[Tuple, List, np.ndarray],
k: int,
element: Optional[str] = "nodes",
inverse_indices=False,
inverse_indices: Union[List[str], Set[str], bool] = False,
**kwargs,
):
"""Subsets an unstructured grid by returning the ``k`` closest
Expand Down
8 changes: 4 additions & 4 deletions uxarray/subset/grid_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from typing import TYPE_CHECKING, Union, Tuple, List, Optional
from typing import TYPE_CHECKING, Union, Tuple, List, Optional, Set

if TYPE_CHECKING:
from uxarray.grid import Grid
Expand Down Expand Up @@ -33,7 +33,7 @@ def bounding_box(
lat_bounds: Union[Tuple, List, np.ndarray],
element: Optional[str] = "nodes",
method: Optional[str] = "coords",
inverse_indices=False,
inverse_indices: Union[List[str], Set[str], bool] = False,
**kwargs,
):
"""Subsets an unstructured grid between two latitude and longitude
Expand Down Expand Up @@ -118,7 +118,7 @@ def bounding_circle(
center_coord: Union[Tuple, List, np.ndarray],
r: Union[float, int],
element: Optional[str] = "nodes",
inverse_indices=False,
inverse_indices: Union[List[str], Set[str], bool] = False,
**kwargs,
):
"""Subsets an unstructured grid by returning all elements within some
Expand Down Expand Up @@ -154,7 +154,7 @@ def nearest_neighbor(
center_coord: Union[Tuple, List, np.ndarray],
k: int,
element: Optional[str] = "nodes",
inverse_indices=False,
inverse_indices: Union[List[str], Set[str], bool] = False,
**kwargs,
):
"""Subsets an unstructured grid by returning the ``k`` closest
Expand Down
Loading