Skip to content

Commit

Permalink
channel: add optional global cache
Browse files Browse the repository at this point in the history
  • Loading branch information
JoepVanlier committed Oct 4, 2024
1 parent 96551bc commit b14ffc8
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 37 deletions.
1 change: 1 addition & 0 deletions lumicks/pylake/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from . import simulation
from .file import *
from .channel import set_cache_enabled
from .scalebar import ScaleBar
from .__about__ import (
__doc__,
Expand Down
148 changes: 114 additions & 34 deletions lumicks/pylake/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,80 @@

import numpy as np
import numpy.typing as npt
from cachetools import LRUCache, cached

from .detail.plotting import _annotate
from .detail.timeindex import to_seconds, to_timestamp
from .detail.utilities import downsample
from .nb_widgets.range_selector import SliceRangeSelectorWidget

global_cache = False


def set_cache_enabled(enabled):
"""Enable or disable the global cache
Pylake offers a global cache. When the global cache is enabled, all `Slice` objects come from
the same cache.
Parameters
----------
enabled : bool
Whether the caching should be enabled (by default it is off)
"""
global global_cache
global_cache = enabled


@cached(LRUCache(maxsize=1 << 30, getsizeof=lambda x: x.nbytes), info=True) # 1 GB of cache
def _get_array(cache_object):
return cache_object.read_array()


class LazyCache:
def __init__(self, location, dset, nbytes):
"""A lazy globally cached wrapper around an object that is convertible to a numpy array"""
self._location = location
self._dset = dset
self._nbytes = nbytes

def __len__(self):
return len(self._dset)

@property
def nbytes(self):
return self._nbytes

def __hash__(self):
return hash(self._location)

@staticmethod
def from_h5py_dset(dset, field=None):
location = f"{dset.file.filename}{dset.name}"
if field:
location = f"{location}.{field}"
dset = dset.fields(field)
item_size = dset.read_dtype.itemsize
else:
item_size = dset.dtype.itemsize

return LazyCache(location, dset, nbytes=item_size * len(dset))

def read_array(self):
# Note, we deliberately do _not_ allow additional arguments to asarray since we would
# have to hash those with and unless necessary, they would unnecessarily increase the
# cache (because of sometimes defensively adding an explicit type). It's better to raise
# in this case and end up at this comment.
arr = np.asarray(self._dset)
arr.flags.writeable = False
return arr

def __eq__(self, other):
return self._location == other._location

def __array__(self):
return _get_array(self)


class Slice:
"""A lazily evaluated slice of a timeline/HDF5 channel
Expand Down Expand Up @@ -693,7 +761,7 @@ def from_dataset(dset, y_label="y", calibration=None):
start = dset.attrs["Start time (ns)"]
dt = int(1e9 / dset.attrs["Sample rate (Hz)"]) # ns
return Slice(
Continuous(dset, start, dt),
Continuous(LazyCache.from_h5py_dset(dset) if global_cache else dset, start, dt),
labels={"title": dset.name.strip("/"), "y": y_label},
calibration=calibration,
)
Expand All @@ -719,9 +787,12 @@ def to_dataset(self, parent, name, **kwargs):

@property
def data(self) -> npt.ArrayLike:
if self._cached_data is None:
self._cached_data = np.asarray(self._src_data)
return self._cached_data
if global_cache:
return np.asarray(self._src_data) # Reads from cache if it exists
else:
if self._cached_data is None:
self._cached_data = np.asarray(self._src_data)
return self._cached_data

@property
def timestamps(self) -> npt.ArrayLike:
Expand Down Expand Up @@ -796,32 +867,14 @@ def _apply_mask(self, mask):

@staticmethod
def from_dataset(dset, y_label="y", calibration=None) -> Slice:
class LazyLoadedCompoundField:
"""Wrapper to enable lazy loading of HDF5 compound datasets
Notes
-----
We only need to support the methods `__array__()` and `__len__()`, as we only access
`LazyLoadedCompoundField` via the properties `TimeSeries.data`, `timestamps` and the
method `__len__()`.
`LazyLoadCompoundField` might be replaced with `dset.fields(fieldname)` if and when the
returned `FieldsWrapper` object provides an `__array__()` method itself"""

def __init__(self, dset, fieldname):
self._dset = dset
self._fieldname = fieldname

def __array__(self):
"""Get the data of the field as an array"""
return self._dset[self._fieldname]

def __len__(self):
"""Get the length of the underlying dataset"""
return len(self._dset)

data = LazyLoadedCompoundField(dset, "Value")
timestamps = LazyLoadedCompoundField(dset, "Timestamp")
data = (
LazyCache.from_h5py_dset(dset, field="Value") if global_cache else dset.fields("Value")
)
timestamps = (
LazyCache.from_h5py_dset(dset, field="Timestamp")
if global_cache
else dset.fields("Timestamp")
)
return Slice(
TimeSeries(data, timestamps),
labels={"title": dset.name.strip("/"), "y": y_label},
Expand Down Expand Up @@ -850,12 +903,18 @@ def to_dataset(self, parent, name, **kwargs):

@property
def data(self) -> npt.ArrayLike:
if global_cache:
return np.asarray(self._src_data)

if self._cached_data is None:
self._cached_data = np.asarray(self._src_data)
return self._cached_data

@property
def timestamps(self) -> npt.ArrayLike:
if global_cache:
return np.asarray(self._src_timestamps)

if self._cached_timestamps is None:
self._cached_timestamps = np.asarray(self._src_timestamps)
return self._cached_timestamps
Expand Down Expand Up @@ -907,13 +966,31 @@ class TimeTags:
"""

def __init__(self, data, start=None, stop=None):
self.data = np.asarray(data, dtype=np.int64)
self.start = start if start is not None else (self.data[0] if self.data.size > 0 else 0)
self.stop = stop if stop is not None else (self.data[-1] + 1 if self.data.size > 0 else 0)
self._src_data = data
self._start = start
self._stop = stop

def __len__(self):
return self.data.size

@property
def start(self):
return (
self._start if self._start is not None else (self.data[0] if self.data.size > 0 else 0)
)

@property
def stop(self):
return (
self._stop
if self._stop is not None
else (self.data[-1] + 1 if self.data.size > 0 else 0)
)

@property
def data(self):
return np.asarray(self._src_data)

def _with_data(self, data):
raise NotImplementedError("Time tags do not currently support this operation")

Expand All @@ -922,7 +999,10 @@ def _apply_mask(self, mask):

@staticmethod
def from_dataset(dset, y_label="y"):
return Slice(TimeTags(dset))
return Slice(
TimeTags(LazyCache.from_h5py_dset(dset) if global_cache else dset),
labels={"title": dset.name.strip("/"), "y": y_label},
)

def to_dataset(self, parent, name, **kwargs):
"""Save this to an h5 dataset
Expand Down
3 changes: 2 additions & 1 deletion lumicks/pylake/file.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pathlib
import warnings
from typing import Dict

Expand Down Expand Up @@ -50,7 +51,7 @@ class File(Group, Force, DownsampledFD, BaselineCorrectedForce, PhotonCounts, Ph
def __init__(self, filename, *, rgb_to_detectors=None):
import h5py

super().__init__(h5py.File(filename, "r"), lk_file=self)
super().__init__(h5py.File(pathlib.Path(filename).absolute(), "r"), lk_file=self)
self._check_file_format()
self._rgb_to_detectors = self._get_detector_mapping(rgb_to_detectors)

Expand Down
6 changes: 5 additions & 1 deletion lumicks/pylake/tests/test_channels/test_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
import matplotlib as mpl

import lumicks.pylake.channel
from lumicks.pylake import channel
from lumicks.pylake.low_level import make_continuous_slice
from lumicks.pylake.calibration import ForceCalibrationList
Expand Down Expand Up @@ -893,7 +894,10 @@ def test_annotation_bad_item():

def test_regression_lazy_loading(channel_h5_file):
ch = channel.Continuous.from_dataset(channel_h5_file["Force HF"]["Force 1x"])
assert isinstance(ch._src._src_data, h5py.Dataset)
if lumicks.pylake.channel.global_cache:
assert isinstance(ch._src._src_data._dset, h5py.Dataset)
else:
assert isinstance(ch._src._src_data, h5py.Dataset)


@pytest.mark.parametrize(
Expand Down
65 changes: 65 additions & 0 deletions lumicks/pylake/tests/test_file/test_caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import pytest

from lumicks import pylake
from lumicks.pylake.channel import _get_array


def test_global_cache_continuous(h5_file):
pylake.set_cache_enabled(True)
_get_array.cache_clear()

# Load the file (never storing the file handle)
f1x1 = pylake.File.from_h5py(h5_file)["Force HF/Force 1x"]
f1x2 = pylake.File.from_h5py(h5_file).force1x
assert _get_array.cache_info().hits == 0 # No cache used yet (lazy loading)

# These should point to the same data
assert id(f1x1.data) == id(f1x2.data)
assert _get_array.cache_info().hits == 1
assert _get_array.cache_info().currsize == 40

with pytest.raises(ValueError, match="assignment destination is read-only"):
f1x1.data[5:100] = 3

file = pylake.File.from_h5py(h5_file)
assert id(file.force1x.data) == id(file.force1x.data)


def test_global_cache_timeseries(h5_file):
pylake.set_cache_enabled(True)
_get_array.cache_clear()

f1x1 = pylake.File.from_h5py(h5_file).downsampled_force1x
f1x2 = pylake.File.from_h5py(h5_file).downsampled_force1x
assert _get_array.cache_info().hits == 0 # No cache used yet (lazy loading)

# These should point to the same data
assert id(f1x1.data) == id(f1x2.data)
assert _get_array.cache_info().hits == 1
assert _get_array.cache_info().currsize == 16
assert id(f1x1.timestamps) == id(f1x2.timestamps)
assert _get_array.cache_info().hits == 2
assert _get_array.cache_info().currsize == 32

with pytest.raises(ValueError, match="assignment destination is read-only"):
f1x1.data[5:100] = 3

with pytest.raises(ValueError, match="assignment destination is read-only"):
f1x1.timestamps[5:100] = 3


def test_global_cache_timetags(h5_file):
pylake.set_cache_enabled(True)
if pylake.File.from_h5py(h5_file).format_version == 2:
_get_array.cache_clear()
tags1 = pylake.File.from_h5py(h5_file)["Photon Time Tags"]["Red"]
tags2 = pylake.File.from_h5py(h5_file)["Photon Time Tags"]["Red"]
assert _get_array.cache_info().hits == 0

# These should point to the same data
assert id(tags1.data) == id(tags2.data)
assert _get_array.cache_info().hits == 1
assert _get_array.cache_info().currsize == 72

with pytest.raises(ValueError, match="assignment destination is read-only"):
tags1.data[5:100] = 3
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ classifiers=[
]
dependencies = [
"pytest>=3.5",
"h5py>=3.4, <4",
"h5py>=3.8, <4", # Minimum bound needed for using __array__ on Dataset.fields()
"numpy>=1.24, <2", # 1.24 is needed for dtype in vstack/hstack (Dec 18th, 2022)
"scipy>=1.9, <2", # 1.9.0 needed for lazy imports (July 29th, 2022)
"matplotlib>=3.8",
Expand Down

0 comments on commit b14ffc8

Please sign in to comment.