From 3abe0df25814a22eba5089362b08ef536275bc71 Mon Sep 17 00:00:00 2001 From: Anderson Banihirwe Date: Tue, 19 Nov 2019 16:25:31 -0700 Subject: [PATCH 1/2] Add support for caching datasets as zarr stores --- tests/cached_data/test-dset.nc | Bin 6544 -> 6544 bytes tests/test_core.py | 60 ++++++++++++++---------- xpersist/core.py | 83 +++++++++++++++++++++------------ 3 files changed, 89 insertions(+), 54 deletions(-) diff --git a/tests/cached_data/test-dset.nc b/tests/cached_data/test-dset.nc index b35115ed3f45badc37bede8840e916fad930ec0f..b793b02b7b83a135639350185277b6671e8b7f27 100644 GIT binary patch delta 179 zcmbPWJi&OvamKY1Ph=^W>lx}~q@6 zXZ$~T0i!q*3&UhaW|7GP%q){HGR|?}F=1d}Vqj$8W8h)nV2JliElEyEGjYkx%}fCr zB*4MIzydZ<0ZcN2Nf_bRHv8k|Vx}TSM&`-Sm{*FdWrE2-=*aUW*EVltQD)x6F+mUj D2=6BE delta 112 zcmbPWJi&OvamM0_C$bdG^o(>eQqoLq4fPBS^h_DRK%#Q%H-8V8ASPa!$y=DzCRZ@3 zGajA1fKi;0dGZd%t!xYoq5|yso7XZGF)}huR$y5vQp^O=2_x6mGH=`L$PvrDiDQBw E0Md6H1^@s6 diff --git a/tests/test_core.py b/tests/test_core.py index 7dcd325..8278410 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,19 +1,20 @@ import os import shutil from glob import glob +from tempfile import TemporaryDirectory + import numpy as np +import pytest import xarray as xr import xpersist as xp -import pytest - here = os.path.abspath(os.path.dirname(__file__)) xp.settings['cache_dir'] = os.path.join(here, 'cached_data') def rm_tmpfile(): - for p in ['tmp-*.nc', 'persisted_Dataset-*.nc']: + for p in ['tmp-*.nc', 'PersistedDataset-*.nc']: for f in glob(os.path.join(here, 'cached_data', p)): os.remove(f) @@ -24,26 +25,27 @@ def cleanup(): yield rm_tmpfile() + def func(scaleby): - return xr.Dataset({'x': xr.DataArray(np.ones((50,))*scaleby)}) + return xr.Dataset({'x': xr.DataArray(np.ones((50,)) * scaleby)}) # must be first test def test_xpersist_actions(): - ds = xp.persist_ds(func, name='test-dset')(10) - file, action = xp.persisted_Dataset._actions.popitem() + _ = xp.persist_ds(func, name='test-dset')(10) + file, action = xp.PersistedDataset._actions.popitem() assert action == 'read_cache_trusted' - ds = xp.persist_ds(func, name='test-dset')(10) - file, action = xp.persisted_Dataset._actions.popitem() + _ = xp.persist_ds(func, name='test-dset')(10) + file, action = xp.PersistedDataset._actions.popitem() assert action == 'read_cache_verified' - ds = xp.persist_ds(func, name='test-dset')(11) - file, action = xp.persisted_Dataset._actions.popitem() + _ = xp.persist_ds(func, name='test-dset')(11) + file, action = xp.PersistedDataset._actions.popitem() assert action == 'overwrite_cache' - ds = xp.persist_ds(func, name='tmp-test-dset')(11) - file, action = xp.persisted_Dataset._actions.popitem() + _ = xp.persist_ds(func, name='tmp-test-dset')(11) + file, action = xp.PersistedDataset._actions.popitem() assert action == 'create_cache' @@ -60,7 +62,7 @@ def test_make_cache_dir(): shutil.rmtree(new) xp.settings['cache_dir'] = new - ds = xp.persist_ds(func, name='test-dset')(10) + _ = xp.persist_ds(func, name='test-dset')(10) assert os.path.exists(new) @@ -68,34 +70,42 @@ def test_make_cache_dir(): xp.settings['cache_dir'] = old - def test_xpersist_noname(): - ds = xp.persist_ds(func)(10) - file, action = xp.persisted_Dataset._actions.popitem() + _ = xp.persist_ds(func)(10) + file, action = xp.PersistedDataset._actions.popitem() assert action == 'create_cache' def test_clobber(): - ds = xp.persist_ds(func, name='test-dset')(10) - file, action = xp.persisted_Dataset._actions.popitem() + _ = xp.persist_ds(func, name='test-dset')(10) + file, action = xp.PersistedDataset._actions.popitem() assert action == 'read_cache_verified' - ds = xp.persist_ds(func, name='test-dset', clobber=True)(11) - file, action = xp.persisted_Dataset._actions.popitem() + _ = xp.persist_ds(func, name='test-dset', clobber=True)(11) + file, action = xp.PersistedDataset._actions.popitem() assert action == 'overwrite_cache' def test_trusted(): - ds = xp.persist_ds(func, name='test-dset')(10) - file, action = xp.persisted_Dataset._actions.popitem() + _ = xp.persist_ds(func, name='test-dset')(10) + file, action = xp.PersistedDataset._actions.popitem() assert action == 'read_cache_verified' - ds = xp.persist_ds(func, name='test-dset', trust_cache=True)(11) - file, action = xp.persisted_Dataset._actions.popitem() + _ = xp.persist_ds(func, name='test-dset', trust_cache=True)(11) + file, action = xp.PersistedDataset._actions.popitem() assert action == 'read_cache_trusted' + def test_validate_dset(): dsp = xp.persist_ds(func, name='test-dset')(10) - file, action = xp.persisted_Dataset._actions.popitem() + file, action = xp.PersistedDataset._actions.popitem() ds = xr.open_dataset(file) xr.testing.assert_identical(dsp, ds) + + +def test_save_as_zarr(): + with TemporaryDirectory() as local_store: + dsp = xp.persist_ds(func, name='test-dset', path=local_store, format='zarr')(10) + zarr_store, action = xp.PersistedDataset._actions.popitem() + ds = xr.open_zarr(zarr_store, consolidated=True) + xr.testing.assert_identical(dsp, ds) diff --git a/xpersist/core.py b/xpersist/core.py index 84ed8ea..2b55fcc 100644 --- a/xpersist/core.py +++ b/xpersist/core.py @@ -1,18 +1,19 @@ import os +import shutil -from toolz import curry - -import xarray as xr import dask +import xarray as xr +from toolz import curry from . import settings -__all__ = ["persisted_Dataset", "persist_ds"] +__all__ = ['PersistedDataset', 'persist_ds'] _actions = {'read_cache_trusted', 'read_cache_verified', 'overwrite_cache', 'create_cache'} -_formats = {'nc'} +_formats = {'nc', 'zarr'} -class persisted_Dataset(object): + +class PersistedDataset(object): """ Generate an `xarray.Dataset` from a function and cache the result to file. If the cache file exists, don't recompute, but read back in from file. @@ -30,8 +31,16 @@ class persisted_Dataset(object): # class property _actions = {} - def __init__(self, func, name=None, path=None, trust_cache=False, clobber=False, - format='nc', open_ds_kwargs={}): + def __init__( + self, + func, + name=None, + path=None, + trust_cache=False, + clobber=False, + format='nc', + open_ds_kwargs={}, + ): """set instance attributes""" self._func = func self._name = name @@ -52,34 +61,37 @@ def _check_token_assign_action(self, token): # if we don't yet know about this file, assume it's the right one; # this enables usage on first call in a Python session, for instance - known_cache = self._cache_file in persisted_Dataset._tokens + known_cache = self._cache_file in PersistedDataset._tokens if not known_cache or self._trust_cache and not self._clobber: print(f'assuming cache is correct') - persisted_Dataset._tokens[self._cache_file] = token - persisted_Dataset._actions[self._cache_file] = 'read_cache_trusted' + PersistedDataset._tokens[self._cache_file] = token + PersistedDataset._actions[self._cache_file] = 'read_cache_trusted' # if the cache file is present and we know about it, # check the token; if the token doesn't match, remove the file elif known_cache: - if token != persisted_Dataset._tokens[self._cache_file] or self._clobber: + if token != PersistedDataset._tokens[self._cache_file] or self._clobber: print(f'name mismatch, removing: {self._cache_file}') - os.remove(self._cache_file) - persisted_Dataset._actions[self._cache_file] = 'overwrite_cache' + if self._format != 'zarr': + os.remove(self._cache_file) + else: + shutil.rmtree(self._cache_file, ignore_errors=True) + PersistedDataset._actions[self._cache_file] = 'overwrite_cache' else: - persisted_Dataset._actions[self._cache_file] = 'read_cache_verified' + PersistedDataset._actions[self._cache_file] = 'read_cache_verified' else: - persisted_Dataset._tokens[self._cache_file] = token - persisted_Dataset._actions[self._cache_file] = 'create_cache' + PersistedDataset._tokens[self._cache_file] = token + PersistedDataset._actions[self._cache_file] = 'create_cache' if os.path.dirname(self._cache_file) and not os.path.exists(self._path): print(f'making {self._path}') os.makedirs(self._path) - assert persisted_Dataset._actions[self._cache_file] in _actions + assert PersistedDataset._actions[self._cache_file] in _actions @property def _basename(self): - if self._name.endswith('.'+self._format): + if self._name.endswith('.' + self._format): return self._name else: return f'{self._name}.{self._format}' @@ -95,19 +107,27 @@ def _cache_exists(self): def __call__(self, *args, **kwargs): """call function or read cache""" - + # Generate Deterministic token token = dask.base.tokenize(self._func, args, kwargs) if self._name is None: - self._name = f'persisted_Dataset-{token}' + self._name = f'PersistedDataset-{token}' if self._path is None: self._path = settings['cache_dir'] self._check_token_assign_action(token) - if {'read_cache_trusted', 'read_cache_verified'}.intersection({self._actions[self._cache_file]}): + if {'read_cache_trusted', 'read_cache_verified'}.intersection( + {self._actions[self._cache_file]} + ): print(f'reading cached file: {self._cache_file}') - return xr.open_dataset(self._cache_file, **self._open_ds_kwargs) + if self._format == 'nc': + return xr.open_dataset(self._cache_file, **self._open_ds_kwargs) + elif self._format == 'zarr': + if 'consolidated' not in self._open_ds_kwargs: + zarr_kwargs = self._open_ds_kwargs.copy() + zarr_kwargs['consolidated'] = True + return xr.open_zarr(self._cache_file, **zarr_kwargs) elif {'create_cache', 'overwrite_cache'}.intersection({self._actions[self._cache_file]}): # generate dataset @@ -115,16 +135,21 @@ def __call__(self, *args, **kwargs): # write dataset print(f'writing cache file: {self._cache_file}') - ds.to_netcdf(self._cache_file) - return ds + if self._format == 'nc': + ds.to_netcdf(self._cache_file) + elif self._format == 'zarr': + ds.to_zarr(self._cache_file, consolidated=True) + + return ds @curry -def persist_ds(func, name=None, path=None, trust_cache=False, clobber=False, - format='nc', open_ds_kwargs={}): - """Wraps a function to produce a ``persisted_Dataset``. +def persist_ds( + func, name=None, path=None, trust_cache=False, clobber=False, format='nc', open_ds_kwargs={} +): + """Wraps a function to produce a ``PersistedDataset``. Parameters ---------- @@ -182,4 +207,4 @@ def persist_ds(func, name=None, path=None, trust_cache=False, clobber=False, if not callable(func): raise ValueError('func must be callable') - return persisted_Dataset(func, name, path, trust_cache, clobber, format, open_ds_kwargs) + return PersistedDataset(func, name, path, trust_cache, clobber, format, open_ds_kwargs) From 682c23c226da1e53f50b55eca84e1fb3719d9781 Mon Sep 17 00:00:00 2001 From: Anderson Banihirwe Date: Tue, 19 Nov 2019 16:27:42 -0700 Subject: [PATCH 2/2] Update dependencies --- requirements-dev.txt | 1 + requirements.txt | 2 ++ 2 files changed, 3 insertions(+) diff --git a/requirements-dev.txt b/requirements-dev.txt index 3f11ff0..c7154f9 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,5 +3,6 @@ xarray dask toolz netCDF4 +zarr pytest pytest-cov diff --git a/requirements.txt b/requirements.txt index 7ee8d7a..71094d1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,5 @@ numpy xarray dask toolz +netCDF4 +zarr