Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add caching options #412

Open
wants to merge 10 commits into
base: dev
Choose a base branch
from
2 changes: 1 addition & 1 deletion nlmod/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

NLMOD_DATADIR = os.path.join(os.path.dirname(__file__), "data")

from . import dims, gis, gwf, gwt, modpath, plot, read, sim, util
from . import config, dims, gis, gwf, gwt, modpath, plot, read, sim, util
from .dims import base, get_ds, grid, layers, resample, time, to_model_ds
from .util import download_mfbinaries
from .version import __version__, show_versions
133 changes: 108 additions & 25 deletions nlmod/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
import pandas as pd
import xarray as xr
from dask.diagnostics import ProgressBar
from xarray.testing import assert_identical

from .config import NLMOD_CACHE_OPTIONS

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -194,6 +197,7 @@ def wrapper(*args, cachedir=None, cachename=None, **kwargs):
with open(fname_pickle_cache, "rb") as f:
func_args_dic_cache = pickle.load(f)
pickle_check = True

except (pickle.UnpicklingError, ModuleNotFoundError):
logger.info("could not read pickle, not using cache")
pickle_check = False
Expand All @@ -216,30 +220,62 @@ def wrapper(*args, cachedir=None, cachename=None, **kwargs):

if pickle_check:
# Ensure that the pickle pairs with the netcdf, see #66.
if nc_hash:
if NLMOD_CACHE_OPTIONS["nc_hash"] and nc_hash:
with open(fname_cache, "rb") as myfile:
cache_bytes = myfile.read()
func_args_dic["_nc_hash"] = hashlib.sha256(
cache_bytes
).hexdigest()

if dataset is not None:
# Check the coords of the dataset argument
func_args_dic["_dataset_coords_hash"] = dask.base.tokenize(
dict(dataset.coords)
)

# Check the data_vars of the dataset argument
func_args_dic["_dataset_data_vars_hash"] = dask.base.tokenize(
dict(dataset.data_vars)
)
if NLMOD_CACHE_OPTIONS["dataset_coords_hash"]:
# Check the coords of the dataset argument
func_args_dic["_dataset_coords_hash"] = dask.base.tokenize(
dict(dataset.coords)
)
else:
func_args_dic_cache.pop("_dataset_coords_hash", None)
logger.warning(
"cache -> dataset coordinates not checked, "
"disabled in global config. See "
"`nlmod.config.NLMOD_CACHE_OPTIONS`."
)
if not NLMOD_CACHE_OPTIONS[
"explicit_dataset_coordinate_comparison"
]:
logger.warning(
"It is recommended to turn on "
"`explicit_dataset_coordinate_comparison` "
"in global config when hash check is turned off!"
)

if NLMOD_CACHE_OPTIONS["dataset_data_vars_hash"]:
# Check the data_vars of the dataset argument
func_args_dic["_dataset_data_vars_hash"] = (
dask.base.tokenize(dict(dataset.data_vars))
)
else:
func_args_dic_cache.pop("_dataset_data_vars_hash", None)
logger.warning(
"cache -> dataset data vars not checked, "
"disabled in global config. See "
"`nlmod.config.NLMOD_CACHE_OPTIONS`."
)

# check if cache was created with same function arguments as
# function call
argument_check = _same_function_arguments(
func_args_dic, func_args_dic_cache
)

# explicit check on input dataset coordinates and cached dataset
if NLMOD_CACHE_OPTIONS[
"explicit_dataset_coordinate_comparison"
] and isinstance(dataset, (xr.DataArray, xr.Dataset)):
b = _explicit_dataset_coordinate_comparison(dataset, cached_ds)
# update argument check
argument_check = argument_check and b

cached_ds = _check_for_data_array(cached_ds)
if modification_check and argument_check and pickle_check:
msg = f"using cached data -> {cachename}"
Expand Down Expand Up @@ -276,19 +312,33 @@ def wrapper(*args, cachedir=None, cachename=None, **kwargs):
result.to_netcdf(fname_cache)

# add netcdf hash to function arguments dic, see #66
if nc_hash:
if NLMOD_CACHE_OPTIONS["nc_hash"] and nc_hash:
with open(fname_cache, "rb") as myfile:
cache_bytes = myfile.read()
func_args_dic["_nc_hash"] = hashlib.sha256(cache_bytes).hexdigest()

# Add dataset argument hash to function arguments dic
if dataset is not None:
func_args_dic["_dataset_coords_hash"] = dask.base.tokenize(
dict(dataset.coords)
)
func_args_dic["_dataset_data_vars_hash"] = dask.base.tokenize(
dict(dataset.data_vars)
)
if NLMOD_CACHE_OPTIONS["dataset_coords_hash"]:
func_args_dic["_dataset_coords_hash"] = dask.base.tokenize(
dict(dataset.coords)
)
else:
logger.warning(
"cache -> not writing dataset coordinates hash to "
"pickle file, disabled in global config. See "
"`nlmod.config.NLMOD_CACHE_OPTIONS`."
)
if NLMOD_CACHE_OPTIONS["dataset_data_vars_hash"]:
func_args_dic["_dataset_data_vars_hash"] = dask.base.tokenize(
dict(dataset.data_vars)
)
else:
logger.warning(
"cache -> not writing dataset data vars hash to "
"pickle file, disabled in global config. See "
"`nlmod.config.NLMOD_CACHE_OPTIONS`."
)

# pickle function arguments
with open(fname_pickle_cache, "wb") as fpklz:
Expand Down Expand Up @@ -422,15 +472,14 @@ def decorator(*args, cachedir=None, cachename=None, **kwargs):


def _same_function_arguments(func_args_dic, func_args_dic_cache):
"""Checks if two dictionaries with function arguments are identical by
checking:
"""Checks if two dictionaries with function arguments are identical.

The following items are checked:
1. if they have the same keys
2. if the items have the same type
3. if the items have the same values (only possible for the types: int,
float, bool, str, bytes, list,
tuple, dict, np.ndarray,
xr.DataArray,
flopy.mf6.ModflowGwf).
3. if the items have the same values (only implemented for the types: int,
float, bool, str, bytes, list, tuple, dict, np.ndarray, xr.DataArray,
flopy.mf6.ModflowGwf).

Parameters
----------
Expand Down Expand Up @@ -744,7 +793,6 @@ def ds_contains(
if coords_2d or coords_3d:
coords.append("x")
coords.append("y")
datavars.append("area")
attrs.append("extent")
attrs.append("gridtype")

Expand Down Expand Up @@ -832,3 +880,38 @@ def ds_contains(
coords={k: ds.coords[k] for k in coords},
attrs={k: ds.attrs[k] for k in attrs},
)


def _explicit_dataset_coordinate_comparison(ds_in, ds_cache):
"""Perform explicit dataset coordinate comparison.

Uses `xarray.testing.assert_identical()`.

Parameters
----------
ds_in : xr.Dataset
Input dataset.
ds_cache : xr.Dataset
Cached dataset.

Returns
-------
bool
True if coordinates are identical, else False.

Raises
------
AssertionError
If the coordinates are not equal.
"""
logger.debug("cache -> performing explicit dataset coordinate comparison")
for coord in ds_cache.coords:
logger.debug(f"cache -> comparing coordinate {coord}")
try:
assert_identical(ds_in[coord], ds_cache[coord])
except AssertionError as e:
logger.debug(f"cache -> coordinate {coord} not equal")
logger.debug(e)
return False
logger.debug("cache -> all coordinates equal")
return True
79 changes: 79 additions & 0 deletions nlmod/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from contextlib import contextmanager

NLMOD_CACHE_OPTIONS = {
# compare hash for stored netcdf, default is True:
"nc_hash": True,
# compare hash for dataset coordinates, default is True:
"dataset_coords_hash": True,
# compare hash for dataset data variables, default is True:
"dataset_data_vars_hash": True,
# perform explicit comparison of dataset coordinates, default is False:
"explicit_dataset_coordinate_comparison": False,
}

_DEFAULT_CACHE_OPTIONS = {
"nc_hash": True,
"dataset_coords_hash": True,
"dataset_data_vars_hash": True,
"explicit_dataset_coordinate_comparison": False,
}


@contextmanager
def cache_options(**kwargs):
"""Context manager for nlmod cache options."""
set_options(**kwargs)
try:
yield get_options()
finally:
reset_options(list(kwargs.keys()))


def set_options(**kwargs):
"""
Set options for the nlmod package.

Parameters
----------
**kwargs : dict
Options to set.

"""
for key, value in kwargs.items():
if key in NLMOD_CACHE_OPTIONS:
NLMOD_CACHE_OPTIONS[key] = value
else:
raise ValueError(
f"Unknown option: {key}. Options are: "
f"{list(NLMOD_CACHE_OPTIONS.keys())}"
)


def get_options(key=None):
"""
Get options for the nlmod package.

Parameters
----------
key : str, optional
Option to get.

Returns
-------
dict or value
The options or the value of the requested option.

"""
if key is None:
return NLMOD_CACHE_OPTIONS
else:
return {key: NLMOD_CACHE_OPTIONS[key]}


def reset_options(options=None):
"""Reset options to default."""
if options is None:
set_options(**_DEFAULT_CACHE_OPTIONS)
else:
for opt in options:
set_options(**{opt: _DEFAULT_CACHE_OPTIONS[opt]})
6 changes: 3 additions & 3 deletions nlmod/dims/attributes_encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,9 @@ def get_encodings(
if np.issubdtype(da.dtype, np.character):
continue

assert (
"_FillValue" not in da.attrs
), f"Custom fillvalues are not supported. {varname} has a fillvalue set."
assert "_FillValue" not in da.attrs, (
f"Custom fillvalues are not supported. {varname} has a fillvalue set."
)

encoding = {
"zlib": True,
Expand Down
Loading