Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added regridding #8

Merged
merged 3 commits into from
Jan 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 126 additions & 72 deletions evaltools/eval.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import numpy as np
import xarray as xr
import cordex as cx
import cf_xarray as cfxr
import xesmf as xe
from warnings import warn


def regional_mean(ds, regions=None, weights=None):
Expand Down Expand Up @@ -100,17 +104,12 @@ def height_correction(height1, height2):
return (height1 - height2) * 0.0065


def seasonal_mean(da):
"""
Calculate seasonal averages from a time series of monthly means.

Parameters:
da (xarray.DataArray): The DataArray to compute seasonal means for.
def seasonal_mean(da, skipna=True, min_count=1):
"""Calculate seasonal averages from time series of monthly means

Returns:
xarray.DataArray: The seasonal mean values.
based on: https://xarray.pydata.org/en/stable/examples/monthly-means.html
"""
# Get number of days for each month
# Get number od days for each month
month_length = da.time.dt.days_in_month
# Calculate the weights by grouping by 'time.season'.
weights = (
Expand All @@ -121,79 +120,134 @@ def seasonal_mean(da):
# np.testing.assert_allclose(weights.groupby("time.season").sum().values, np.ones(4))

# Calculate the weighted average
return (da * weights).groupby("time.season").sum(dim="time")
return (
(da * weights)
.groupby("time.season")
.sum(dim="time", skipna=skipna, min_count=min_count)
)


def add_bounds(ds):
if "longitude" not in ds.cf.bounds and "latitude" not in ds.cf.bounds:
ds = cx.transform_bounds(ds, trg_dims=("vertices_lon", "vertices_lat"))
ds = ds.assign_coords(
lon_b=cfxr.bounds_to_vertices(
ds.vertices_lon, bounds_dim="vertices", order="counterclockwise"
),
lat_b=cfxr.bounds_to_vertices(
ds.vertices_lat, bounds_dim="vertices", order="counterclockwise"
),
)
return ds


def mask_with_sftlf(ds, sftlf=None):
if sftlf is None and "sftlf" in ds:
sftlf = ds["sftlf"]
for var in ds.data_vars:
if var != "sftlf":
ds[var] = ds[var].where(sftlf > 0)
ds["mask"] = sftlf > 0
else:
warn(f"sftlf not found in dataset: {ds.source_id}")
return ds

def get_regridder(finer, coarser, method="bilinear", **kwargs):

def create_cordex_grid(domain_id):
"""
Regrid data bilinearly to a coarser grid.
Creates a CORDEX grid for the specified domain.

Parameters:
finer (xarray.Dataset): The dataset to regrid.
coarser (xarray.Dataset): The target grid dataset.
method (str, optional): The regridding method to use. Defaults to "bilinear".
**kwargs: Additional keyword arguments to pass to the regridding function.
Parameters
----------
domain_id : str
The domain ID for the CORDEX grid.

Returns:
xesmf.Regridder: The regridder object.
Returns
-------
xarray.Dataset
The CORDEX grid with assigned coordinates for longitude and latitude bounds.
"""
import xesmf as xe
grid = cx.domain(domain_id, bounds=True, mip_era="CMIP6")
lon_b = cfxr.bounds_to_vertices(
grid.vertices_lon, bounds_dim="vertices", order="counterclockwise"
)
lat_b = cfxr.bounds_to_vertices(
grid.vertices_lat, bounds_dim="vertices", order="counterclockwise"
)
return grid.assign_coords(lon_b=lon_b, lat_b=lat_b)


regridder = xe.Regridder(finer, coarser, method, **kwargs)
def create_regridder(source, target, method="bilinear"):
"""
Creates a regridder for regridding data from the source grid to the target grid.

Parameters
----------
source : xarray.Dataset
The source dataset to be regridded.
target : xarray.Dataset
The target grid dataset.
method : str, optional
The regridding method to use. Default is "bilinear".

Returns
-------
xesmf.Regridder
The regridder object.
"""
regridder = xe.Regridder(source, target, method=method)
return regridder


def compare_seasons(
ds1, ds2, regrid="ds1", do_height_correction=False, orog1=None, orog2=None
):
"""
Function to compare seasonal means of two datasets.

Paramters
---------
ds1 : xarray.Dataset
First variable data for comparision. Temporal resolution has to be less than monthly.
ds1 is mainly model output data.
ds2 is subtracted from ds1.
ds2 : xarray.Dataset
Second variable data for comparision. Temporal resolution has to be less than monthly.
ds2 is mainly observational or reanalysis data.
ds2 is subtracted from ds1.
regrid : {"ds1", "ds2"}, optional
Denotes the dataset to be bilinearly regridded. Specify the dataset with the finer spatial resolution:

- "ds1": Regrid ds1 to ds2's grid with coarser spatial resolution.
- "ds2": Regrid ds2 to ds1's grid with coarser spatial resolution.
do_height_correction : bool, optional
If ``do_height_correction=True``, do a height correction on ds1 using two orography files orog1 and orog2.
orog1 : xarray.Dataset, optional
Use only if ``do_height_correction=True``.
Specify a orography file referring to ds1.
orog2 : xarray.Dataset, optional
Use only if ``do_height_correction=True``.
Specify a orography file referring to ds2.
def regrid(ds, regridder):
"""
Regrids the dataset using the specified regridder.

Parameters
----------
ds : xarray.Dataset
The dataset to be regridded.
regridder : xesmf.Regridder
The regridder object.

Returns
-------
xarray.Dataset
The regridded dataset.
"""
ds_regrid = regridder(ds)
for var in ds.data_vars:
if var not in ["mask", "sftlf"]:
ds_regrid[var] = ds_regrid[var].where(ds_regrid["mask"] > 0.0)
return ds_regrid


def regrid_dsets(dsets, target_grid, method="bilinear"):
"""
Regrids multiple datasets to the target grid.

Parameters
----------
dsets : dict
A dictionary of datasets to be regridded, with keys as dataset IDs and values as xarray.Datasets.
target_grid : xarray.Dataset
The target grid dataset.
method : str, optional
The regridding method to use. Default is "bilinear".

Returns
-------
seasonal_comparision : xarray.Dataset
Spatial mean differences of two datasets

"""
ds1 = ds1.copy()
ds2 = ds2.copy()
ds1_seasmean = seasonal_mean(ds1)
ds2_seasmean = seasonal_mean(ds2)
if regrid == "ds1":
regridder = get_regridder(ds1, ds2)
print(regridder)
ds1_seasmean = regridder(ds1_seasmean)
elif regrid == "ds2":
regridder = get_regridder(ds2, ds1)
print(regridder)
ds2_seasmean = regridder(ds2_seasmean)

if do_height_correction is True:
orog1 = regridder(orog1)
ds1_seasmean += height_correction(orog1, orog2)
return ds1_seasmean - ds2_seasmean
# return xr.where(ds1_seasmean.mask, ds2_seasmean - ds1_seasmean, np.nan)
dict
A dictionary of regridded datasets.
"""
for dset_id, ds in dsets.items():
print(dset_id)
mapping = ds.cf["grid_mapping"].grid_mapping_name
if mapping == "rotated_latitude_longitude":
dsets[dset_id] = ds.cx.rewrite_coords(coords="all")
else:
print(f"regridding {dset_id} with grid_mapping: {mapping}")
regridder = create_regridder(ds, target_grid, method=method)
print(regridder)
dsets[dset_id] = regrid(ds, regridder)
return dsets
25 changes: 22 additions & 3 deletions evaltools/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from warnings import warn

from .utils import iid_to_dict, dict_to_iid
from .eval import mask_with_sftlf, add_bounds

xarray_open_kwargs = {"use_cftime": True, "decode_coords": "all", "chunks": None}
time_range_default = slice("1979", "2020")
Expand All @@ -25,7 +26,12 @@ def open_catalog(url=None):


def get_source_collection(
variable_id, frequency, driving_source_id="ERA5", add_fx=None, catalog=None
variable_id,
frequency,
driving_source_id="ERA5",
add_fx=None,
catalog=None,
**kwargs,
):
"""
Search the catalog for datasets matching the specified variable_id, frequency, and driving_source_id.
Expand All @@ -49,15 +55,16 @@ def get_source_collection(
frequency=frequency,
driving_source_id=driving_source_id,
require_all_on=["source_id"],
**kwargs,
)
source_ids = list(subset.df.source_id.unique())
print(f"Found: {source_ids} for variables: {variable_id}")
if add_fx:
if add_fx is True:
fx = catalog.search(source_id=source_ids, frequency="fx")
fx = catalog.search(source_id=source_ids, frequency="fx", **kwargs)
else:
fx = catalog.search(
source_id=source_ids, frequency="fx", variable_id=add_fx
source_id=source_ids, frequency="fx", variable_id=add_fx, **kwargs
)
if fx.df.empty:
warn(f"static variables not found: {variable_id}")
Expand Down Expand Up @@ -119,3 +126,15 @@ def open_and_sort(catalog, merge=None, concat=False, time_range="auto"):
join="override",
)
return dsets


def open_datasets(variables, frequency="mon", mask=True, add_missing_bounds=True):
catalog = get_source_collection(variables, frequency, add_fx=["areacella", "sftlf"])
dsets = open_and_sort(catalog, merge=True)
if mask is True:
for ds in dsets.values():
mask_with_sftlf(ds)
if add_missing_bounds is True:
for dset_id, ds in dsets.items():
dsets[dset_id] = add_bounds(ds)
return dsets
43 changes: 39 additions & 4 deletions evaltools/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,22 @@
from collections import defaultdict


def iid_to_dict(dset_id, attrs):
default_attrs = [
"project_id",
"domain_id",
"institution_id",
"driving_source_id",
"driving_experiment_id",
"driving_variant_label",
"source_id",
"version_realization",
#'frequency',
#'variable_id',
"version",
]


def iid_to_dict(iid, attrs=None):
"""
Convert a dataset ID and its attributes to a dictionary.

Expand All @@ -12,11 +27,13 @@ def iid_to_dict(dset_id, attrs):
Returns:
dict: The dataset ID and attributes as a dictionary.
"""
values = dset_id.split(".")
if attrs is None:
attrs = default_attrs
values = iid.split(".")
return dict(zip(attrs, values))


def dict_to_iid(attrs, drop=None):
def dict_to_iid(attrs, drop=None, delimiter="."):
"""
Convert a dictionary of dataset attributes to a dataset ID.

Expand All @@ -28,7 +45,25 @@ def dict_to_iid(attrs, drop=None):
"""
if drop is None:
drop = []
return ".".join(v for k, v in attrs.items() if k not in drop)
return delimiter.join(v for k, v in attrs.items() if k not in drop)


def short_iid(iid, attrs=None, delimiter="."):
"""
Convert a dataset ID to a short ID.

Parameters:
iid (str): The dataset ID.
attrs (dict): The dataset attributes.

Returns:
str: The short ID.
"""
if attrs is None:
attrs = ["institution_id", "source_id", "driving_source_id", "experiment_id"]
return dict_to_iid(
{k: v for k, v in iid_to_dict(iid).items() if k in attrs}, delimiter=delimiter
)


def sort_by_grid_mapping(dsets):
Expand Down
Loading