diff --git a/tests/test_regrid.py b/tests/test_regrid.py index 861b0124..9cfd1b6f 100644 --- a/tests/test_regrid.py +++ b/tests/test_regrid.py @@ -517,6 +517,18 @@ def test_unknown_variable(self): with pytest.raises(KeyError): regridder.horizontal("unknown", self.coarse_2d_ds) + def test_raises_error_if_axis_name_for_dim_cannot_be_determined(self): + ds = self.coarse_2d_ds.copy() + ds["lat"].attrs["standard_name"] = "latitude" + ds["lat"].attrs.pop("axis") + + regridder = regrid2.Regrid2Regridder(ds, self.fine_2d_ds) + + with pytest.raises( + ValueError, match="Could not determine axis name for dimension" + ): + regridder.horizontal("ts", ds) + @pytest.mark.filterwarnings("ignore:.*invalid value.*true_divide.*:RuntimeWarning") def test_regrid_input_mask(self): regridder = regrid2.Regrid2Regridder(self.coarse_2d_ds, self.fine_2d_ds) diff --git a/xcdat/regridder/regrid2.py b/xcdat/regridder/regrid2.py index 9d1f456e..1c49e934 100644 --- a/xcdat/regridder/regrid2.py +++ b/xcdat/regridder/regrid2.py @@ -3,7 +3,7 @@ import numpy as np import xarray as xr -from xcdat.axis import get_dim_keys +from xcdat.axis import CF_ATTR_MAP, get_dim_keys from xcdat.regridder.base import BaseRegridder, _preserve_bounds @@ -105,8 +105,6 @@ def horizontal(self, data_var: str, ds: xr.Dataset) -> xr.Dataset: ds, data_var, output_data, - dst_lat_bnds, - dst_lon_bnds, self._input_grid, self._output_grid, ) @@ -228,8 +226,6 @@ def _build_dataset( ds: xr.Dataset, data_var: str, output_data: np.ndarray, - dst_lat_bnds, - dst_lon_bnds, input_grid: xr.Dataset, output_grid: xr.Dataset, ) -> xr.Dataset: @@ -242,11 +238,13 @@ def _build_dataset( dim = str(dim) try: - axis_name = [x for x, y in ds.cf.axes.items() if dim in y][0] - except Exception: + axis_name = [ + cf_axis for cf_axis, dims in ds.cf.axes.items() if dim in dims + ][0] + except IndexError as e: raise ValueError( f"Could not determine axis name for dimension {dim}" - ) from None + ) from e if axis_name in ["X", "Y"]: output_coords[dim] = output_grid.cf[axis_name] @@ -566,12 +564,20 @@ def _get_dimension(input_data_var, cf_axis_name): def _get_bounds_ensure_dtype(ds, axis): - try: - name = ds.cf.bounds[axis][0] - except (KeyError, IndexError) as e: - raise RuntimeError(f"Could not determine {axis!r} bounds") from e - else: - bounds = ds[name] + cf_keys = CF_ATTR_MAP[axis].values() + + bounds = None + + for key in cf_keys: + try: + name = ds.cf.bounds[key][0] + except (KeyError, IndexError): + pass + else: + bounds = ds[name] + + if bounds is None: + raise RuntimeError(f"Could not determine {axis!r} bounds") if bounds.dtype != np.float32: bounds = bounds.astype(np.float32)