From a94eb76d4c944ffe2c8e7baf3d0bd253f5a7dd9b Mon Sep 17 00:00:00 2001 From: Joep Vanlier Date: Tue, 6 Aug 2024 15:25:01 +0200 Subject: [PATCH] channel: add global cache for channels --- lumicks/pylake/__init__.py | 1 + lumicks/pylake/channel.py | 84 ++++++++-------- lumicks/pylake/detail/caching.py | 96 +++++++++++++++++++ lumicks/pylake/file.py | 3 +- .../pylake/tests/test_channels/conftest.py | 14 +++ .../tests/test_channels/test_channels.py | 6 +- lumicks/pylake/tests/test_file/conftest.py | 12 +++ .../pylake/tests/test_file/test_caching.py | 65 +++++++++++++ pyproject.toml | 4 +- 9 files changed, 235 insertions(+), 50 deletions(-) create mode 100644 lumicks/pylake/detail/caching.py create mode 100644 lumicks/pylake/tests/test_file/test_caching.py diff --git a/lumicks/pylake/__init__.py b/lumicks/pylake/__init__.py index 1c936d22f..e7e5ab135 100644 --- a/lumicks/pylake/__init__.py +++ b/lumicks/pylake/__init__.py @@ -19,6 +19,7 @@ from .fitting.fit import FdFit from .image_stack import ImageStack, CorrelatedStack from .file_download import * +from .detail.caching import set_cache_enabled from .fitting.models import * from .fitting.parameter_trace import parameter_trace from .kymotracker.kymotracker import * diff --git a/lumicks/pylake/channel.py b/lumicks/pylake/channel.py index 83b63d679..ee67bb45d 100644 --- a/lumicks/pylake/channel.py +++ b/lumicks/pylake/channel.py @@ -6,6 +6,7 @@ import numpy as np import numpy.typing as npt +from .detail import caching from .detail.plotting import _annotate from .detail.timeindex import to_seconds, to_timestamp from .detail.utilities import downsample @@ -657,7 +658,7 @@ def range_selector(self, show=True, **kwargs) -> SliceRangeSelectorWidget: return SliceRangeSelectorWidget(self, show=show, **kwargs) -class Continuous: +class Continuous(caching.LazyCacheMixin): """A source of continuous data for a timeline slice Parameters @@ -671,8 +672,8 @@ class Continuous: """ def __init__(self, data, start, dt): + super().__init__() self._src_data = data - self._cached_data = None self.start = start self.stop = start + len(data) * dt self.dt = dt # ns @@ -693,7 +694,7 @@ def from_dataset(dset, y_label="y", calibration=None): start = dset.attrs["Start time (ns)"] dt = int(1e9 / dset.attrs["Sample rate (Hz)"]) # ns return Slice( - Continuous(dset, start, dt), + Continuous(caching.from_h5py(dset), start, dt), labels={"title": dset.name.strip("/"), "y": y_label}, calibration=calibration, ) @@ -719,9 +720,7 @@ def to_dataset(self, parent, name, **kwargs): @property def data(self) -> npt.ArrayLike: - if self._cached_data is None: - self._cached_data = np.asarray(self._src_data) - return self._cached_data + return self.read_lazy_cache("data", self._src_data) @property def timestamps(self) -> npt.ArrayLike: @@ -755,7 +754,7 @@ def downsampled_by(self, factor, reduce): ) -class TimeSeries: +class TimeSeries(caching.LazyCacheMixin): """A source of time series data for a timeline slice Parameters @@ -778,10 +777,9 @@ def __init__(self, data, timestamps): f"({len(timestamps)})." ) + super().__init__() self._src_data = data - self._cached_data = None self._src_timestamps = timestamps - self._cached_timestamps = None def __len__(self): return len(self._src_data) @@ -796,32 +794,8 @@ def _apply_mask(self, mask): @staticmethod def from_dataset(dset, y_label="y", calibration=None) -> Slice: - class LazyLoadedCompoundField: - """Wrapper to enable lazy loading of HDF5 compound datasets - - Notes - ----- - We only need to support the methods `__array__()` and `__len__()`, as we only access - `LazyLoadedCompoundField` via the properties `TimeSeries.data`, `timestamps` and the - method `__len__()`. - - `LazyLoadCompoundField` might be replaced with `dset.fields(fieldname)` if and when the - returned `FieldsWrapper` object provides an `__array__()` method itself""" - - def __init__(self, dset, fieldname): - self._dset = dset - self._fieldname = fieldname - - def __array__(self): - """Get the data of the field as an array""" - return self._dset[self._fieldname] - - def __len__(self): - """Get the length of the underlying dataset""" - return len(self._dset) - - data = LazyLoadedCompoundField(dset, "Value") - timestamps = LazyLoadedCompoundField(dset, "Timestamp") + data = caching.from_h5py(dset, field="Value") + timestamps = caching.from_h5py(dset, field="Timestamp") return Slice( TimeSeries(data, timestamps), labels={"title": dset.name.strip("/"), "y": y_label}, @@ -850,15 +824,11 @@ def to_dataset(self, parent, name, **kwargs): @property def data(self) -> npt.ArrayLike: - if self._cached_data is None: - self._cached_data = np.asarray(self._src_data) - return self._cached_data + return self.read_lazy_cache("data", self._src_data) @property def timestamps(self) -> npt.ArrayLike: - if self._cached_timestamps is None: - self._cached_timestamps = np.asarray(self._src_timestamps) - return self._cached_timestamps + return self.read_lazy_cache("timestamps", self._src_timestamps) @property def start(self): @@ -893,7 +863,7 @@ def downsampled_by(self, factor, reduce): raise NotImplementedError("Downsampling is currently not available for time series data") -class TimeTags: +class TimeTags(caching.LazyCacheMixin): """A source of time tag data for a timeline slice Parameters @@ -907,13 +877,32 @@ class TimeTags: """ def __init__(self, data, start=None, stop=None): - self.data = np.asarray(data, dtype=np.int64) - self.start = start if start is not None else (self.data[0] if self.data.size > 0 else 0) - self.stop = stop if stop is not None else (self.data[-1] + 1 if self.data.size > 0 else 0) + super().__init__() + self._src_data = data + self._start = start + self._stop = stop def __len__(self): return self.data.size + @property + def start(self): + return ( + self._start if self._start is not None else (self.data[0] if self.data.size > 0 else 0) + ) + + @property + def stop(self): + return ( + self._stop + if self._stop is not None + else (self.data[-1] + 1 if self.data.size > 0 else 0) + ) + + @property + def data(self): + return self.read_lazy_cache("data", self._src_data) + def _with_data(self, data): raise NotImplementedError("Time tags do not currently support this operation") @@ -922,7 +911,10 @@ def _apply_mask(self, mask): @staticmethod def from_dataset(dset, y_label="y"): - return Slice(TimeTags(dset)) + return Slice( + TimeTags(caching.from_h5py(dset)), + labels={"title": dset.name.strip("/"), "y": y_label}, + ) def to_dataset(self, parent, name, **kwargs): """Save this to an h5 dataset diff --git a/lumicks/pylake/detail/caching.py b/lumicks/pylake/detail/caching.py new file mode 100644 index 000000000..ca323ec05 --- /dev/null +++ b/lumicks/pylake/detail/caching.py @@ -0,0 +1,96 @@ +import numpy as np +from cachetools import LRUCache, cached + +global_cache = False + + +def set_cache_enabled(enabled): + """Enable or disable the global cache + + Pylake offers a global cache. When the global cache is enabled, all `Slice` objects come from + the same cache. + + Parameters + ---------- + enabled : bool + Whether the caching should be enabled (by default it is off) + """ + global global_cache + global_cache = enabled + + +@cached(LRUCache(maxsize=1 << 30, getsizeof=lambda x: x.nbytes), info=True) # 1 GB of cache +def _get_array(cache_object): + return cache_object.read_array() + + +class LazyCache: + def __init__(self, location, dset): + """A lazy globally cached wrapper around an object that is convertible to a numpy array""" + self._location = location + self._dset = dset + + def __len__(self): + return len(self._dset) + + def __hash__(self): + return hash(self._location) + + @staticmethod + def from_h5py_dset(dset, field=None): + location = f"{dset.file.filename}{dset.name}" + if field: + location = f"{location}.{field}" + dset = dset.fields(field) + + return LazyCache(location, dset) + + def read_array(self): + # Note, we deliberately do _not_ allow additional arguments to asarray since we would + # have to hash those with and unless necessary, they would unnecessarily increase the + # cache (because of sometimes defensively adding an explicit type). It's better to raise + # in this case and end up at this comment. + arr = np.asarray(self._dset) + arr.flags.writeable = False + return arr + + def __eq__(self, other): + return self._location == other._location + + def __array__(self): + return _get_array(self) + + +def from_h5py(dset, field=None): + global global_cache + return ( + LazyCache.from_h5py_dset(dset, field=field) + if global_cache + else dset.fields(field) if field else dset + ) + + +class LazyCacheMixin: + def __init__(self): + self._cache = {} + + def read_lazy_cache(self, key, src_field): + """A small convenience decorator to incorporate a lazy cache for properties. + Data will be stored in the `_cache` variable of the instance. + + Parameters + ---------- + key : str + Key to use when caching this data + src_field : LazyCache or dset + Source field to read from + """ + global global_cache + + if global_cache: + return np.asarray(src_field) + + if key not in self._cache: + self._cache[key] = np.asarray(src_field) + + return self._cache[key] diff --git a/lumicks/pylake/file.py b/lumicks/pylake/file.py index 94d66a08d..55073e550 100644 --- a/lumicks/pylake/file.py +++ b/lumicks/pylake/file.py @@ -1,3 +1,4 @@ +import pathlib import warnings from typing import Dict @@ -50,7 +51,7 @@ class File(Group, Force, DownsampledFD, BaselineCorrectedForce, PhotonCounts, Ph def __init__(self, filename, *, rgb_to_detectors=None): import h5py - super().__init__(h5py.File(filename, "r"), lk_file=self) + super().__init__(h5py.File(pathlib.Path(filename).absolute(), "r"), lk_file=self) self._check_file_format() self._rgb_to_detectors = self._get_detector_mapping(rgb_to_detectors) diff --git a/lumicks/pylake/tests/test_channels/conftest.py b/lumicks/pylake/tests/test_channels/conftest.py index cf01d6b4a..de6251892 100644 --- a/lumicks/pylake/tests/test_channels/conftest.py +++ b/lumicks/pylake/tests/test_channels/conftest.py @@ -31,3 +31,17 @@ def channel_h5_file(tmpdir_factory, request): mock_file.make_continuous_channel("Photon count", "Red", np.int64(20e9), freq, counts) return mock_file.file + + +@pytest.fixture(autouse=True, scope="module", params=[False, True]) +def cache_setting(request): + from copy import deepcopy + + from lumicks.pylake.detail.caching import global_cache, set_cache_enabled + + old_value = deepcopy(global_cache) + try: + set_cache_enabled(request.param) + yield + finally: + set_cache_enabled(old_value) diff --git a/lumicks/pylake/tests/test_channels/test_channels.py b/lumicks/pylake/tests/test_channels/test_channels.py index 349d941cc..16b3fba55 100644 --- a/lumicks/pylake/tests/test_channels/test_channels.py +++ b/lumicks/pylake/tests/test_channels/test_channels.py @@ -8,6 +8,7 @@ import matplotlib as mpl from lumicks.pylake import channel +from lumicks.pylake.detail import caching from lumicks.pylake.low_level import make_continuous_slice from lumicks.pylake.calibration import ForceCalibrationList @@ -893,7 +894,10 @@ def test_annotation_bad_item(): def test_regression_lazy_loading(channel_h5_file): ch = channel.Continuous.from_dataset(channel_h5_file["Force HF"]["Force 1x"]) - assert isinstance(ch._src._src_data, h5py.Dataset) + if caching.global_cache: + assert isinstance(ch._src._src_data._dset, h5py.Dataset) + else: + assert isinstance(ch._src._src_data, h5py.Dataset) @pytest.mark.parametrize( diff --git a/lumicks/pylake/tests/test_file/conftest.py b/lumicks/pylake/tests/test_file/conftest.py index a9d877f7a..c95f6568c 100644 --- a/lumicks/pylake/tests/test_file/conftest.py +++ b/lumicks/pylake/tests/test_file/conftest.py @@ -284,3 +284,15 @@ def h5_two_colors(tmpdir_factory, request): mock_file.make_continuous_channel("Photon count", "Blue", np.int64(20e9), freq, counts) mock_file.make_continuous_channel("Info wave", "Info wave", np.int64(20e9), freq, infowave) return mock_file.file + + +@pytest.fixture(autouse=True, scope="module", params=[False, True]) +def cache_setting(request): + from lumicks.pylake.detail import caching + + old_value = caching.global_cache + try: + caching.set_cache_enabled(request.param) + yield + finally: + caching.set_cache_enabled(old_value) diff --git a/lumicks/pylake/tests/test_file/test_caching.py b/lumicks/pylake/tests/test_file/test_caching.py new file mode 100644 index 000000000..da5a03321 --- /dev/null +++ b/lumicks/pylake/tests/test_file/test_caching.py @@ -0,0 +1,65 @@ +import pytest + +from lumicks import pylake +from lumicks.pylake.detail.caching import _get_array + + +def test_global_cache_continuous(h5_file): + pylake.set_cache_enabled(True) + _get_array.cache_clear() + + # Load the file (never storing the file handle) + f1x1 = pylake.File.from_h5py(h5_file)["Force HF/Force 1x"] + f1x2 = pylake.File.from_h5py(h5_file).force1x + assert _get_array.cache_info().hits == 0 # No cache used yet (lazy loading) + + # These should point to the same data + assert id(f1x1.data) == id(f1x2.data) + assert _get_array.cache_info().hits == 1 + assert _get_array.cache_info().currsize == 40 + + with pytest.raises(ValueError, match="assignment destination is read-only"): + f1x1.data[5:100] = 3 + + file = pylake.File.from_h5py(h5_file) + assert id(file.force1x.data) == id(file.force1x.data) + + +def test_global_cache_timeseries(h5_file): + pylake.set_cache_enabled(True) + _get_array.cache_clear() + + f1x1 = pylake.File.from_h5py(h5_file).downsampled_force1x + f1x2 = pylake.File.from_h5py(h5_file).downsampled_force1x + assert _get_array.cache_info().hits == 0 # No cache used yet (lazy loading) + + # These should point to the same data + assert id(f1x1.data) == id(f1x2.data) + assert _get_array.cache_info().hits == 1 + assert _get_array.cache_info().currsize == 16 + assert id(f1x1.timestamps) == id(f1x2.timestamps) + assert _get_array.cache_info().hits == 2 + assert _get_array.cache_info().currsize == 32 + + with pytest.raises(ValueError, match="assignment destination is read-only"): + f1x1.data[5:100] = 3 + + with pytest.raises(ValueError, match="assignment destination is read-only"): + f1x1.timestamps[5:100] = 3 + + +def test_global_cache_timetags(h5_file): + pylake.set_cache_enabled(True) + if pylake.File.from_h5py(h5_file).format_version == 2: + _get_array.cache_clear() + tags1 = pylake.File.from_h5py(h5_file)["Photon Time Tags"]["Red"] + tags2 = pylake.File.from_h5py(h5_file)["Photon Time Tags"]["Red"] + assert _get_array.cache_info().hits == 0 + + # These should point to the same data + assert id(tags1.data) == id(tags2.data) + assert _get_array.cache_info().hits == 1 + assert _get_array.cache_info().currsize == 72 + + with pytest.raises(ValueError, match="assignment destination is read-only"): + tags1.data[5:100] = 3 diff --git a/pyproject.toml b/pyproject.toml index 057759a5f..c758b5b38 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,13 +26,13 @@ classifiers=[ ] dependencies = [ "pytest>=3.5", - "h5py>=3.4, <4", + "h5py>=3.8, <4", # Minimum bound needed for using __array__ on Dataset.fields() "numpy>=1.24", # 1.24 is needed for dtype in vstack/hstack (Dec 18th, 2022) "scipy>=1.9, <2", # 1.9.0 needed for lazy imports (July 29th, 2022) "matplotlib>=3.8", "tifffile>=2022.7.28", "tabulate>=0.8.8, <0.9", - "cachetools>=3.1", + "cachetools>=5.0.0", "deprecated>=1.2.8", "scikit-learn>=0.18.0", "scikit-image>=0.17.2",