Skip to content

Commit

Permalink
FIX: implementation was not friendly to dask-based datasets from open…
Browse files Browse the repository at this point in the history
…_mfdataset (#8)

* fix small issue with loading and dask

* added some manual loading to make dask happy for mf datasets

* filter warnings in this way because with dask they are hard to pin down

* numpy2 compatibility issue
  • Loading branch information
nocollier authored Jul 2, 2024
1 parent b0a990d commit 273ea86
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 5 deletions.
11 changes: 10 additions & 1 deletion ilamb3/analysis/bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
ILAMBAnalysis : The abstract base class from which this derives.
"""

import warnings
from typing import Literal, Union

import numpy as np
Expand Down Expand Up @@ -126,6 +127,8 @@ def __call__(

# Temporal means across the time period
ref, com = cmp.make_comparable(ref, com, varname)
ref.pint.dequantify().load().pint.quantify()
com.pint.dequantify().load().pint.quantify()
ref_mean = (
dset.integrate_time(ref, varname, mean=True)
if "time" in ref[varname].dims
Expand Down Expand Up @@ -169,6 +172,7 @@ def __call__(
raise ValueError("Reference and comparison not uniformly site/spatial.")

# Compute score by different methods
ref_, com_, norm_, uncert_ = cmp.rename_dims(ref_, com_, norm_, uncert_)
bias = com_ - ref_
if method == "Collier2018":
score = np.exp(-(np.abs(bias) - uncert_).clip(0) / norm_)
Expand Down Expand Up @@ -258,6 +262,11 @@ def _scalar(
)
# Bias Score
bias_scalar_score = _scalar(com_out, "bias_score", region, True, True)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", "divide by zero encountered in divide", RuntimeWarning
)
bias_scalar_score = float(bias_scalar_score.pint.dequantify())
dfs.append(
[
"Comparison",
Expand All @@ -266,7 +275,7 @@ def _scalar(
"Bias Score",
"score",
"1",
float(bias_scalar_score.pint.dequantify()),
bias_scalar_score,
]
)

Expand Down
8 changes: 7 additions & 1 deletion ilamb3/analysis/relationship.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,13 @@ def build_response(self, nbin: int = 25, eps: float = 3e-3):
np.log10(self.dep_limits[0]), np.log10(self.dep_limits[1]), nbin + 1
)
dist, xedges, yedges = np.histogram2d(
ind, dep, bins=[xedges, yedges], range=[self._ind_limits, self._dep_limits]
ind,
dep,
bins=[xedges, yedges],
range=[
[v.values for v in self._ind_limits],
[v.values for v in self._dep_limits],
],
)
dist = np.ma.masked_values(dist.T, 0).astype(float)
dist /= dist.sum()
Expand Down
36 changes: 36 additions & 0 deletions ilamb3/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,12 @@ def _to_tuple(da: xr.DataArray) -> tuple[int]:
ta0, taf = dset.get_time_extent(dsa)
tb0, tbf = dset.get_time_extent(dsb)

# At this point we need actual data, so load
ta0.load()
taf.load()
tb0.load()
tbf.load()

# Convert to a date tuple (year, month, day) and find the maximal overlap
tmin = max(_to_tuple(ta0), _to_tuple(tb0))
tmax = min(_to_tuple(taf), _to_tuple(tbf))
Expand Down Expand Up @@ -258,3 +264,33 @@ def extract_sites(
)
assert (dist < model_res).all()
return ds_spatial


def rename_dims(*args):
"""
Rename the dimension to a uniform canonical name.
Parameters
----------
*args
Any number of `xr.Dataset` or `xr.DataArray` objects for which we will change
the dimension names.
Returns
-------
*args
The input *args with the dimension names changed.
"""

def _populate_renames(ds):
out = {}
for dim in ["time", "lat", "lon"]:
try:
out[dset.get_dim_name(ds, dim)] = dim
except KeyError:
pass
return out

for arg in args:
assert isinstance(arg, (xr.DataArray, xr.Dataset))
return [arg.rename(_populate_renames(arg)) for arg in args]
5 changes: 2 additions & 3 deletions ilamb3/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,8 @@ def _measure1d(time): # numpydoc ignore=GL08
# compute from the bounds
delt = dset[timeb_name]
nbnd = delt.dims[-1]
delt = delt.diff(nbnd).squeeze()
delt *= 1e-9 / 86400 # [ns] to [d]
measure = delt.astype("float")
delt = delt.diff(nbnd).squeeze().compute()
measure = delt.astype("float") * 1e-9 / 86400 # [ns] to [d]
measure = measure.pint.quantify("d")
return measure

Expand Down

0 comments on commit 273ea86

Please sign in to comment.