diff --git a/test/test_grid.py b/test/test_grid.py index 4c8d98c76..f8d271d5c 100644 --- a/test/test_grid.py +++ b/test/test_grid.py @@ -448,8 +448,8 @@ def test_connectivity_build_n_nodes_per_face(): max_dimension = grid.n_max_face_nodes min_dimension = 3 - assert grid.n_nodes_per_face.min() >= min_dimension - assert grid.n_nodes_per_face.max() <= max_dimension + assert grid.n_nodes_per_face.values.min() >= min_dimension + assert grid.n_nodes_per_face.values.max() <= max_dimension verts = [f0_deg, f1_deg, f2_deg, f3_deg, f4_deg, f5_deg, f6_deg] grid_from_verts = ux.open_grid(verts) diff --git a/test/test_helpers.py b/test/test_helpers.py index 3713228bf..96939232b 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -152,6 +152,7 @@ def test_replace_fill_values(): for dtype in dtypes: # test face nodes with set dtype face_nodes = np.array([[1, 2, -1], [-1, -1, -1]], dtype=dtype) + face_nodes = xr.DataArray(data=face_nodes) # output of _replace_fill_values() face_nodes_test = _replace_fill_values( diff --git a/test/test_mpas.py b/test/test_mpas.py index 7e6dc6cfb..cf1a915b3 100644 --- a/test/test_mpas.py +++ b/test/test_mpas.py @@ -73,6 +73,9 @@ def test_add_fill_values(): nEdgesOnCell = np.array([2, 3, 2]) gold_output = np.array([[0, 1, fv, fv], [2, 3, 4, fv], [5, 6, fv, fv]], dtype=INT_DTYPE) + verticesOnCell = xr.DataArray(data=verticesOnCell, dims=['n_face', 'n_max_face_nodes']) + nEdgesOnCell = xr.DataArray(data=nEdgesOnCell, dims=['n_face']) + verticesOnCell = _replace_padding(verticesOnCell, nEdgesOnCell) verticesOnCell = _replace_zeros(verticesOnCell) verticesOnCell = _to_zero_index(verticesOnCell) diff --git a/test/test_zonal.py b/test/test_zonal.py index 6ee578f72..e37584122 100644 --- a/test/test_zonal.py +++ b/test/test_zonal.py @@ -76,8 +76,8 @@ def test_zonal_weights(self): data_path = self.datafile_vortex_ne30 uxds = ux.open_dataset(grid_path, data_path) - za_1 = uxds['psi'].zonal_mean((-90, 90, 1), use_robust_weights=True) - za_2 = uxds['psi'].zonal_mean((-90, 90, 1), use_robust_weights=False) + za_1 = uxds['psi'].zonal_mean((-90, 90, 30), use_robust_weights=True) + za_2 = uxds['psi'].zonal_mean((-90, 90, 30), use_robust_weights=False) nt.assert_almost_equal(za_1.data, za_2.data) diff --git a/uxarray/core/api.py b/uxarray/core/api.py index d7420737f..ade944995 100644 --- a/uxarray/core/api.py +++ b/uxarray/core/api.py @@ -9,10 +9,28 @@ from uxarray.grid import Grid from uxarray.core.dataset import UxDataset from uxarray.core.utils import _map_dims_to_ugrid +from uxarray.io.utils import _parse_grid_type, _get_source_dims_dict + from warnings import warn +def rename_chunks(grid_ds, chunks): + # TODO: might need to copy chunks + print(chunks) + grid_spec = _parse_grid_type(grid_ds) + + source_dims_dict = _get_source_dims_dict(grid_ds, grid_spec) + + # correctly chunk standardized ugrid dimension names + for original_grid_dim, ugrid_grid_dim in source_dims_dict.items(): + if ugrid_grid_dim in chunks["chunks"]: + chunks["chunks"][original_grid_dim] = chunks["chunks"][ugrid_grid_dim] + + print(chunks) + return chunks + + def open_grid( grid_filename_or_obj: Union[ str, os.PathLike, xr.DataArray, np.ndarray, list, tuple, dict @@ -84,6 +102,7 @@ def open_grid( if isinstance(grid_filename_or_obj, xr.Dataset): # construct a grid from a dataset file + # TODO: insert/rechunk here? uxgrid = Grid.from_dataset(grid_filename_or_obj, use_dual=use_dual) elif isinstance(grid_filename_or_obj, dict): @@ -97,7 +116,16 @@ def open_grid( # attempt to use Xarray directly for remaining input types else: try: - grid_ds = xr.open_dataset(grid_filename_or_obj, **kwargs) + # TODO: Insert chunking here + if "data_chunks" in kwargs: + data_chunks = kwargs["data_chunks"] + del kwargs["data_chunks"] + + grid_ds = xr.open_dataset(grid_filename_or_obj, **kwargs) + chunks = rename_chunks(grid_ds, data_chunks) + grid_ds = xr.open_dataset(grid_filename_or_obj, chunks=chunks, **kwargs) + else: + grid_ds = xr.open_dataset(grid_filename_or_obj, **kwargs) uxgrid = Grid.from_dataset(grid_ds, use_dual=use_dual) except ValueError: @@ -173,25 +201,33 @@ def open_dataset( stacklevel=2, ) + # TODO: + if "chunks" in kwargs and "chunks" not in grid_kwargs: + chunks = kwargs["chunks"] + + grid_kwargs["data_chunks"] = chunks + # Grid definition uxgrid = open_grid( grid_filename_or_obj, latlon=latlon, use_dual=use_dual, **grid_kwargs ) - if "chunks" in kwargs: - # correctly chunk standardized ugrid dimension names - source_dims_dict = uxgrid._source_dims_dict - for original_grid_dim, ugrid_grid_dim in source_dims_dict.items(): - if ugrid_grid_dim in kwargs["chunks"]: - kwargs["chunks"][original_grid_dim] = kwargs["chunks"][ugrid_grid_dim] + # if "chunks" in kwargs: + # # correctly chunk standardized ugrid dimension names + # source_dims_dict = uxgrid._source_dims_dict + # for original_grid_dim, ugrid_grid_dim in source_dims_dict.items(): + # if ugrid_grid_dim in kwargs["chunks"]: + # kwargs["chunks"][original_grid_dim] = kwargs["chunks"][ugrid_grid_dim] # UxDataset ds = xr.open_dataset(filename_or_obj, **kwargs) # type: ignore # map each dimension to its UGRID equivalent + # TODO: maybe issues here? ds = _map_dims_to_ugrid(ds, uxgrid._source_dims_dict, uxgrid) uxds = UxDataset(ds, uxgrid=uxgrid, source_datasets=str(filename_or_obj)) + # UxDataset.from_xarray(ds, uxgrid=uxgrid, source_d return uxds @@ -275,12 +311,12 @@ def open_mfdataset( grid_filename_or_obj, latlon=latlon, use_dual=use_dual, **grid_kwargs ) - if "chunks" in kwargs: - # correctly chunk standardized ugrid dimension names - source_dims_dict = uxgrid._source_dims_dict - for original_grid_dim, ugrid_grid_dim in source_dims_dict.items(): - if ugrid_grid_dim in kwargs["chunks"]: - kwargs["chunks"][original_grid_dim] = kwargs["chunks"][ugrid_grid_dim] + # if "chunks" in kwargs: + # # correctly chunk standardized ugrid dimension names + # source_dims_dict = uxgrid._source_dims_dict + # for original_grid_dim, ugrid_grid_dim in source_dims_dict.items(): + # if ugrid_grid_dim in kwargs["chunks"]: + # kwargs["chunks"][original_grid_dim] = kwargs["chunks"][ugrid_grid_dim] # UxDataset ds = xr.open_mfdataset(paths, **kwargs) # type: ignore diff --git a/uxarray/core/dataset.py b/uxarray/core/dataset.py index 80e2f8027..4ab1a932f 100644 --- a/uxarray/core/dataset.py +++ b/uxarray/core/dataset.py @@ -276,7 +276,6 @@ def from_xarray(cls, ds: xr.Dataset, uxgrid: Grid = None, ugrid_dims: dict = Non if uxgrid is not None: if ugrid_dims is None and uxgrid._source_dims_dict is not None: ugrid_dims = uxgrid._source_dims_dict - pass # Grid is provided, else: # parse diff --git a/uxarray/grid/connectivity.py b/uxarray/grid/connectivity.py index 78e936117..c51997533 100644 --- a/uxarray/grid/connectivity.py +++ b/uxarray/grid/connectivity.py @@ -64,58 +64,60 @@ def _replace_fill_values(grid_var, original_fill, new_fill, new_dtype=None): Parameters ---------- - grid_var : np.ndarray - grid variable to be modified + grid_var : xr.DataArray + Grid variable to be modified original_fill : constant - original fill value used in (``grid_var``) + Original fill value used in (``grid_var``) new_fill : constant - new fill value to be used in (``grid_var``) + New fill value to be used in (``grid_var``) new_dtype : np.dtype, optional - new data type to convert (``grid_var``) to + New data type to convert (``grid_var``) to Returns - ---------- - grid_var : xarray.Dataset - Input Dataset with correct fill value and dtype + ------- + grid_var : xr.DataArray + Modified DataArray with updated fill values and dtype """ - # locations of fill values + # Identify fill value locations if original_fill is not None and np.isnan(original_fill): - fill_val_idx = np.isnan(grid_var) - grid_var[fill_val_idx] = 0.0 # todo? + # For NaN fill values + fill_val_idx = grid_var.isnull() + # Temporarily replace NaNs with a placeholder if dtype conversion is needed + if new_dtype is not None and np.issubdtype(new_dtype, np.floating): + grid_var = grid_var.fillna(0.0) + else: + # Choose an appropriate placeholder for non-floating types + grid_var = grid_var.fillna(new_fill) else: + # For non-NaN fill values fill_val_idx = grid_var == original_fill - # convert to new data type - if new_dtype != grid_var.dtype and new_dtype is not None: + # Convert to the new data type if specified + if new_dtype is not None and new_dtype != grid_var.dtype: grid_var = grid_var.astype(new_dtype) - # ensure fill value can be represented with current integer data type - if np.issubdtype(new_dtype, np.integer): - int_min = np.iinfo(grid_var.dtype).min - int_max = np.iinfo(grid_var.dtype).max - # ensure new_fill is in range [int_min, int_max] - if new_fill < int_min or new_fill > int_max: - raise ValueError( - f"New fill value: {new_fill} not representable by" - f" integer dtype: {grid_var.dtype}" - ) - - # ensure non-nan fill value can be represented with current float data type - elif np.issubdtype(new_dtype, np.floating) and not np.isnan(new_fill): - float_min = np.finfo(grid_var.dtype).min - float_max = np.finfo(grid_var.dtype).max - # ensure new_fill is in range [float_min, float_max] - if new_fill < float_min or new_fill > float_max: - raise ValueError( - f"New fill value: {new_fill} not representable by" - f" float dtype: {grid_var.dtype}" - ) - else: - raise ValueError(f"Data type {grid_var.dtype} not supportedfor grid variables") - - # replace all zeros with a fill value - grid_var[fill_val_idx] = new_fill + # Validate that the new_fill can be represented in the new_dtype + if new_dtype is not None: + if np.issubdtype(new_dtype, np.integer): + int_min = np.iinfo(new_dtype).min + int_max = np.iinfo(new_dtype).max + if not (int_min <= new_fill <= int_max): + raise ValueError( + f"New fill value: {new_fill} not representable by integer dtype: {new_dtype}" + ) + elif np.issubdtype(new_dtype, np.floating): + if not ( + np.isnan(new_fill) + or (np.finfo(new_dtype).min <= new_fill <= np.finfo(new_dtype).max) + ): + raise ValueError( + f"New fill value: {new_fill} not representable by float dtype: {new_dtype}" + ) + else: + raise ValueError(f"Data type {new_dtype} not supported for grid variables") + + grid_var = grid_var.where(~fill_val_idx, new_fill) return grid_var diff --git a/uxarray/grid/grid.py b/uxarray/grid/grid.py index 86a8006a7..aa189c309 100644 --- a/uxarray/grid/grid.py +++ b/uxarray/grid/grid.py @@ -193,6 +193,9 @@ def __init__( ) # TODO: more checks for validate grid (lat/lon coords, etc) + # TODO: + self._load_on_access = True + # mapping of ugrid dimensions and variables to source dataset's conventions self._source_dims_dict = source_dims_dict @@ -1177,7 +1180,7 @@ def edge_node_x(self) -> xr.DataArray: """ if "edge_node_x" not in self._ds: - _edge_node_x = self.node_x.values[self.edge_node_connectivity.values] + _edge_node_x = self.node_x[self.edge_node_connectivity] self._ds["edge_node_x"] = xr.DataArray( data=_edge_node_x, @@ -1194,7 +1197,7 @@ def edge_node_y(self) -> xr.DataArray: """ if "edge_node_y" not in self._ds: - _edge_node_y = self.node_y.values[self.edge_node_connectivity.values] + _edge_node_y = self.node_y[self.edge_node_connectivity] self._ds["edge_node_y"] = xr.DataArray( data=_edge_node_y, @@ -1211,7 +1214,7 @@ def edge_node_z(self) -> xr.DataArray: """ if "edge_node_z" not in self._ds: - _edge_node_z = self.node_z.values[self.edge_node_connectivity.values] + _edge_node_z = self.node_z[self.edge_node_connectivity] self._ds["edge_node_z"] = xr.DataArray( data=_edge_node_z, @@ -1613,7 +1616,7 @@ def chunk(self, n_node="auto", n_edge="auto", n_face="auto"): def get_ball_tree( self, - coordinates: Optional[str] = "nodes", + coordinates: Optional[str] = "face centers", coordinate_system: Optional[str] = "spherical", distance_metric: Optional[str] = "haversine", reconstruct: bool = False, @@ -1663,7 +1666,7 @@ def get_ball_tree( def get_kd_tree( self, - coordinates: Optional[str] = "nodes", + coordinates: Optional[str] = "face centers", coordinate_system: Optional[str] = "cartesian", distance_metric: Optional[str] = "minkowski", reconstruct: bool = False, diff --git a/uxarray/grid/neighbors.py b/uxarray/grid/neighbors.py index 2dcb70602..966dc7e17 100644 --- a/uxarray/grid/neighbors.py +++ b/uxarray/grid/neighbors.py @@ -45,7 +45,7 @@ class KDTree: def __init__( self, grid, - coordinates: Optional[str] = "nodes", + coordinates: Optional[str] = "face centers", coordinate_system: Optional[str] = "cartesian", distance_metric: Optional[str] = "minkowski", reconstruct: bool = False, @@ -433,7 +433,7 @@ class BallTree: def __init__( self, grid, - coordinates: Optional[str] = "nodes", + coordinates: Optional[str] = "face centers", coordinate_system: Optional[str] = "spherical", distance_metric: Optional[str] = "haversine", reconstruct: bool = False, diff --git a/uxarray/grid/slice.py b/uxarray/grid/slice.py index 94e8e0eb8..4257f05e1 100644 --- a/uxarray/grid/slice.py +++ b/uxarray/grid/slice.py @@ -107,22 +107,29 @@ def _slice_face_indices( face_indices = indices + # TODO: stop using .values, use Xarray directly # nodes of each face (inclusive) node_indices = np.unique(grid.face_node_connectivity.values[face_indices].ravel()) node_indices = node_indices[node_indices != INT_FILL_VALUE] - # edges of each face (inclusive) - edge_indices = np.unique(grid.face_edge_connectivity.values[face_indices].ravel()) - edge_indices = edge_indices[edge_indices != INT_FILL_VALUE] - # index original dataset to obtain a 'subgrid' ds = ds.isel(n_node=node_indices) ds = ds.isel(n_face=face_indices) - ds = ds.isel(n_edge=edge_indices) + + # Only slice edge dimension if we already have the connectivity + if "face_edge_connectivity" in grid._ds: + # TODO: add a warning here? + edge_indices = np.unique( + grid.face_edge_connectivity.values[face_indices].ravel() + ) + edge_indices = edge_indices[edge_indices != INT_FILL_VALUE] + ds = ds.isel(n_edge=edge_indices) + ds["subgrid_edge_indices"] = xr.DataArray(edge_indices, dims=["n_edge"]) + else: + edge_indices = None ds["subgrid_node_indices"] = xr.DataArray(node_indices, dims=["n_node"]) ds["subgrid_face_indices"] = xr.DataArray(face_indices, dims=["n_face"]) - ds["subgrid_edge_indices"] = xr.DataArray(edge_indices, dims=["n_edge"]) # mapping to update existing connectivity node_indices_dict = { @@ -152,9 +159,12 @@ def _slice_face_indices( index_types = { "face": face_indices, - "edge": edge_indices, "node": node_indices, } + + if edge_indices is not None: + index_types["edge"] = edge_indices + if isinstance(inverse_indices, bool): inverse_indices_ds["face"] = face_indices else: diff --git a/uxarray/grid/utils.py b/uxarray/grid/utils.py index d9a6621c9..3dca814d4 100644 --- a/uxarray/grid/utils.py +++ b/uxarray/grid/utils.py @@ -3,6 +3,28 @@ from numba import njit +import functools +from typing import Callable +import dask.array as da + + +def load_on_access(func: Callable) -> Callable: + """Decorator that loads a Xarray DataArray backed by Dask into + memory when the user first invokes a property, based on the _load_on_access _load_on_access flag. + """ + + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + data = func(self, *args, **kwargs) + # Check if loading is requested and if the data uses Dask + if self._load_on_access and isinstance(data.data, da.Array): + # Load the data into memory + data.load() + return data + return data + + return wrapper + @njit(cache=True) def _small_angle_of_2_vectors(u, v): diff --git a/uxarray/io/_esmf.py b/uxarray/io/_esmf.py index f16abc657..a23d785f1 100644 --- a/uxarray/io/_esmf.py +++ b/uxarray/io/_esmf.py @@ -6,6 +6,15 @@ from uxarray.conventions import ugrid +def _esmf_to_ugrid_dims(in_ds): + source_dims_dict = { + "nodeCount": ugrid.NODE_DIM, + "elementCount": ugrid.FACE_DIM, + "maxNodePElement": ugrid.N_MAX_FACE_NODES_DIM, + } + return source_dims_dict + + def _read_esmf(in_ds): """Reads in an Xarray dataset containing an ESMF formatted Grid dataset and encodes it in the UGRID conventions. @@ -38,32 +47,28 @@ def _read_esmf(in_ds): out_ds = xr.Dataset() - source_dims_dict = { - "nodeCount": ugrid.NODE_DIM, - "elementCount": ugrid.FACE_DIM, - "maxNodePElement": ugrid.N_MAX_FACE_NODES_DIM, - } + source_dims_dict = _esmf_to_ugrid_dims(in_ds) if in_ds["nodeCoords"].units == "degrees": # Spherical Coordinates (in degrees) - node_lon = in_ds["nodeCoords"].isel(coordDim=0).values + node_lon = in_ds["nodeCoords"].isel(coordDim=0) out_ds[ugrid.NODE_COORDINATES[0]] = xr.DataArray( node_lon, dims=[ugrid.NODE_DIM], attrs=ugrid.NODE_LON_ATTRS ) - node_lat = in_ds["nodeCoords"].isel(coordDim=1).values + node_lat = in_ds["nodeCoords"].isel(coordDim=1) out_ds[ugrid.NODE_COORDINATES[1]] = xr.DataArray( node_lat, dims=[ugrid.NODE_DIM], attrs=ugrid.NODE_LAT_ATTRS ) if "centerCoords" in in_ds: # parse center coords (face centers) if available - face_lon = in_ds["centerCoords"].isel(coordDim=0).values + face_lon = in_ds["centerCoords"].isel(coordDim=0) out_ds[ugrid.FACE_COORDINATES[0]] = xr.DataArray( face_lon, dims=[ugrid.FACE_DIM], attrs=ugrid.FACE_LON_ATTRS ) - face_lat = in_ds["centerCoords"].isel(coordDim=1).values + face_lat = in_ds["centerCoords"].isel(coordDim=1) out_ds[ugrid.FACE_COORDINATES[1]] = xr.DataArray( face_lat, dims=[ugrid.FACE_DIM], attrs=ugrid.FACE_LAT_ATTRS ) @@ -73,7 +78,7 @@ def _read_esmf(in_ds): "Reading in ESMF grids with Cartesian coordinates not yet supported" ) - n_nodes_per_face = in_ds["numElementConn"].values.astype(INT_DTYPE) + n_nodes_per_face = in_ds["numElementConn"].astype(INT_DTYPE) out_ds["n_nodes_per_face"] = xr.DataArray( data=n_nodes_per_face, dims=ugrid.N_NODES_PER_FACE_DIMS, @@ -86,7 +91,7 @@ def _read_esmf(in_ds): # assume start index is 1 if one is not provided start_index = 1 - face_node_connectivity = in_ds["elementConn"].values.astype(INT_DTYPE) + face_node_connectivity = in_ds["elementConn"].astype(INT_DTYPE) for i, max_nodes in enumerate(n_nodes_per_face): # convert to zero index and standardize fill values diff --git a/uxarray/io/_exodus.py b/uxarray/io/_exodus.py index 8c91913fa..7c4fff213 100644 --- a/uxarray/io/_exodus.py +++ b/uxarray/io/_exodus.py @@ -34,8 +34,8 @@ def _read_exodus(ext_ds): if ext_ds.sizes[dim] > max_face_nodes: max_face_nodes = ext_ds.sizes[dim] - # create an empty conn array for storing all blk face_nodes_data - conn = np.empty((0, max_face_nodes)) + # # create an empty conn array for storing all blk face_nodes_data + # conn = np.empty((0, max_face_nodes)) for key, value in ext_ds.variables.items(): if key == "qa_records": @@ -70,10 +70,10 @@ def _read_exodus(ext_ds): elif "connect" in key: # check if num face nodes is less than max. if value.data.shape[1] <= max_face_nodes: - conn = np.full( - (value.data.shape[1], max_face_nodes), 0, dtype=conn.dtype - ) - conn = value.data + # face_nodes = np.full( + # (value.data.shape[1], max_face_nodes), 0, dtype=conn.dtype + # ) + face_nodes = value else: raise RuntimeError("found face_nodes_dim greater than n_max_face_nodes") @@ -88,7 +88,7 @@ def _read_exodus(ext_ds): # standardize fill values and data type face nodes face_nodes = _replace_fill_values( - grid_var=conn[:] - 1, + grid_var=face_nodes[:] - 1, original_fill=-1, new_fill=INT_FILL_VALUE, new_dtype=INT_DTYPE, diff --git a/uxarray/io/_fesom2.py b/uxarray/io/_fesom2.py index 4e9a30330..24503644b 100644 --- a/uxarray/io/_fesom2.py +++ b/uxarray/io/_fesom2.py @@ -218,8 +218,8 @@ def _read_fesom2_netcdf(in_ds): source_dims_dict = {"ncells": "n_face"} ugrid_ds = xr.Dataset() - node_lon = in_ds["lon"].data - node_lat = in_ds["lat"].data + node_lon = in_ds["lon"] + node_lat = in_ds["lat"] ugrid_ds["node_lon"] = xr.DataArray( data=node_lon, dims=ugrid.NODE_DIM, attrs=ugrid.NODE_LON_ATTRS @@ -228,7 +228,7 @@ def _read_fesom2_netcdf(in_ds): data=node_lat, dims=ugrid.NODE_DIM, attrs=ugrid.NODE_LAT_ATTRS ) - face_node_connectivity = in_ds["triag_nodes"].data - 1 + face_node_connectivity = in_ds["triag_nodes"] - 1 ugrid_ds["face_node_connectivity"] = xr.DataArray( data=face_node_connectivity, diff --git a/uxarray/io/_icon.py b/uxarray/io/_icon.py index 01ed891f2..755c43768 100644 --- a/uxarray/io/_icon.py +++ b/uxarray/io/_icon.py @@ -3,16 +3,21 @@ import numpy as np +def _icon_to_ugrid_dims(in_ds): + source_dims_dict = {"vertex": "n_node", "edge": "n_edge", "cell": "n_face"} + return source_dims_dict + + def _primal_to_ugrid(in_ds, out_ds): """Encodes the Primal Mesh of an ICON Grid into the UGRID conventions.""" - source_dims_dict = {"vertex": "n_node", "edge": "n_edge", "cell": "n_face"} + source_dims_dict = _icon_to_ugrid_dims(in_ds) # rename dimensions to match ugrid conventions in_ds = in_ds.rename_dims(source_dims_dict) # node coordinates - node_lon = np.rad2deg(in_ds["vlon"]) - node_lat = np.rad2deg(in_ds["vlat"]) + node_lon = 180.0 * in_ds["vlon"] / np.pi + node_lat = 180.0 * in_ds["vlat"] / np.pi out_ds["node_lon"] = xr.DataArray( data=node_lon, dims=ugrid.NODE_DIM, attrs=ugrid.NODE_LON_ATTRS @@ -22,8 +27,8 @@ def _primal_to_ugrid(in_ds, out_ds): ) # edge coordinates - edge_lon = np.rad2deg(in_ds["elon"]) - edge_lat = np.rad2deg(in_ds["elat"]) + edge_lon = 180.0 * in_ds["elon"] / np.pi + edge_lat = 180.0 * in_ds["elat"] / np.pi out_ds["edge_lon"] = xr.DataArray( data=edge_lon, dims=ugrid.EDGE_DIM, attrs=ugrid.EDGE_LON_ATTRS @@ -33,8 +38,8 @@ def _primal_to_ugrid(in_ds, out_ds): ) # face coordinates - face_lon = np.rad2deg(in_ds["clon"]) - face_lat = np.rad2deg(in_ds["clat"]) + face_lon = 180.0 * in_ds["clon"] / np.pi + face_lat = 180.0 * in_ds["clat"] / np.pi out_ds["face_lon"] = xr.DataArray( data=face_lon, dims=ugrid.FACE_DIM, attrs=ugrid.FACE_LON_ATTRS diff --git a/uxarray/io/_mpas.py b/uxarray/io/_mpas.py index 539c596cf..26caa5967 100644 --- a/uxarray/io/_mpas.py +++ b/uxarray/io/_mpas.py @@ -5,18 +5,29 @@ from uxarray.conventions import ugrid, descriptors -def _primal_to_ugrid(in_ds, out_ds): - """Encodes the MPAS Primal-Mesh in the UGRID conventions. +def _mpas_to_ugrid_dims(in_ds, primal=True): + """TODO:""" + source_dims_dict = {} + if primal: + source_dims_dict["nVertices"] = ugrid.NODE_DIM + source_dims_dict[in_ds["verticesOnCell"].dims[0]] = ugrid.FACE_DIM + source_dims_dict[in_ds["verticesOnCell"].dims[1]] = ugrid.N_MAX_FACE_NODES_DIM - Parameters - ---------- - in_ds : xarray.Dataset - Input MPAS dataset - out_ds : xarray.Dataset - Output dataset where the MPAS Primal-Mesh is encoded in the UGRID - conventions - """ + if "verticesOnEdge" in in_ds: + source_dims_dict[in_ds["verticesOnEdge"].dims[0]] = "n_edge" + else: + source_dims_dict[in_ds["latCell"].dims[0]] = ugrid.NODE_DIM + source_dims_dict[in_ds["cellsOnVertex"].dims[0]] = ugrid.FACE_DIM + source_dims_dict[in_ds["cellsOnVertex"].dims[1]] = ugrid.N_MAX_FACE_NODES_DIM + if "cellsOnEdge" in in_ds: + source_dims_dict[in_ds["cellsOnEdge"].dims[0]] = "n_edge" + + return source_dims_dict + + +def _primal_to_ugrid(in_ds, out_ds): + """Encodes the MPAS Primal-Mesh in the UGRID conventions.""" source_dims_dict = {} if "lonVertex" in in_ds: @@ -77,17 +88,7 @@ def _primal_to_ugrid(in_ds, out_ds): def _dual_to_ugrid(in_ds, out_ds): - """Encodes the MPAS Dual-Mesh in the UGRID conventions. - - Parameters - ---------- - in_ds : xarray.Dataset - Input MPAS dataset - out_ds : xarray.Dataset - Output dataset where the MPAS Dual-Mesh is encoded in the UGRID - conventions - """ - + """Encodes the MPAS Dual-Mesh in the UGRID conventions.""" source_dims_dict = {} if "lonCell" in in_ds: @@ -142,491 +143,366 @@ def _dual_to_ugrid(in_ds, out_ds): def _parse_node_latlon_coords(in_ds, out_ds, mesh_type): - """Parses cartesian corner node coordinates for either the Primal or Dual - Mesh.""" + """Parses cartesian corner node coordinates for either the Primal or Dual Mesh.""" if mesh_type == "primal": - node_lon = np.rad2deg(in_ds["lonVertex"].values) - node_lat = np.rad2deg(in_ds["latVertex"].values) + node_lon = 180.0 * in_ds["lonVertex"] / np.pi + node_lat = 180.0 * in_ds["latVertex"] / np.pi + + # Ensure correct dimension name + node_lon = node_lon.rename({"nVertices": ugrid.NODE_DIM}) + node_lat = node_lat.rename({"nVertices": ugrid.NODE_DIM}) else: - node_lon = np.rad2deg(in_ds["lonCell"].values) - node_lat = np.rad2deg(in_ds["latCell"].values) + node_lon = 180.0 * in_ds["lonCell"] / np.pi + node_lat = 180.0 * in_ds["latCell"] / np.pi - out_ds["node_lon"] = xr.DataArray( - node_lon, dims=[ugrid.NODE_DIM], attrs=ugrid.NODE_LON_ATTRS - ) + # Ensure correct dimension name + node_lon = node_lon.rename({"nCells": ugrid.NODE_DIM}) + node_lat = node_lat.rename({"nCells": ugrid.NODE_DIM}) - out_ds["node_lat"] = xr.DataArray( - node_lat, dims=[ugrid.NODE_DIM], attrs=ugrid.NODE_LAT_ATTRS - ) + out_ds["node_lon"] = node_lon.assign_attrs(ugrid.NODE_LON_ATTRS) + out_ds["node_lat"] = node_lat.assign_attrs(ugrid.NODE_LAT_ATTRS) def _parse_node_xyz_coords(in_ds, out_ds, mesh_type): - """Parses cartesian corner node coordinates for either the Primal or Dual - Mesh.""" + """Parses cartesian corner node coordinates for either the Primal or Dual Mesh.""" if mesh_type == "primal": - node_x = in_ds["xVertex"].values - node_y = in_ds["yVertex"].values - node_z = in_ds["zVertex"].values + node_x = in_ds["xVertex"] + node_y = in_ds["yVertex"] + node_z = in_ds["zVertex"] + + # Ensure correct dimension name + node_x = node_x.rename({"nVertices": ugrid.NODE_DIM}) + node_y = node_y.rename({"nVertices": ugrid.NODE_DIM}) + node_z = node_z.rename({"nVertices": ugrid.NODE_DIM}) else: - # corners of dual-mesh cells (artesian) - node_x = in_ds["xCell"].values - node_y = in_ds["yCell"].values - node_z = in_ds["zCell"].values + node_x = in_ds["xCell"] + node_y = in_ds["yCell"] + node_z = in_ds["zCell"] - out_ds["node_x"] = xr.DataArray( - data=node_x, dims=[ugrid.NODE_DIM], attrs=ugrid.NODE_X_ATTRS - ) + # Ensure correct dimension name + node_x = node_x.rename({"nCells": ugrid.NODE_DIM}) + node_y = node_y.rename({"nCells": ugrid.NODE_DIM}) + node_z = node_z.rename({"nCells": ugrid.NODE_DIM}) - out_ds["node_y"] = xr.DataArray( - data=node_y, dims=[ugrid.NODE_DIM], attrs=ugrid.NODE_Y_ATTRS - ) - - out_ds["node_z"] = xr.DataArray( - data=node_z, dims=[ugrid.NODE_DIM], attrs=ugrid.NODE_Z_ATTRS - ) + out_ds["node_x"] = node_x.assign_attrs(ugrid.NODE_X_ATTRS) + out_ds["node_y"] = node_y.assign_attrs(ugrid.NODE_Y_ATTRS) + out_ds["node_z"] = node_z.assign_attrs(ugrid.NODE_Z_ATTRS) def _parse_face_latlon_coords(in_ds, out_ds, mesh_type): - """Parses latlon face center coordinates for either the Primal or Dual - Mesh.""" + """Parses latlon face center coordinates for either the Primal or Dual Mesh.""" if mesh_type == "primal": - face_lon = np.rad2deg(in_ds["lonCell"].values) - face_lat = np.rad2deg(in_ds["latCell"].values) + face_lon = 180.0 * in_ds["lonCell"] / np.pi + face_lat = 180.0 * in_ds["latCell"] / np.pi + + # Ensure correct dimension name + face_lon = face_lon.rename({"nCells": ugrid.FACE_DIM}) + face_lat = face_lat.rename({"nCells": ugrid.FACE_DIM}) else: - face_lon = np.rad2deg(in_ds["lonVertex"].values) - face_lat = np.rad2deg(in_ds["latVertex"].values) + face_lon = 180.0 * in_ds["lonVertex"] / np.pi + face_lat = 180.0 * in_ds["latVertex"] / np.pi - out_ds["face_lon"] = xr.DataArray( - face_lon, dims=[ugrid.FACE_DIM], attrs=ugrid.FACE_LON_ATTRS - ) + # Ensure correct dimension name + face_lon = face_lon.rename({"nVertices": ugrid.FACE_DIM}) + face_lat = face_lat.rename({"nVertices": ugrid.FACE_DIM}) - out_ds["face_lat"] = xr.DataArray( - face_lat, dims=[ugrid.FACE_DIM], attrs=ugrid.FACE_LAT_ATTRS - ) + out_ds["face_lon"] = face_lon.assign_attrs(ugrid.FACE_LON_ATTRS) + out_ds["face_lat"] = face_lat.assign_attrs(ugrid.FACE_LAT_ATTRS) def _parse_face_xyz_coords(in_ds, out_ds, mesh_type): - """Parses cartesian face center coordinates for either the Primal or Dual - Mesh.""" + """Parses cartesian face center coordinates for either the Primal or Dual Mesh.""" if mesh_type == "primal": - face_x = in_ds["xCell"].values - face_y = in_ds["yCell"].values - face_z = in_ds["zCell"].values + face_x = in_ds["xCell"] + face_y = in_ds["yCell"] + face_z = in_ds["zCell"] + + # Ensure correct dimension name + face_x = face_x.rename({"nCells": ugrid.FACE_DIM}) + face_y = face_y.rename({"nCells": ugrid.FACE_DIM}) + face_z = face_z.rename({"nCells": ugrid.FACE_DIM}) else: - # centers of dual-mesh cells (in degrees) - face_x = in_ds["xVertex"].values - face_y = in_ds["yVertex"].values - face_z = in_ds["zVertex"].values + face_x = in_ds["xVertex"] + face_y = in_ds["yVertex"] + face_z = in_ds["zVertex"] - out_ds["face_x"] = xr.DataArray( - data=face_x, dims=[ugrid.FACE_DIM], attrs=ugrid.FACE_X_ATTRS - ) + # Ensure correct dimension name + face_x = face_x.rename({"nVertices": ugrid.FACE_DIM}) + face_y = face_y.rename({"nVertices": ugrid.FACE_DIM}) + face_z = face_z.rename({"nVertices": ugrid.FACE_DIM}) - out_ds["face_y"] = xr.DataArray( - data=face_y, dims=[ugrid.FACE_DIM], attrs=ugrid.FACE_Y_ATTRS - ) - - out_ds["face_z"] = xr.DataArray( - data=face_z, dims=[ugrid.FACE_DIM], attrs=ugrid.FACE_Z_ATTRS - ) + out_ds["face_x"] = face_x.assign_attrs(ugrid.FACE_X_ATTRS) + out_ds["face_y"] = face_y.assign_attrs(ugrid.FACE_Y_ATTRS) + out_ds["face_z"] = face_z.assign_attrs(ugrid.FACE_Z_ATTRS) def _parse_edge_latlon_coords(in_ds, out_ds, mesh_type): - """Parses latlon edge node coordinates for either the Primal or Dual - Mesh.""" - - edge_lon = np.rad2deg(in_ds["lonEdge"].values) - edge_lat = np.rad2deg(in_ds["latEdge"].values) + """Parses latlon edge node coordinates.""" + edge_lon = 180.0 * in_ds["lonEdge"] / np.pi + edge_lat = 180.0 * in_ds["latEdge"] / np.pi - out_ds["edge_lon"] = xr.DataArray( - edge_lon, dims=[ugrid.EDGE_DIM], attrs=ugrid.EDGE_LON_ATTRS - ) + # Ensure correct dimension name + edge_lon = edge_lon.rename({"nEdges": ugrid.EDGE_DIM}) + edge_lat = edge_lat.rename({"nEdges": ugrid.EDGE_DIM}) - out_ds["edge_lat"] = xr.DataArray( - edge_lat, dims=[ugrid.EDGE_DIM], attrs=ugrid.EDGE_LAT_ATTRS - ) + out_ds["edge_lon"] = edge_lon.assign_attrs(ugrid.EDGE_LON_ATTRS) + out_ds["edge_lat"] = edge_lat.assign_attrs(ugrid.EDGE_LAT_ATTRS) def _parse_edge_xyz_coords(in_ds, out_ds, mesh_type): - """Parses cartesian edge node coordinates for either the Primal or Dual - Mesh.""" - edge_x = in_ds["xEdge"].values - edge_y = in_ds["yEdge"].values - edge_z = in_ds["zEdge"].values - - out_ds["edge_x"] = xr.DataArray( - data=edge_x, dims=[ugrid.EDGE_DIM], attrs=ugrid.EDGE_X_ATTRS - ) + """Parses cartesian edge node coordinates.""" + edge_x = in_ds["xEdge"] + edge_y = in_ds["yEdge"] + edge_z = in_ds["zEdge"] - out_ds["edge_y"] = xr.DataArray( - data=edge_y, dims=[ugrid.EDGE_DIM], attrs=ugrid.EDGE_Y_ATTRS - ) + # Ensure correct dimension name + edge_x = edge_x.rename({"nEdges": ugrid.EDGE_DIM}) + edge_y = edge_y.rename({"nEdges": ugrid.EDGE_DIM}) + edge_z = edge_z.rename({"nEdges": ugrid.EDGE_DIM}) - out_ds["edge_z"] = xr.DataArray( - data=edge_z, dims=[ugrid.EDGE_DIM], attrs=ugrid.EDGE_Z_ATTRS - ) + out_ds["edge_x"] = edge_x.assign_attrs(ugrid.EDGE_X_ATTRS) + out_ds["edge_y"] = edge_y.assign_attrs(ugrid.EDGE_Y_ATTRS) + out_ds["edge_z"] = edge_z.assign_attrs(ugrid.EDGE_Z_ATTRS) def _parse_face_nodes(in_ds, out_ds, mesh_type): """Parses face node connectivity for either the Primal or Dual Mesh.""" if mesh_type == "primal": - verticesOnCell = np.array(in_ds["verticesOnCell"].values, dtype=INT_DTYPE) + verticesOnCell = in_ds["verticesOnCell"].astype(INT_DTYPE) + nEdgesOnCell = in_ds["nEdgesOnCell"].astype(INT_DTYPE) - nEdgesOnCell = np.array(in_ds["nEdgesOnCell"].values, dtype=INT_DTYPE) - - # replace padded values with fill values + # Replace padded values with fill values verticesOnCell = _replace_padding(verticesOnCell, nEdgesOnCell) - # replace missing/zero values with fill values + # Replace missing/zero values with fill values verticesOnCell = _replace_zeros(verticesOnCell) - # convert to zero-indexed - verticesOnCell = _to_zero_index(verticesOnCell) - - face_nodes = verticesOnCell + # Convert to zero-indexed + face_nodes = _to_zero_index(verticesOnCell) else: - cellsOnVertex = np.array(in_ds["cellsOnVertex"].values, dtype=INT_DTYPE) + cellsOnVertex = in_ds["cellsOnVertex"].astype(INT_DTYPE) - # replace missing/zero values with fill values + # Replace missing/zero values with fill values cellsOnVertex = _replace_zeros(cellsOnVertex) - # convert to zero-indexed - cellsOnVertex = _to_zero_index(cellsOnVertex) + # Convert to zero-indexed + face_nodes = _to_zero_index(cellsOnVertex) - face_nodes = cellsOnVertex - - out_ds["face_node_connectivity"] = xr.DataArray( - data=face_nodes, - dims=ugrid.FACE_NODE_CONNECTIVITY_DIMS, - attrs=ugrid.FACE_NODE_CONNECTIVITY_ATTRS, - ) + out_ds["face_node_connectivity"] = face_nodes.assign_attrs( + ugrid.FACE_NODE_CONNECTIVITY_ATTRS + ).rename(dict(zip(face_nodes.dims, ugrid.FACE_NODE_CONNECTIVITY_DIMS))) def _parse_edge_nodes(in_ds, out_ds, mesh_type): """Parses edge node connectivity for either the Primal or Dual Mesh.""" if mesh_type == "primal": - # vertex indices that saddle a given edge - verticesOnEdge = np.array(in_ds["verticesOnEdge"].values, dtype=INT_DTYPE) + verticesOnEdge = in_ds["verticesOnEdge"].astype(INT_DTYPE) - # replace missing/zero values with fill value + # Replace missing/zero values with fill values verticesOnEdge = _replace_zeros(verticesOnEdge) - # convert to zero-indexed - verticesOnEdge = _to_zero_index(verticesOnEdge) + # Convert to zero-indexed + edge_nodes = _to_zero_index(verticesOnEdge) - edge_nodes = verticesOnEdge else: - # vertex indices that saddle a given edge - cellsOnEdge = np.array(in_ds["cellsOnEdge"].values, dtype=INT_DTYPE) + cellsOnEdge = in_ds["cellsOnEdge"].astype(INT_DTYPE) - # replace missing/zero values with fill values + # Replace missing/zero values with fill values cellsOnEdge = _replace_zeros(cellsOnEdge) - # convert to zero-indexed - cellsOnEdge = _to_zero_index(cellsOnEdge) + # Convert to zero-indexed + edge_nodes = _to_zero_index(cellsOnEdge) - edge_nodes = cellsOnEdge - - out_ds["edge_node_connectivity"] = xr.DataArray( - data=edge_nodes, - dims=ugrid.EDGE_NODE_CONNECTIVITY_DIMS, - attrs=ugrid.EDGE_NODE_CONNECTIVITY_ATTRS, - ) + out_ds["edge_node_connectivity"] = edge_nodes.assign_attrs( + ugrid.EDGE_NODE_CONNECTIVITY_ATTRS + ).rename(dict(zip(edge_nodes.dims, ugrid.EDGE_NODE_CONNECTIVITY_DIMS))) def _parse_node_faces(in_ds, out_ds, mesh_type): """Parses node face connectivity for either the Primal or Dual Mesh.""" if mesh_type == "primal": - cellsOnVertex = np.array(in_ds["cellsOnVertex"].values, dtype=INT_DTYPE) + cellsOnVertex = in_ds["cellsOnVertex"].astype(INT_DTYPE) - # replace missing/zero values with fill values + # Replace missing/zero values with fill values cellsOnVertex = _replace_zeros(cellsOnVertex) - # convert to zero-indexed - cellsOnVertex = _to_zero_index(cellsOnVertex) - - node_faces = cellsOnVertex + # Convert to zero-indexed + node_faces = _to_zero_index(cellsOnVertex) else: - verticesOnCell = np.array(in_ds["verticesOnCell"].values, dtype=INT_DTYPE) - - nEdgesOnCell = np.array(in_ds["nEdgesOnCell"].values, dtype=INT_DTYPE) + verticesOnCell = in_ds["verticesOnCell"].astype(INT_DTYPE) + nEdgesOnCell = in_ds["nEdgesOnCell"].astype(INT_DTYPE) - # replace padded values with fill values + # Replace padded values with fill values verticesOnCell = _replace_padding(verticesOnCell, nEdgesOnCell) - # replace missing/zero values with fill values + # Replace missing/zero values with fill values verticesOnCell = _replace_zeros(verticesOnCell) - # convert to zero-indexed - verticesOnCell = _to_zero_index(verticesOnCell) + # Convert to zero-indexed + node_faces = _to_zero_index(verticesOnCell) - node_faces = verticesOnCell - - out_ds["node_face_connectivity"] = xr.DataArray( - data=node_faces, - dims=ugrid.NODE_FACE_CONNECTIVITY_DIMS, - attrs=ugrid.NODE_FACE_CONNECTIVITY_ATTRS, - ) + out_ds["node_face_connectivity"] = node_faces.assign_attrs( + ugrid.NODE_FACE_CONNECTIVITY_ATTRS + ).rename(dict(zip(node_faces.dims, ugrid.NODE_FACE_CONNECTIVITY_DIMS))) def _parse_face_edges(in_ds, out_ds, mesh_type): """Parses face edge connectivity for either the Primal or Dual Mesh.""" if mesh_type == "primal": - edgesOnCell = np.array(in_ds["edgesOnCell"].values, dtype=INT_DTYPE) - - nEdgesOnCell = np.array(in_ds["nEdgesOnCell"].values, dtype=INT_DTYPE) + edgesOnCell = in_ds["edgesOnCell"].astype(INT_DTYPE) + nEdgesOnCell = in_ds["nEdgesOnCell"].astype(INT_DTYPE) - # replace padded values with fill values + # Replace padded values with fill values edgesOnCell = _replace_padding(edgesOnCell, nEdgesOnCell) - # replace missing/zero values with fill values + # Replace missing/zero values with fill values edgesOnCell = _replace_zeros(edgesOnCell) - # convert to zero-indexed - edgesOnCell = _to_zero_index(edgesOnCell) - - face_edges = edgesOnCell + # Convert to zero-indexed + face_edges = _to_zero_index(edgesOnCell) else: - edgesOnVertex = np.array(in_ds["edgesOnVertex"].values, dtype=INT_DTYPE) + edgesOnVertex = in_ds["edgesOnVertex"].astype(INT_DTYPE) - # replace missing/zero values with fill values + # Replace missing/zero values with fill values edgesOnVertex = _replace_zeros(edgesOnVertex) - # convert to zero-indexed - edgesOnVertex = _to_zero_index(edgesOnVertex) - - face_edges = edgesOnVertex + # Convert to zero-indexed + face_edges = _to_zero_index(edgesOnVertex) - out_ds["face_edge_connectivity"] = xr.DataArray( - data=face_edges, - dims=ugrid.FACE_EDGE_CONNECTIVITY_DIMS, - attrs=ugrid.FACE_EDGE_CONNECTIVITY_ATTRS, - ) + out_ds["face_edge_connectivity"] = face_edges.assign_attrs( + ugrid.FACE_EDGE_CONNECTIVITY_ATTRS + ).rename(dict(zip(face_edges.dims, ugrid.FACE_EDGE_CONNECTIVITY_DIMS))) def _parse_edge_faces(in_ds, out_ds, mesh_type): - """Parses edge node connectivity for either the Primal or Dual Mesh.""" + """Parses edge face connectivity for either the Primal or Dual Mesh.""" if mesh_type == "primal": - # vertex indices that saddle a given edge - cellsOnEdge = np.array(in_ds["cellsOnEdge"].values, dtype=INT_DTYPE) + cellsOnEdge = in_ds["cellsOnEdge"].astype(INT_DTYPE) - # replace missing/zero values with fill values + # Replace missing/zero values with fill values cellsOnEdge = _replace_zeros(cellsOnEdge) - # convert to zero-indexed - cellsOnEdge = _to_zero_index(cellsOnEdge) - - edge_faces = cellsOnEdge + # Convert to zero-indexed + edge_faces = _to_zero_index(cellsOnEdge) else: - # vertex indices that saddle a given edge - verticesOnEdge = np.array(in_ds["verticesOnEdge"].values, dtype=INT_DTYPE) + verticesOnEdge = in_ds["verticesOnEdge"].astype(INT_DTYPE) - # replace missing/zero values with fill value + # Replace missing/zero values with fill values verticesOnEdge = _replace_zeros(verticesOnEdge) - # convert to zero-indexed - verticesOnEdge = _to_zero_index(verticesOnEdge) - - edge_faces = verticesOnEdge + # Convert to zero-indexed + edge_faces = _to_zero_index(verticesOnEdge) - out_ds["edge_face_connectivity"] = xr.DataArray( - data=edge_faces, - dims=ugrid.EDGE_FACE_CONNECTIVITY_DIMS, - attrs=ugrid.EDGE_FACE_CONNECTIVITY_ATTRS, - ) + out_ds["edge_face_connectivity"] = edge_faces.assign_attrs( + ugrid.EDGE_FACE_CONNECTIVITY_ATTRS + ).rename(dict(zip(edge_faces.dims, ugrid.EDGE_FACE_CONNECTIVITY_DIMS))) def _parse_edge_node_distances(in_ds, out_ds): """Parses ``edge_node_distances``""" - edge_node_distances = in_ds["dvEdge"].values + edge_node_distances = in_ds["dvEdge"] - out_ds["edge_node_distances"] = xr.DataArray( - data=edge_node_distances, - dims=descriptors.EDGE_NODE_DISTANCES_DIMS, - attrs=descriptors.EDGE_NODE_DISTANCES_ATTRS, - ) + out_ds["edge_node_distances"] = edge_node_distances.assign_attrs( + descriptors.EDGE_NODE_DISTANCES_ATTRS + ).rename({"nEdges": ugrid.EDGE_DIM}) def _parse_edge_face_distances(in_ds, out_ds): """Parses ``edge_face_distances``""" - edge_face_distances = in_ds["dcEdge"].values + edge_face_distances = in_ds["dcEdge"] - out_ds["edge_face_distances"] = xr.DataArray( - data=edge_face_distances, - dims=descriptors.EDGE_FACE_DISTANCES_DIMS, - attrs=descriptors.EDGE_FACE_DISTANCES_ATTRS, - ) + out_ds["edge_face_distances"] = edge_face_distances.assign_attrs( + descriptors.EDGE_FACE_DISTANCES_ATTRS + ).rename({"nEdges": ugrid.EDGE_DIM}) def _parse_global_attrs(in_ds, out_ds): - """Helper to parse MPAS global attributes. - - Parameters - ---------- - in_ds : xarray.Dataset - Input MPAS dataset - out_ds : xarray.Dataset - Output UGRID dataset with parsed global attributes - """ - + """Helper to parse MPAS global attributes.""" out_ds.attrs = in_ds.attrs def _parse_face_faces(in_ds, out_ds): """Parses face-face connectivity for Primal Mesh.""" - cellsOnCell = np.array(in_ds["cellsOnCell"].values, dtype=INT_DTYPE) - nEdgesOnCell = np.array(in_ds["nEdgesOnCell"].values, dtype=INT_DTYPE) + cellsOnCell = in_ds["cellsOnCell"].astype(INT_DTYPE) + nEdgesOnCell = in_ds["nEdgesOnCell"].astype(INT_DTYPE) - # replace padded values with fill values + # Replace padded values with fill values cellsOnCell = _replace_padding(cellsOnCell, nEdgesOnCell) - # replace missing/zero values with fill values + # Replace missing/zero values with fill values cellsOnCell = _replace_zeros(cellsOnCell) - # make zero-indexed - cellsOnCell = _to_zero_index(cellsOnCell) + # Convert to zero-indexed + face_face_connectivity = _to_zero_index(cellsOnCell) - face_face_connectivity = cellsOnCell - - out_ds["face_face_connectivity"] = xr.DataArray( - data=face_face_connectivity, - dims=ugrid.FACE_FACE_CONNECTIVITY_DIMS, - attrs=ugrid.FACE_FACE_CONNECTIVITY_ATTRS, - ) + out_ds["face_face_connectivity"] = face_face_connectivity.assign_attrs( + ugrid.FACE_FACE_CONNECTIVITY_ATTRS + ).rename(dict(zip(face_face_connectivity.dims, ugrid.FACE_FACE_CONNECTIVITY_DIMS))) def _parse_face_areas(in_ds, out_ds, mesh_type): """Parses the face area for either a primal or dual grid.""" - if mesh_type == "primal": - face_area = in_ds["areaCell"].data + face_area = in_ds["areaCell"] else: - face_area = in_ds["areaTriangle"].data + face_area = in_ds["areaTriangle"] - out_ds["face_areas"] = xr.DataArray( - data=face_area, - dims=descriptors.FACE_AREAS_DIMS, - attrs=descriptors.FACE_AREAS_ATTRS, + out_ds["face_areas"] = face_area.assign_attrs(descriptors.FACE_AREAS_ATTRS).rename( + {face_area.dims[0]: ugrid.FACE_DIM} ) def _parse_boundary_node_indices(in_ds, out_ds, mesh_type): - """Parses the face area for either a primal or dual grid.""" - - boundary_node_mask = in_ds["boundaryVertex"].values - boundary_node_indices = np.argwhere(boundary_node_mask).flatten().astype(INT_DTYPE) - - out_ds["boundary_node_indices"] = xr.DataArray( - data=boundary_node_indices, - dims=[ - "n_boundary_nodes", - ], + """Parses the boundary node indices.""" + boundary_node_mask = in_ds["boundaryVertex"] + boundary_node_indices = boundary_node_mask.where(boundary_node_mask).dropna( + dim=boundary_node_mask.dims[0] ) + # Convert to integer indices + boundary_node_indices = boundary_node_indices.coords[ + boundary_node_indices.dims[0] + ].astype(INT_DTYPE) -def _replace_padding(verticesOnCell, nEdgesOnCell): - """Replaces the padded values in verticesOnCell defined by nEdgesOnCell - with a fill-value. + # Ensure zero-indexed + boundary_node_indices = boundary_node_indices - 1 - Parameters - ---------- - verticesOnCell : numpy.ndarray - Vertex indices that surround a given cell - - nEdgesOnCell : numpy.ndarray - Number of edges on a given cell - - Returns - ------- - verticesOnCell : numpy.ndarray - Vertex indices that surround a given cell with padded values replaced - by fill values, done in-place - """ - - # max vertices/edges per cell - maxEdges = verticesOnCell.shape[1] + out_ds["boundary_node_indices"] = boundary_node_indices.rename( + {"nVertices": "n_boundary_nodes"} + ) - # mask for non-padded values - mask = np.arange(maxEdges) < nEdgesOnCell[:, None] - # replace remaining padding or zeros with INT_FILL_VALUE - verticesOnCell[np.logical_not(mask)] = INT_FILL_VALUE +def _replace_padding(verticesOnCell, nEdgesOnCell): + """Replaces padded values in verticesOnCell with fill-value.""" + maxEdges = verticesOnCell.sizes[verticesOnCell.dims[1]] + edge_indices = xr.DataArray(np.arange(maxEdges), dims=[verticesOnCell.dims[1]]) + mask = edge_indices >= nEdgesOnCell + verticesOnCell = verticesOnCell.where(~mask, INT_FILL_VALUE) return verticesOnCell def _replace_zeros(grid_var): - """Replaces all instances of a zero (invalid/missing MPAS value) with a - fill value. - - Parameters - ---------- - grid_var : numpy.ndarray - Grid variable that may contain zeros that need to be replaced - - Returns - ------- - grid_var : numpy.ndarray - Grid variable with zero replaced by fill values, done in-place - """ - - # replace all zeros with INT_FILL_VALUE - grid_var[grid_var == 0] = INT_FILL_VALUE - + """Replaces zeros with fill-value.""" + grid_var = grid_var.where(grid_var != 0, INT_FILL_VALUE) return grid_var def _to_zero_index(grid_var): - """Given an input using that is one-indexed, subtracts one from all non- - fill value entries to convert to zero-indexed. - - Parameters - ---------- - grid_var : numpy.ndarray - Grid variable that is one-indexed - - Returns - ------- - grid_var : numpy.ndarray - Grid variable that is converted to zero-indexed, done in-place - """ - - # convert non-fill values to zero-indexed - grid_var[grid_var != INT_FILL_VALUE] -= 1 - + """Converts one-indexed data to zero-indexed.""" + grid_var = xr.where(grid_var != INT_FILL_VALUE, grid_var - 1, grid_var) return grid_var def _read_mpas(ext_ds, use_dual=False): - """Function to read in a MPAS Grid dataset and encode either the Primal or - Dual Mesh in the UGRID conventions. - - Adheres to the MPAS Mesh Specifications outlined in the following document: - https://mpas-dev.github.io/files/documents/MPAS-MeshSpec.pdf - - Parameters - ---------- - ext_ds : xarray.Dataset, required - MPAS datafile of interest - use_dual : bool, optional - Flag to select whether to encode the Dual-Mesh. Defaults to False - - Returns - ------- - ds : xarray.Dataset - UGRID dataset derived from inputted MPAS dataset - """ - - # empty dataset that will contain our encoded MPAS mesh + """Reads an MPAS Grid dataset and encodes either the Primal or Dual Mesh in the UGRID conventions.""" ds = xr.Dataset() - # convert dual-mesh to UGRID if use_dual: source_dim_map = _dual_to_ugrid(ext_ds, ds) - # convert primal-mesh to UGRID else: source_dim_map = _primal_to_ugrid(ext_ds, ds) diff --git a/uxarray/io/_scrip.py b/uxarray/io/_scrip.py index f6b874e93..2ced1a8e9 100644 --- a/uxarray/io/_scrip.py +++ b/uxarray/io/_scrip.py @@ -75,12 +75,15 @@ def _to_ugrid(in_ds, out_ds): # standardize fill values and data type face nodes face_nodes = _replace_fill_values( - unq_inv, original_fill=-1, new_fill=INT_FILL_VALUE, new_dtype=INT_DTYPE + xr.DataArray(data=unq_inv), + original_fill=-1, + new_fill=INT_FILL_VALUE, + new_dtype=INT_DTYPE, ) # set the face nodes data compiled in "connect" section out_ds["face_node_connectivity"] = xr.DataArray( - data=face_nodes, + data=face_nodes.data, dims=ugrid.FACE_NODE_CONNECTIVITY_DIMS, attrs=ugrid.FACE_NODE_CONNECTIVITY_ATTRS, ) diff --git a/uxarray/io/_ugrid.py b/uxarray/io/_ugrid.py index 3629a843b..c7be98cd3 100644 --- a/uxarray/io/_ugrid.py +++ b/uxarray/io/_ugrid.py @@ -128,27 +128,29 @@ def _standardize_connectivity(ds, conn_name): ---------- ds : xarray.Dataset Input Dataset + conn_name : str + The name of the connectivity variable to standardize. Returns - ---------- + ------- ds : xarray.Dataset - Input Dataset with correct index variables + Input Dataset with standardized connectivity variable. """ - # original connectivity - conn = ds[conn_name].values + # Extract the connectivity variable + conn = ds[conn_name] - # original fill value, if one exists - if "_FillValue" in ds[conn_name].attrs: - original_fv = ds[conn_name]._FillValue - elif np.isnan(ds[conn_name].values).any(): + # Determine the original fill value + if "_FillValue" in conn.attrs: + original_fv = conn.attrs["_FillValue"] + elif conn.isnull().any(): original_fv = np.nan else: original_fv = None - # if current dtype and fill value are not standardized + # Check if dtype or fill value needs to be standardized if conn.dtype != INT_DTYPE or original_fv != INT_FILL_VALUE: - # replace fill values and set correct dtype + # Replace fill values and set the correct dtype new_conn = _replace_fill_values( grid_var=conn, original_fill=original_fv, @@ -156,17 +158,34 @@ def _standardize_connectivity(ds, conn_name): new_dtype=INT_DTYPE, ) - if "start_index" in ds[conn_name].attrs: - new_conn[new_conn != INT_FILL_VALUE] -= INT_DTYPE(ds[conn_name].start_index) + # Check if 'start_index' attribute exists + if "start_index" in conn.attrs: + # Retrieve and convert 'start_index' + start_index = INT_DTYPE(conn.attrs["start_index"]) + + # Perform conditional subtraction using `.where()` + new_conn = new_conn.where( + new_conn == INT_FILL_VALUE, new_conn - start_index + ) else: + # Identify non-fill value indices fill_value_indices = new_conn != INT_FILL_VALUE - start_index = new_conn[fill_value_indices].min() - new_conn[fill_value_indices] -= INT_DTYPE(start_index) - # reassign data to use updated connectivity - ds[conn_name].data = new_conn + # Compute the minimum start_index from non-fill values + start_index = new_conn.where(fill_value_indices).min().item() + + # Convert start_index to the desired integer dtype + start_index = INT_DTYPE(start_index) + + # Perform conditional subtraction using `.where()` + new_conn = new_conn.where( + new_conn == INT_FILL_VALUE, new_conn - start_index + ) + + # Update the connectivity variable in the dataset + ds = ds.assign({conn_name: new_conn}) - # use new fill value + # Update the '_FillValue' attribute ds[conn_name].attrs["_FillValue"] = INT_FILL_VALUE return ds @@ -174,8 +193,6 @@ def _standardize_connectivity(ds, conn_name): def _is_ugrid(ds): """Check mesh topology and dimension.""" - # getkeys_filter_by_attribute(filepath, attr_name, attr_val) - # return type KeysView node_coords_dv = ds.filter_by_attrs(node_coordinates=lambda v: v is not None) face_conn_dv = ds.filter_by_attrs(face_node_connectivity=lambda v: v is not None) topo_dim_dv = ds.filter_by_attrs(topology_dimension=lambda v: v is not None) diff --git a/uxarray/io/utils.py b/uxarray/io/utils.py index 35af59158..dfcb93422 100644 --- a/uxarray/io/utils.py +++ b/uxarray/io/utils.py @@ -1,7 +1,11 @@ -from uxarray.io._ugrid import _is_ugrid import numpy as np import xarray as xr +from uxarray.io._ugrid import _is_ugrid, _read_ugrid +from uxarray.io._mpas import _mpas_to_ugrid_dims +from uxarray.io._icon import _icon_to_ugrid_dims +from uxarray.io._esmf import _esmf_to_ugrid_dims + def _parse_grid_type(dataset): """Checks input and contents to determine grid type. Supports detection of @@ -123,3 +127,17 @@ def _is_structured(dataset: xr.Dataset, tol: float = 1e-5) -> bool: print("Longitude coordinates are not regularly spaced.") return lat_regular and lon_regular, lon_name, lat_name + + +def _get_source_dims_dict(grid_ds, grid_spec): + if grid_spec == "MPAS": + return _mpas_to_ugrid_dims(grid_ds) + if grid_spec == "UGRID": + _, dim_dict = _read_ugrid(grid_ds) + return dim_dict + elif grid_spec == "ICON": + return _icon_to_ugrid_dims(grid_ds) + elif grid_spec == "ESMF": + return _esmf_to_ugrid_dims(grid_ds) + else: + return dict() diff --git a/uxarray/subset/dataarray_accessor.py b/uxarray/subset/dataarray_accessor.py index 1775baa68..f5f64d9f0 100644 --- a/uxarray/subset/dataarray_accessor.py +++ b/uxarray/subset/dataarray_accessor.py @@ -67,7 +67,7 @@ def bounding_circle( self, center_coord: Union[Tuple, List, np.ndarray], r: Union[float, int], - element: Optional[str] = "nodes", + element: Optional[str] = "face centers", inverse_indices: Union[List[str], Set[str], bool] = False, **kwargs, ): @@ -97,7 +97,7 @@ def nearest_neighbor( self, center_coord: Union[Tuple, List, np.ndarray], k: int, - element: Optional[str] = "nodes", + element: Optional[str] = "face centers", inverse_indices: Union[List[str], Set[str], bool] = False, **kwargs, ): diff --git a/uxarray/subset/grid_accessor.py b/uxarray/subset/grid_accessor.py index e42720db7..971ab477f 100644 --- a/uxarray/subset/grid_accessor.py +++ b/uxarray/subset/grid_accessor.py @@ -69,7 +69,7 @@ def bounding_circle( self, center_coord: Union[Tuple, List, np.ndarray], r: Union[float, int], - element: Optional[str] = "nodes", + element: Optional[str] = "face centers", inverse_indices: Union[List[str], Set[str], bool] = False, **kwargs, ): @@ -108,7 +108,7 @@ def nearest_neighbor( self, center_coord: Union[Tuple, List, np.ndarray], k: int, - element: Optional[str] = "nodes", + element: Optional[str] = "face centers", inverse_indices: Union[List[str], Set[str], bool] = False, **kwargs, ):