Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

caching: centralize caching #688

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
channel: add global cache for channels
JoepVanlier committed Dec 11, 2024
commit b5984efd3a80746be360e3d7281ded467513c04e
1 change: 1 addition & 0 deletions lumicks/pylake/__init__.py
Original file line number Diff line number Diff line change
@@ -19,6 +19,7 @@
from .fitting.fit import FdFit
from .image_stack import ImageStack, CorrelatedStack
from .file_download import *
from .detail.caching import set_cache_enabled
from .fitting.models import *
from .fitting.parameter_trace import parameter_trace
from .kymotracker.kymotracker import *
84 changes: 38 additions & 46 deletions lumicks/pylake/channel.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
import numpy as np
import numpy.typing as npt

from .detail import caching
from .detail.plotting import _annotate
from .detail.timeindex import to_seconds, to_timestamp
from .detail.utilities import downsample, convert_to_scalar
@@ -681,7 +682,7 @@ def range_selector(self, show=True, **kwargs) -> SliceRangeSelectorWidget:
return SliceRangeSelectorWidget(self, show=show, **kwargs)


class Continuous:
class Continuous(caching.LazyCacheMixin):
"""A source of continuous data for a timeline slice

Parameters
@@ -695,8 +696,8 @@ class Continuous:
"""

def __init__(self, data, start, dt):
super().__init__()
self._src_data = data
self._cached_data = None
self.start = start
self.stop = start + len(data) * dt
self.dt = dt # ns
@@ -717,7 +718,7 @@ def from_dataset(dset, y_label="y", calibration=None):
start = dset.attrs["Start time (ns)"]
dt = int(1e9 / dset.attrs["Sample rate (Hz)"]) # ns
return Slice(
Continuous(dset, start, dt),
Continuous(caching.from_h5py(dset), start, dt),
labels={"title": dset.name.strip("/"), "y": y_label},
calibration=calibration,
)
@@ -743,9 +744,7 @@ def to_dataset(self, parent, name, **kwargs):

@property
def data(self) -> npt.ArrayLike:
if self._cached_data is None:
self._cached_data = np.asarray(self._src_data)
return self._cached_data
return self.read_lazy_cache("data", self._src_data)

@property
def timestamps(self) -> npt.ArrayLike:
@@ -779,7 +778,7 @@ def downsampled_by(self, factor, reduce):
)


class TimeSeries:
class TimeSeries(caching.LazyCacheMixin):
"""A source of time series data for a timeline slice

Parameters
@@ -802,10 +801,9 @@ def __init__(self, data, timestamps):
f"({len(timestamps)})."
)

super().__init__()
self._src_data = data
self._cached_data = None
self._src_timestamps = timestamps
self._cached_timestamps = None

def __len__(self):
return len(self._src_data)
@@ -820,32 +818,8 @@ def _apply_mask(self, mask):

@staticmethod
def from_dataset(dset, y_label="y", calibration=None) -> Slice:
class LazyLoadedCompoundField:
"""Wrapper to enable lazy loading of HDF5 compound datasets

Notes
-----
We only need to support the methods `__array__()` and `__len__()`, as we only access
`LazyLoadedCompoundField` via the properties `TimeSeries.data`, `timestamps` and the
method `__len__()`.

`LazyLoadCompoundField` might be replaced with `dset.fields(fieldname)` if and when the
returned `FieldsWrapper` object provides an `__array__()` method itself"""

def __init__(self, dset, fieldname):
self._dset = dset
self._fieldname = fieldname

def __array__(self):
"""Get the data of the field as an array"""
return self._dset[self._fieldname]

def __len__(self):
"""Get the length of the underlying dataset"""
return len(self._dset)

data = LazyLoadedCompoundField(dset, "Value")
timestamps = LazyLoadedCompoundField(dset, "Timestamp")
data = caching.from_h5py(dset, field="Value")
timestamps = caching.from_h5py(dset, field="Timestamp")
return Slice(
TimeSeries(data, timestamps),
labels={"title": dset.name.strip("/"), "y": y_label},
@@ -874,15 +848,11 @@ def to_dataset(self, parent, name, **kwargs):

@property
def data(self) -> npt.ArrayLike:
if self._cached_data is None:
self._cached_data = np.asarray(self._src_data)
return self._cached_data
return self.read_lazy_cache("data", self._src_data)

@property
def timestamps(self) -> npt.ArrayLike:
if self._cached_timestamps is None:
self._cached_timestamps = np.asarray(self._src_timestamps)
return self._cached_timestamps
return self.read_lazy_cache("timestamps", self._src_timestamps)

@property
def start(self):
@@ -917,7 +887,7 @@ def downsampled_by(self, factor, reduce):
raise NotImplementedError("Downsampling is currently not available for time series data")


class TimeTags:
class TimeTags(caching.LazyCacheMixin):
"""A source of time tag data for a timeline slice

Parameters
@@ -931,13 +901,32 @@ class TimeTags:
"""

def __init__(self, data, start=None, stop=None):
self.data = np.asarray(data, dtype=np.int64)
self.start = start if start is not None else (self.data[0] if self.data.size > 0 else 0)
self.stop = stop if stop is not None else (self.data[-1] + 1 if self.data.size > 0 else 0)
super().__init__()
self._src_data = data
self._start = start
self._stop = stop

def __len__(self):
return self.data.size

@property
def start(self):
return (
self._start if self._start is not None else (self.data[0] if self.data.size > 0 else 0)
)

@property
def stop(self):
return (
self._stop
if self._stop is not None
else (self.data[-1] + 1 if self.data.size > 0 else 0)
)

@property
def data(self):
return self.read_lazy_cache("data", self._src_data)

def _with_data(self, data):
raise NotImplementedError("Time tags do not currently support this operation")

@@ -946,7 +935,10 @@ def _apply_mask(self, mask):

@staticmethod
def from_dataset(dset, y_label="y"):
return Slice(TimeTags(dset))
return Slice(
TimeTags(caching.from_h5py(dset)),
labels={"title": dset.name.strip("/"), "y": y_label},
)

def to_dataset(self, parent, name, **kwargs):
"""Save this to an h5 dataset
96 changes: 96 additions & 0 deletions lumicks/pylake/detail/caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import numpy as np
from cachetools import LRUCache, cached

global_cache = False


def set_cache_enabled(enabled):
"""Enable or disable the global cache

Pylake offers a global cache. When the global cache is enabled, all `Slice` objects come from
the same cache.

Parameters
----------
enabled : bool
Whether the caching should be enabled (by default it is off)
"""
global global_cache
global_cache = enabled


@cached(LRUCache(maxsize=1 << 30, getsizeof=lambda x: x.nbytes), info=True) # 1 GB of cache
def _get_array(cache_object):
return cache_object.read_array()


class LazyCache:
def __init__(self, location, dset):
"""A lazy globally cached wrapper around an object that is convertible to a numpy array"""
self._location = location
self._dset = dset

def __len__(self):
return len(self._dset)

def __hash__(self):
return hash(self._location)

@staticmethod
def from_h5py_dset(dset, field=None):
location = f"{dset.file.filename}{dset.name}"
if field:
location = f"{location}.{field}"
dset = dset.fields(field)

return LazyCache(location, dset)

def read_array(self):
# Note, we deliberately do _not_ allow additional arguments to asarray since we would
# have to hash those with and unless necessary, they would unnecessarily increase the
# cache (because of sometimes defensively adding an explicit type). It's better to raise
# in this case and end up at this comment.
arr = np.asarray(self._dset)
arr.flags.writeable = False
return arr

def __eq__(self, other):
return self._location == other._location

def __array__(self):
return _get_array(self)


def from_h5py(dset, field=None):
global global_cache
return (
LazyCache.from_h5py_dset(dset, field=field)
if global_cache
else dset.fields(field) if field else dset
)


class LazyCacheMixin:
def __init__(self):
self._cache = {}

def read_lazy_cache(self, key, src_field):
"""A small convenience decorator to incorporate a lazy cache for properties.
Data will be stored in the `_cache` variable of the instance.

Parameters
----------
key : str
Key to use when caching this data
src_field : LazyCache or dset
Source field to read from
"""
global global_cache

if global_cache:
return np.asarray(src_field)

if key not in self._cache:
self._cache[key] = np.asarray(src_field)

return self._cache[key]
3 changes: 2 additions & 1 deletion lumicks/pylake/file.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pathlib
import warnings
from typing import Dict

@@ -50,7 +51,7 @@ class File(Group, Force, DownsampledFD, BaselineCorrectedForce, PhotonCounts, Ph
def __init__(self, filename, *, rgb_to_detectors=None):
import h5py

super().__init__(h5py.File(filename, "r"), lk_file=self)
super().__init__(h5py.File(pathlib.Path(filename).absolute(), "r"), lk_file=self)
self._check_file_format()
self._rgb_to_detectors = self._get_detector_mapping(rgb_to_detectors)

14 changes: 14 additions & 0 deletions lumicks/pylake/tests/test_channels/conftest.py
Original file line number Diff line number Diff line change
@@ -31,3 +31,17 @@ def channel_h5_file(tmpdir_factory, request):
mock_file.make_continuous_channel("Photon count", "Red", np.int64(20e9), freq, counts)

return mock_file.file


@pytest.fixture(autouse=True, scope="module", params=[False, True])
def cache_setting(request):
from copy import deepcopy

from lumicks.pylake.detail.caching import global_cache, set_cache_enabled

old_value = deepcopy(global_cache)
try:
set_cache_enabled(request.param)
yield
finally:
set_cache_enabled(old_value)
6 changes: 5 additions & 1 deletion lumicks/pylake/tests/test_channels/test_channels.py
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@
import matplotlib as mpl

from lumicks.pylake import channel
from lumicks.pylake.detail import caching
from lumicks.pylake.low_level import make_continuous_slice
from lumicks.pylake.calibration import ForceCalibrationList

@@ -893,7 +894,10 @@ def test_annotation_bad_item():

def test_regression_lazy_loading(channel_h5_file):
ch = channel.Continuous.from_dataset(channel_h5_file["Force HF"]["Force 1x"])
assert isinstance(ch._src._src_data, h5py.Dataset)
if caching.global_cache:
assert isinstance(ch._src._src_data._dset, h5py.Dataset)
else:
assert isinstance(ch._src._src_data, h5py.Dataset)


@pytest.mark.parametrize(
12 changes: 12 additions & 0 deletions lumicks/pylake/tests/test_file/conftest.py
Original file line number Diff line number Diff line change
@@ -284,3 +284,15 @@ def h5_two_colors(tmpdir_factory, request):
mock_file.make_continuous_channel("Photon count", "Blue", np.int64(20e9), freq, counts)
mock_file.make_continuous_channel("Info wave", "Info wave", np.int64(20e9), freq, infowave)
return mock_file.file


@pytest.fixture(autouse=True, scope="module", params=[False, True])
def cache_setting(request):
from lumicks.pylake.detail import caching

old_value = caching.global_cache
try:
caching.set_cache_enabled(request.param)
yield
finally:
caching.set_cache_enabled(old_value)
Loading