Skip to content

Commit

Permalink
Merge pull request #4 from andersy005/master
Browse files Browse the repository at this point in the history
Add support for caching datasets as zarr stores
  • Loading branch information
matt-long authored Nov 20, 2019
2 parents 1beadf3 + 682c23c commit a764787
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 54 deletions.
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ xarray
dask
toolz
netCDF4
zarr
pytest
pytest-cov
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@ numpy
xarray
dask
toolz
netCDF4
zarr
Binary file modified tests/cached_data/test-dset.nc
Binary file not shown.
60 changes: 35 additions & 25 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
import os
import shutil
from glob import glob
from tempfile import TemporaryDirectory

import numpy as np
import pytest
import xarray as xr

import xpersist as xp

import pytest

here = os.path.abspath(os.path.dirname(__file__))
xp.settings['cache_dir'] = os.path.join(here, 'cached_data')


def rm_tmpfile():
for p in ['tmp-*.nc', 'persisted_Dataset-*.nc']:
for p in ['tmp-*.nc', 'PersistedDataset-*.nc']:
for f in glob(os.path.join(here, 'cached_data', p)):
os.remove(f)

Expand All @@ -24,26 +25,27 @@ def cleanup():
yield
rm_tmpfile()


def func(scaleby):
return xr.Dataset({'x': xr.DataArray(np.ones((50,))*scaleby)})
return xr.Dataset({'x': xr.DataArray(np.ones((50,)) * scaleby)})


# must be first test
def test_xpersist_actions():
ds = xp.persist_ds(func, name='test-dset')(10)
file, action = xp.persisted_Dataset._actions.popitem()
_ = xp.persist_ds(func, name='test-dset')(10)
file, action = xp.PersistedDataset._actions.popitem()
assert action == 'read_cache_trusted'

ds = xp.persist_ds(func, name='test-dset')(10)
file, action = xp.persisted_Dataset._actions.popitem()
_ = xp.persist_ds(func, name='test-dset')(10)
file, action = xp.PersistedDataset._actions.popitem()
assert action == 'read_cache_verified'

ds = xp.persist_ds(func, name='test-dset')(11)
file, action = xp.persisted_Dataset._actions.popitem()
_ = xp.persist_ds(func, name='test-dset')(11)
file, action = xp.PersistedDataset._actions.popitem()
assert action == 'overwrite_cache'

ds = xp.persist_ds(func, name='tmp-test-dset')(11)
file, action = xp.persisted_Dataset._actions.popitem()
_ = xp.persist_ds(func, name='tmp-test-dset')(11)
file, action = xp.PersistedDataset._actions.popitem()
assert action == 'create_cache'


Expand All @@ -60,42 +62,50 @@ def test_make_cache_dir():
shutil.rmtree(new)
xp.settings['cache_dir'] = new

ds = xp.persist_ds(func, name='test-dset')(10)
_ = xp.persist_ds(func, name='test-dset')(10)

assert os.path.exists(new)

shutil.rmtree(new)
xp.settings['cache_dir'] = old



def test_xpersist_noname():
ds = xp.persist_ds(func)(10)
file, action = xp.persisted_Dataset._actions.popitem()
_ = xp.persist_ds(func)(10)
file, action = xp.PersistedDataset._actions.popitem()
assert action == 'create_cache'


def test_clobber():
ds = xp.persist_ds(func, name='test-dset')(10)
file, action = xp.persisted_Dataset._actions.popitem()
_ = xp.persist_ds(func, name='test-dset')(10)
file, action = xp.PersistedDataset._actions.popitem()
assert action == 'read_cache_verified'

ds = xp.persist_ds(func, name='test-dset', clobber=True)(11)
file, action = xp.persisted_Dataset._actions.popitem()
_ = xp.persist_ds(func, name='test-dset', clobber=True)(11)
file, action = xp.PersistedDataset._actions.popitem()
assert action == 'overwrite_cache'


def test_trusted():
ds = xp.persist_ds(func, name='test-dset')(10)
file, action = xp.persisted_Dataset._actions.popitem()
_ = xp.persist_ds(func, name='test-dset')(10)
file, action = xp.PersistedDataset._actions.popitem()
assert action == 'read_cache_verified'

ds = xp.persist_ds(func, name='test-dset', trust_cache=True)(11)
file, action = xp.persisted_Dataset._actions.popitem()
_ = xp.persist_ds(func, name='test-dset', trust_cache=True)(11)
file, action = xp.PersistedDataset._actions.popitem()
assert action == 'read_cache_trusted'


def test_validate_dset():
dsp = xp.persist_ds(func, name='test-dset')(10)
file, action = xp.persisted_Dataset._actions.popitem()
file, action = xp.PersistedDataset._actions.popitem()
ds = xr.open_dataset(file)
xr.testing.assert_identical(dsp, ds)


def test_save_as_zarr():
with TemporaryDirectory() as local_store:
dsp = xp.persist_ds(func, name='test-dset', path=local_store, format='zarr')(10)
zarr_store, action = xp.PersistedDataset._actions.popitem()
ds = xr.open_zarr(zarr_store, consolidated=True)
xr.testing.assert_identical(dsp, ds)
83 changes: 54 additions & 29 deletions xpersist/core.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import os
import shutil

from toolz import curry

import xarray as xr
import dask
import xarray as xr
from toolz import curry

from . import settings

__all__ = ["persisted_Dataset", "persist_ds"]
__all__ = ['PersistedDataset', 'persist_ds']

_actions = {'read_cache_trusted', 'read_cache_verified', 'overwrite_cache', 'create_cache'}
_formats = {'nc'}
_formats = {'nc', 'zarr'}

class persisted_Dataset(object):

class PersistedDataset(object):
"""
Generate an `xarray.Dataset` from a function and cache the result to file.
If the cache file exists, don't recompute, but read back in from file.
Expand All @@ -30,8 +31,16 @@ class persisted_Dataset(object):
# class property
_actions = {}

def __init__(self, func, name=None, path=None, trust_cache=False, clobber=False,
format='nc', open_ds_kwargs={}):
def __init__(
self,
func,
name=None,
path=None,
trust_cache=False,
clobber=False,
format='nc',
open_ds_kwargs={},
):
"""set instance attributes"""
self._func = func
self._name = name
Expand All @@ -52,34 +61,37 @@ def _check_token_assign_action(self, token):

# if we don't yet know about this file, assume it's the right one;
# this enables usage on first call in a Python session, for instance
known_cache = self._cache_file in persisted_Dataset._tokens
known_cache = self._cache_file in PersistedDataset._tokens
if not known_cache or self._trust_cache and not self._clobber:
print(f'assuming cache is correct')
persisted_Dataset._tokens[self._cache_file] = token
persisted_Dataset._actions[self._cache_file] = 'read_cache_trusted'
PersistedDataset._tokens[self._cache_file] = token
PersistedDataset._actions[self._cache_file] = 'read_cache_trusted'

# if the cache file is present and we know about it,
# check the token; if the token doesn't match, remove the file
elif known_cache:
if token != persisted_Dataset._tokens[self._cache_file] or self._clobber:
if token != PersistedDataset._tokens[self._cache_file] or self._clobber:
print(f'name mismatch, removing: {self._cache_file}')
os.remove(self._cache_file)
persisted_Dataset._actions[self._cache_file] = 'overwrite_cache'
if self._format != 'zarr':
os.remove(self._cache_file)
else:
shutil.rmtree(self._cache_file, ignore_errors=True)
PersistedDataset._actions[self._cache_file] = 'overwrite_cache'
else:
persisted_Dataset._actions[self._cache_file] = 'read_cache_verified'
PersistedDataset._actions[self._cache_file] = 'read_cache_verified'

else:
persisted_Dataset._tokens[self._cache_file] = token
persisted_Dataset._actions[self._cache_file] = 'create_cache'
PersistedDataset._tokens[self._cache_file] = token
PersistedDataset._actions[self._cache_file] = 'create_cache'
if os.path.dirname(self._cache_file) and not os.path.exists(self._path):
print(f'making {self._path}')
os.makedirs(self._path)

assert persisted_Dataset._actions[self._cache_file] in _actions
assert PersistedDataset._actions[self._cache_file] in _actions

@property
def _basename(self):
if self._name.endswith('.'+self._format):
if self._name.endswith('.' + self._format):
return self._name
else:
return f'{self._name}.{self._format}'
Expand All @@ -95,36 +107,49 @@ def _cache_exists(self):

def __call__(self, *args, **kwargs):
"""call function or read cache"""

# Generate Deterministic token
token = dask.base.tokenize(self._func, args, kwargs)
if self._name is None:
self._name = f'persisted_Dataset-{token}'
self._name = f'PersistedDataset-{token}'

if self._path is None:
self._path = settings['cache_dir']

self._check_token_assign_action(token)

if {'read_cache_trusted', 'read_cache_verified'}.intersection({self._actions[self._cache_file]}):
if {'read_cache_trusted', 'read_cache_verified'}.intersection(
{self._actions[self._cache_file]}
):
print(f'reading cached file: {self._cache_file}')
return xr.open_dataset(self._cache_file, **self._open_ds_kwargs)
if self._format == 'nc':
return xr.open_dataset(self._cache_file, **self._open_ds_kwargs)
elif self._format == 'zarr':
if 'consolidated' not in self._open_ds_kwargs:
zarr_kwargs = self._open_ds_kwargs.copy()
zarr_kwargs['consolidated'] = True
return xr.open_zarr(self._cache_file, **zarr_kwargs)

elif {'create_cache', 'overwrite_cache'}.intersection({self._actions[self._cache_file]}):
# generate dataset
ds = self._func(*args, **kwargs)

# write dataset
print(f'writing cache file: {self._cache_file}')
ds.to_netcdf(self._cache_file)

return ds
if self._format == 'nc':
ds.to_netcdf(self._cache_file)

elif self._format == 'zarr':
ds.to_zarr(self._cache_file, consolidated=True)

return ds


@curry
def persist_ds(func, name=None, path=None, trust_cache=False, clobber=False,
format='nc', open_ds_kwargs={}):
"""Wraps a function to produce a ``persisted_Dataset``.
def persist_ds(
func, name=None, path=None, trust_cache=False, clobber=False, format='nc', open_ds_kwargs={}
):
"""Wraps a function to produce a ``PersistedDataset``.
Parameters
----------
Expand Down Expand Up @@ -182,4 +207,4 @@ def persist_ds(func, name=None, path=None, trust_cache=False, clobber=False,
if not callable(func):
raise ValueError('func must be callable')

return persisted_Dataset(func, name, path, trust_cache, clobber, format, open_ds_kwargs)
return PersistedDataset(func, name, path, trust_cache, clobber, format, open_ds_kwargs)

0 comments on commit a764787

Please sign in to comment.