Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DRAFT: Consistently use Xarray for Grid Reader #1092

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,8 +448,8 @@ def test_connectivity_build_n_nodes_per_face():
max_dimension = grid.n_max_face_nodes
min_dimension = 3

assert grid.n_nodes_per_face.min() >= min_dimension
assert grid.n_nodes_per_face.max() <= max_dimension
assert grid.n_nodes_per_face.values.min() >= min_dimension
assert grid.n_nodes_per_face.values.max() <= max_dimension

verts = [f0_deg, f1_deg, f2_deg, f3_deg, f4_deg, f5_deg, f6_deg]
grid_from_verts = ux.open_grid(verts)
Expand Down
1 change: 1 addition & 0 deletions test/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def test_replace_fill_values():
for dtype in dtypes:
# test face nodes with set dtype
face_nodes = np.array([[1, 2, -1], [-1, -1, -1]], dtype=dtype)
face_nodes = xr.DataArray(data=face_nodes)

# output of _replace_fill_values()
face_nodes_test = _replace_fill_values(
Expand Down
3 changes: 3 additions & 0 deletions test/test_mpas.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def test_add_fill_values():
nEdgesOnCell = np.array([2, 3, 2])
gold_output = np.array([[0, 1, fv, fv], [2, 3, 4, fv], [5, 6, fv, fv]], dtype=INT_DTYPE)

verticesOnCell = xr.DataArray(data=verticesOnCell, dims=['n_face', 'n_max_face_nodes'])
nEdgesOnCell = xr.DataArray(data=nEdgesOnCell, dims=['n_face'])

verticesOnCell = _replace_padding(verticesOnCell, nEdgesOnCell)
verticesOnCell = _replace_zeros(verticesOnCell)
verticesOnCell = _to_zero_index(verticesOnCell)
Expand Down
4 changes: 2 additions & 2 deletions test/test_zonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def test_zonal_weights(self):
data_path = self.datafile_vortex_ne30
uxds = ux.open_dataset(grid_path, data_path)

za_1 = uxds['psi'].zonal_mean((-90, 90, 1), use_robust_weights=True)
za_2 = uxds['psi'].zonal_mean((-90, 90, 1), use_robust_weights=False)
za_1 = uxds['psi'].zonal_mean((-90, 90, 30), use_robust_weights=True)
za_2 = uxds['psi'].zonal_mean((-90, 90, 30), use_robust_weights=False)

nt.assert_almost_equal(za_1.data, za_2.data)

Expand Down
62 changes: 49 additions & 13 deletions uxarray/core/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,28 @@
from uxarray.grid import Grid
from uxarray.core.dataset import UxDataset
from uxarray.core.utils import _map_dims_to_ugrid
from uxarray.io.utils import _parse_grid_type, _get_source_dims_dict


from warnings import warn


def rename_chunks(grid_ds, chunks):
# TODO: might need to copy chunks
print(chunks)
grid_spec = _parse_grid_type(grid_ds)

source_dims_dict = _get_source_dims_dict(grid_ds, grid_spec)

# correctly chunk standardized ugrid dimension names
for original_grid_dim, ugrid_grid_dim in source_dims_dict.items():
if ugrid_grid_dim in chunks["chunks"]:
chunks["chunks"][original_grid_dim] = chunks["chunks"][ugrid_grid_dim]

print(chunks)
return chunks


def open_grid(
grid_filename_or_obj: Union[
str, os.PathLike, xr.DataArray, np.ndarray, list, tuple, dict
Expand Down Expand Up @@ -84,6 +102,7 @@ def open_grid(

if isinstance(grid_filename_or_obj, xr.Dataset):
# construct a grid from a dataset file
# TODO: insert/rechunk here?
uxgrid = Grid.from_dataset(grid_filename_or_obj, use_dual=use_dual)

elif isinstance(grid_filename_or_obj, dict):
Expand All @@ -97,7 +116,16 @@ def open_grid(
# attempt to use Xarray directly for remaining input types
else:
try:
grid_ds = xr.open_dataset(grid_filename_or_obj, **kwargs)
# TODO: Insert chunking here
if "data_chunks" in kwargs:
data_chunks = kwargs["data_chunks"]
del kwargs["data_chunks"]

grid_ds = xr.open_dataset(grid_filename_or_obj, **kwargs)
chunks = rename_chunks(grid_ds, data_chunks)
grid_ds = xr.open_dataset(grid_filename_or_obj, chunks=chunks, **kwargs)
else:
grid_ds = xr.open_dataset(grid_filename_or_obj, **kwargs)

uxgrid = Grid.from_dataset(grid_ds, use_dual=use_dual)
except ValueError:
Expand Down Expand Up @@ -173,25 +201,33 @@ def open_dataset(
stacklevel=2,
)

# TODO:
if "chunks" in kwargs and "chunks" not in grid_kwargs:
chunks = kwargs["chunks"]

grid_kwargs["data_chunks"] = chunks

# Grid definition
uxgrid = open_grid(
grid_filename_or_obj, latlon=latlon, use_dual=use_dual, **grid_kwargs
)

if "chunks" in kwargs:
# correctly chunk standardized ugrid dimension names
source_dims_dict = uxgrid._source_dims_dict
for original_grid_dim, ugrid_grid_dim in source_dims_dict.items():
if ugrid_grid_dim in kwargs["chunks"]:
kwargs["chunks"][original_grid_dim] = kwargs["chunks"][ugrid_grid_dim]
# if "chunks" in kwargs:
# # correctly chunk standardized ugrid dimension names
# source_dims_dict = uxgrid._source_dims_dict
# for original_grid_dim, ugrid_grid_dim in source_dims_dict.items():
# if ugrid_grid_dim in kwargs["chunks"]:
# kwargs["chunks"][original_grid_dim] = kwargs["chunks"][ugrid_grid_dim]

# UxDataset
ds = xr.open_dataset(filename_or_obj, **kwargs) # type: ignore

# map each dimension to its UGRID equivalent
# TODO: maybe issues here?
ds = _map_dims_to_ugrid(ds, uxgrid._source_dims_dict, uxgrid)

uxds = UxDataset(ds, uxgrid=uxgrid, source_datasets=str(filename_or_obj))
# UxDataset.from_xarray(ds, uxgrid=uxgrid, source_d

return uxds

Expand Down Expand Up @@ -275,12 +311,12 @@ def open_mfdataset(
grid_filename_or_obj, latlon=latlon, use_dual=use_dual, **grid_kwargs
)

if "chunks" in kwargs:
# correctly chunk standardized ugrid dimension names
source_dims_dict = uxgrid._source_dims_dict
for original_grid_dim, ugrid_grid_dim in source_dims_dict.items():
if ugrid_grid_dim in kwargs["chunks"]:
kwargs["chunks"][original_grid_dim] = kwargs["chunks"][ugrid_grid_dim]
# if "chunks" in kwargs:
# # correctly chunk standardized ugrid dimension names
# source_dims_dict = uxgrid._source_dims_dict
# for original_grid_dim, ugrid_grid_dim in source_dims_dict.items():
# if ugrid_grid_dim in kwargs["chunks"]:
# kwargs["chunks"][original_grid_dim] = kwargs["chunks"][ugrid_grid_dim]

# UxDataset
ds = xr.open_mfdataset(paths, **kwargs) # type: ignore
Expand Down
1 change: 0 additions & 1 deletion uxarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,6 @@ def from_xarray(cls, ds: xr.Dataset, uxgrid: Grid = None, ugrid_dims: dict = Non
if uxgrid is not None:
if ugrid_dims is None and uxgrid._source_dims_dict is not None:
ugrid_dims = uxgrid._source_dims_dict
pass
# Grid is provided,
else:
# parse
Expand Down
80 changes: 41 additions & 39 deletions uxarray/grid/connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,58 +64,60 @@ def _replace_fill_values(grid_var, original_fill, new_fill, new_dtype=None):

Parameters
----------
grid_var : np.ndarray
grid variable to be modified
grid_var : xr.DataArray
Grid variable to be modified
original_fill : constant
original fill value used in (``grid_var``)
Original fill value used in (``grid_var``)
new_fill : constant
new fill value to be used in (``grid_var``)
New fill value to be used in (``grid_var``)
new_dtype : np.dtype, optional
new data type to convert (``grid_var``) to
New data type to convert (``grid_var``) to

Returns
----------
grid_var : xarray.Dataset
Input Dataset with correct fill value and dtype
-------
grid_var : xr.DataArray
Modified DataArray with updated fill values and dtype
"""

# locations of fill values
# Identify fill value locations
if original_fill is not None and np.isnan(original_fill):
fill_val_idx = np.isnan(grid_var)
grid_var[fill_val_idx] = 0.0 # todo?
# For NaN fill values
fill_val_idx = grid_var.isnull()
# Temporarily replace NaNs with a placeholder if dtype conversion is needed
if new_dtype is not None and np.issubdtype(new_dtype, np.floating):
grid_var = grid_var.fillna(0.0)
else:
# Choose an appropriate placeholder for non-floating types
grid_var = grid_var.fillna(new_fill)
else:
# For non-NaN fill values
fill_val_idx = grid_var == original_fill

# convert to new data type
if new_dtype != grid_var.dtype and new_dtype is not None:
# Convert to the new data type if specified
if new_dtype is not None and new_dtype != grid_var.dtype:
grid_var = grid_var.astype(new_dtype)

# ensure fill value can be represented with current integer data type
if np.issubdtype(new_dtype, np.integer):
int_min = np.iinfo(grid_var.dtype).min
int_max = np.iinfo(grid_var.dtype).max
# ensure new_fill is in range [int_min, int_max]
if new_fill < int_min or new_fill > int_max:
raise ValueError(
f"New fill value: {new_fill} not representable by"
f" integer dtype: {grid_var.dtype}"
)

# ensure non-nan fill value can be represented with current float data type
elif np.issubdtype(new_dtype, np.floating) and not np.isnan(new_fill):
float_min = np.finfo(grid_var.dtype).min
float_max = np.finfo(grid_var.dtype).max
# ensure new_fill is in range [float_min, float_max]
if new_fill < float_min or new_fill > float_max:
raise ValueError(
f"New fill value: {new_fill} not representable by"
f" float dtype: {grid_var.dtype}"
)
else:
raise ValueError(f"Data type {grid_var.dtype} not supportedfor grid variables")

# replace all zeros with a fill value
grid_var[fill_val_idx] = new_fill
# Validate that the new_fill can be represented in the new_dtype
if new_dtype is not None:
if np.issubdtype(new_dtype, np.integer):
int_min = np.iinfo(new_dtype).min
int_max = np.iinfo(new_dtype).max
if not (int_min <= new_fill <= int_max):
raise ValueError(
f"New fill value: {new_fill} not representable by integer dtype: {new_dtype}"
)
elif np.issubdtype(new_dtype, np.floating):
if not (
np.isnan(new_fill)
or (np.finfo(new_dtype).min <= new_fill <= np.finfo(new_dtype).max)
):
raise ValueError(
f"New fill value: {new_fill} not representable by float dtype: {new_dtype}"
)
else:
raise ValueError(f"Data type {new_dtype} not supported for grid variables")

grid_var = grid_var.where(~fill_val_idx, new_fill)

return grid_var

Expand Down
13 changes: 8 additions & 5 deletions uxarray/grid/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ def __init__(
)
# TODO: more checks for validate grid (lat/lon coords, etc)

# TODO:
self._load_on_access = True

# mapping of ugrid dimensions and variables to source dataset's conventions
self._source_dims_dict = source_dims_dict

Expand Down Expand Up @@ -1177,7 +1180,7 @@ def edge_node_x(self) -> xr.DataArray:
"""

if "edge_node_x" not in self._ds:
_edge_node_x = self.node_x.values[self.edge_node_connectivity.values]
_edge_node_x = self.node_x[self.edge_node_connectivity]

self._ds["edge_node_x"] = xr.DataArray(
data=_edge_node_x,
Expand All @@ -1194,7 +1197,7 @@ def edge_node_y(self) -> xr.DataArray:
"""

if "edge_node_y" not in self._ds:
_edge_node_y = self.node_y.values[self.edge_node_connectivity.values]
_edge_node_y = self.node_y[self.edge_node_connectivity]

self._ds["edge_node_y"] = xr.DataArray(
data=_edge_node_y,
Expand All @@ -1211,7 +1214,7 @@ def edge_node_z(self) -> xr.DataArray:
"""

if "edge_node_z" not in self._ds:
_edge_node_z = self.node_z.values[self.edge_node_connectivity.values]
_edge_node_z = self.node_z[self.edge_node_connectivity]

self._ds["edge_node_z"] = xr.DataArray(
data=_edge_node_z,
Expand Down Expand Up @@ -1613,7 +1616,7 @@ def chunk(self, n_node="auto", n_edge="auto", n_face="auto"):

def get_ball_tree(
self,
coordinates: Optional[str] = "nodes",
coordinates: Optional[str] = "face centers",
coordinate_system: Optional[str] = "spherical",
distance_metric: Optional[str] = "haversine",
reconstruct: bool = False,
Expand Down Expand Up @@ -1663,7 +1666,7 @@ def get_ball_tree(

def get_kd_tree(
self,
coordinates: Optional[str] = "nodes",
coordinates: Optional[str] = "face centers",
coordinate_system: Optional[str] = "cartesian",
distance_metric: Optional[str] = "minkowski",
reconstruct: bool = False,
Expand Down
4 changes: 2 additions & 2 deletions uxarray/grid/neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class KDTree:
def __init__(
self,
grid,
coordinates: Optional[str] = "nodes",
coordinates: Optional[str] = "face centers",
coordinate_system: Optional[str] = "cartesian",
distance_metric: Optional[str] = "minkowski",
reconstruct: bool = False,
Expand Down Expand Up @@ -433,7 +433,7 @@ class BallTree:
def __init__(
self,
grid,
coordinates: Optional[str] = "nodes",
coordinates: Optional[str] = "face centers",
coordinate_system: Optional[str] = "spherical",
distance_metric: Optional[str] = "haversine",
reconstruct: bool = False,
Expand Down
24 changes: 17 additions & 7 deletions uxarray/grid/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,22 +107,29 @@ def _slice_face_indices(

face_indices = indices

# TODO: stop using .values, use Xarray directly
# nodes of each face (inclusive)
node_indices = np.unique(grid.face_node_connectivity.values[face_indices].ravel())
node_indices = node_indices[node_indices != INT_FILL_VALUE]

# edges of each face (inclusive)
edge_indices = np.unique(grid.face_edge_connectivity.values[face_indices].ravel())
edge_indices = edge_indices[edge_indices != INT_FILL_VALUE]

# index original dataset to obtain a 'subgrid'
ds = ds.isel(n_node=node_indices)
ds = ds.isel(n_face=face_indices)
ds = ds.isel(n_edge=edge_indices)

# Only slice edge dimension if we already have the connectivity
if "face_edge_connectivity" in grid._ds:
# TODO: add a warning here?
edge_indices = np.unique(
grid.face_edge_connectivity.values[face_indices].ravel()
)
edge_indices = edge_indices[edge_indices != INT_FILL_VALUE]
ds = ds.isel(n_edge=edge_indices)
ds["subgrid_edge_indices"] = xr.DataArray(edge_indices, dims=["n_edge"])
else:
edge_indices = None

ds["subgrid_node_indices"] = xr.DataArray(node_indices, dims=["n_node"])
ds["subgrid_face_indices"] = xr.DataArray(face_indices, dims=["n_face"])
ds["subgrid_edge_indices"] = xr.DataArray(edge_indices, dims=["n_edge"])

# mapping to update existing connectivity
node_indices_dict = {
Expand Down Expand Up @@ -152,9 +159,12 @@ def _slice_face_indices(

index_types = {
"face": face_indices,
"edge": edge_indices,
"node": node_indices,
}

if edge_indices is not None:
index_types["edge"] = edge_indices

if isinstance(inverse_indices, bool):
inverse_indices_ds["face"] = face_indices
else:
Expand Down
Loading
Loading