diff --git a/lumicks/pylake/detail/caching.py b/lumicks/pylake/detail/caching.py index ca323ec05..85e31892b 100644 --- a/lumicks/pylake/detail/caching.py +++ b/lumicks/pylake/detail/caching.py @@ -1,5 +1,7 @@ +import sys + import numpy as np -from cachetools import LRUCache, cached +from cachetools import LRUCache, keys, cached, cachedmethod global_cache = False @@ -94,3 +96,59 @@ def read_lazy_cache(self, key, src_field): self._cache[key] = np.asarray(src_field) return self._cache[key] + + +def _getsize(x): + return x.nbytes if isinstance(x, np.ndarray) else sys.getsizeof(x) + + +_method_cache = LRUCache(maxsize=1 << 30, getsizeof=_getsize) # 1 GB of cache + + +def method_cache(name): + """A small convenience decorator to incorporate some really basic instance method memoization + + Note: When used on properties, this one should be included _after_ the @property decorator. + Data will be stored in the `_cache` variable of the instance. + + Parameters + ---------- + name : str + Name of the instance method to memo-ize. Suggestion: the instance method name. + + Examples + -------- + :: + + class Test: + def __init__(self): + self._cache = {} + ... + + @property + @method_cache("example_property") + def example_property(self): + return 10 + + @method_cache("example_method") + def example_method(self, arguments): + return 5 + + + test = Test() + test.example_property + test.example_method("hi") + test._cache + # test._cache will now show {('example_property',): 10, ('example_method', 'hi'): 5} + """ + + # cachetools>=5.0.0 passes self as first argument. We don't want to bump the reference count + # by including a reference to the object we're about to store the cache into, so we explicitly + # drop the first argument. Note that for the default key, they do the same in the package, but + # we can't use the default key, since it doesn't hash in the method name. + def key(self, *args, **kwargs): + return keys.hashkey(self._location, name, *args, **kwargs) + + return cachedmethod( + lambda self: _method_cache if global_cache and self._location else self._cache, key=key + ) diff --git a/lumicks/pylake/detail/confocal.py b/lumicks/pylake/detail/confocal.py index 64b3bfe3e..d8e27797e 100644 --- a/lumicks/pylake/detail/confocal.py +++ b/lumicks/pylake/detail/confocal.py @@ -10,8 +10,9 @@ from .image import reconstruct_image, reconstruct_image_sum from .mixin import PhotonCounts, ExcitationLaserPower +from .caching import method_cache from .plotting import parse_color_channel -from .utilities import method_cache, could_sum_overflow +from .utilities import could_sum_overflow from ..adjustments import no_adjustment from .imaging_mixins import TiffExport @@ -208,9 +209,11 @@ class BaseScan(PhotonCounts, ExcitationLaserPower): End point in the relevant info wave. metadata : ScanMetaData Metadata. + location : str | None + Path of the confocal object. """ - def __init__(self, name, file, start, stop, metadata): + def __init__(self, name, file, start, stop, metadata, location): self.start = start self.stop = stop self.name = name @@ -220,6 +223,7 @@ def __init__(self, name, file, start, stop, metadata): self._timestamp_factory = _default_timestamp_factory self._pixelsize_factory = _default_pixelsize_factory self._pixelcount_factory = _default_pixelcount_factory + self._location = location self._cache = {} def _has_default_factories(self): @@ -243,12 +247,13 @@ def from_dataset(cls, h5py_dset, file): start = h5py_dset.attrs["Start time (ns)"] stop = h5py_dset.attrs["Stop time (ns)"] name = h5py_dset.name.split("/")[-1] + location = file.h5.filename + h5py_dset.name try: metadata = ScanMetaData.from_json(h5py_dset[()]) except KeyError: raise KeyError(f"{cls.__name__} '{name}' is missing metadata and cannot be loaded") - return cls(name, file, start, stop, metadata) + return cls(name, file, start, stop, metadata, location) @property def file(self): @@ -269,6 +274,9 @@ def __copy__(self): start=self.start, stop=self.stop, metadata=self._metadata, + # If it has no location, it will be cached only locally. This is safer than implicitly + # caching it under the same location as the parent. + location=None, ) # Preserve custom factories @@ -512,5 +520,4 @@ def get_image(self, channel="rgb") -> np.ndarray: if channel not in ("red", "green", "blue"): return np.stack([self.get_image(color) for color in ("red", "green", "blue")], axis=-1) else: - # Make sure we don't return a reference to our cache return self._image(channel) diff --git a/lumicks/pylake/detail/tests/test_caching.py b/lumicks/pylake/detail/tests/test_caching.py new file mode 100644 index 000000000..3f85eb721 --- /dev/null +++ b/lumicks/pylake/detail/tests/test_caching.py @@ -0,0 +1,73 @@ +import pytest + +from lumicks.pylake.detail import caching + + +@pytest.mark.parametrize( + "location, use_global_cache", + [ + (None, False), + (None, True), + ("test", False), + ("test", True), + ], +) +def test_cache_method(location, use_global_cache): + calls = 0 + + def call(): + nonlocal calls + calls += 1 + + class Test: + def __init__(self, location): + self._cache = {} + self._location = location + + @property + @caching.method_cache("example_property") + def example_property(self): + call() + return 10 + + @caching.method_cache("example_method") + def example_method(self, argument=5): + call() + return argument + + old_cache = caching.global_cache + caching.set_cache_enabled(use_global_cache) + caching._method_cache.clear() + test = Test(location=location) + + cache_location = caching._method_cache if use_global_cache and location else test._cache + + assert len(cache_location) == 0 + assert test.example_property == 10 + assert len(cache_location) == 1 + assert calls == 1 + assert test.example_property == 10 + assert calls == 1 + assert len(cache_location) == 1 + + assert test.example_method() == 5 + assert calls == 2 + assert len(cache_location) == 2 + + assert test.example_method() == 5 + assert calls == 2 + assert len(cache_location) == 2 + + assert test.example_method(6) == 6 + assert calls == 3 + assert len(cache_location) == 3 + + assert test.example_method(6) == 6 + assert calls == 3 + assert len(cache_location) == 3 + + assert test.example_method() == 5 + assert calls == 3 + assert len(cache_location) == 3 + + caching.set_cache_enabled(old_cache) diff --git a/lumicks/pylake/detail/utilities.py b/lumicks/pylake/detail/utilities.py index 5816a7dcb..21b8456b6 100644 --- a/lumicks/pylake/detail/utilities.py +++ b/lumicks/pylake/detail/utilities.py @@ -2,60 +2,6 @@ import contextlib import numpy as np -import cachetools - - -def method_cache(name): - """A small convenience decorator to incorporate some really basic instance method memoization - - Note: When used on properties, this one should be included _after_ the @property decorator. - Data will be stored in the `_cache` variable of the instance. - - Parameters - ---------- - name : str - Name of the instance method to memo-ize. Suggestion: the instance method name. - - Examples - -------- - :: - - class Test: - def __init__(self): - self._cache = {} - ... - - @property - @method_cache("example_property") - def example_property(self): - return 10 - - @method_cache("example_method") - def example_method(self, arguments): - return 5 - - - test = Test() - test.example_property - test.example_method("hi") - test._cache - # test._cache will now show {('example_property',): 10, ('example_method', 'hi'): 5} - """ - if int(cachetools.__version__.split(".")[0]) < 5: - - def key(*args, **kwargs): - return cachetools.keys.hashkey(name, *args, **kwargs) - - else: - # cachetools>=5.0.0 started passing self as first argument. We don't want to bump the - # reference count by including a reference to the object we're about to store the cache - # into, so we explicitly drop the first argument. Note that for the default key, they - # do the same in the package, but we can't use the default key, since it doesn't hash - # in the method name. - def key(_, *args, **kwargs): - return cachetools.keys.hashkey(name, *args, **kwargs) - - return cachetools.cachedmethod(lambda self: self._cache, key=key) def use_docstring_from(copy_func): diff --git a/lumicks/pylake/kymo.py b/lumicks/pylake/kymo.py index 3643056e8..0b9811fc4 100644 --- a/lumicks/pylake/kymo.py +++ b/lumicks/pylake/kymo.py @@ -13,10 +13,10 @@ seek_timestamp_next_line, first_pixel_sample_indices, ) +from .detail.caching import method_cache from .detail.confocal import ScanAxis, ScanMetaData, ConfocalImage from .detail.plotting import get_axes, show_image from .detail.timeindex import to_timestamp -from .detail.utilities import method_cache from .detail.bead_cropping import find_beads_template, find_beads_brightness @@ -83,10 +83,22 @@ class Kymo(ConfocalImage): Coordinate position offset with respect to the original raw data. calibration : PositionCalibration Class defining calibration from microns to desired position units. + location : str | None + Path of the Kymo. """ - def __init__(self, name, file, start, stop, metadata, position_offset=0, calibration=None): - super().__init__(name, file, start, stop, metadata) + def __init__( + self, + name, + file, + start, + stop, + metadata, + location=None, + position_offset=0, + calibration=None, + ): + super().__init__(name, file, start, stop, metadata, location) self._line_time_factory = _default_line_time_factory self._line_timestamp_ranges_factory = _default_line_timestamp_ranges_factory self._position_offset = position_offset diff --git a/lumicks/pylake/low_level/low_level.py b/lumicks/pylake/low_level/low_level.py index 855d68473..56e0c8400 100644 --- a/lumicks/pylake/low_level/low_level.py +++ b/lumicks/pylake/low_level/low_level.py @@ -32,7 +32,9 @@ def create_confocal_object( metadata = ScanMetaData.from_json(json_metadata) file = ConfocalFileProxy(infowave, red_channel, green_channel, blue_channel) confocal_cls = {0: PointScan, 1: Kymo, 2: Scan} - return confocal_cls[metadata.num_axes](name, file, infowave.start, infowave.stop, metadata) + return confocal_cls[metadata.num_axes]( + name, file, infowave.start, infowave.stop, metadata, location=None + ) def make_continuous_slice(data, start, dt, y_label="y", name="") -> Slice: diff --git a/lumicks/pylake/scan.py b/lumicks/pylake/scan.py index 31cce0a6b..e12082030 100644 --- a/lumicks/pylake/scan.py +++ b/lumicks/pylake/scan.py @@ -5,9 +5,9 @@ from .adjustments import colormaps, no_adjustment from .detail.image import make_image_title, reconstruct_num_frames, first_pixel_sample_indices +from .detail.caching import method_cache from .detail.confocal import ConfocalImage from .detail.plotting import get_axes, show_image -from .detail.utilities import method_cache from .detail.imaging_mixins import FrameIndex, VideoExport @@ -26,10 +26,12 @@ class Scan(ConfocalImage, VideoExport, FrameIndex): End point in the relevant info wave. metadata : ScanMetaData Metadata. + location : str | None + Path of the Scan. """ - def __init__(self, name, file, start, stop, metadata): - super().__init__(name, file, start, stop, metadata) + def __init__(self, name, file, start, stop, metadata, location=None): + super().__init__(name, file, start, stop, metadata, location) if self._metadata.num_axes == 1: raise RuntimeError("1D scans are not supported") if self._metadata.num_axes > 2: diff --git a/lumicks/pylake/tests/test_utilities.py b/lumicks/pylake/tests/test_utilities.py index 5b90eae81..afd945e45 100644 --- a/lumicks/pylake/tests/test_utilities.py +++ b/lumicks/pylake/tests/test_utilities.py @@ -1,18 +1,11 @@ import re -import numpy as np import pytest import matplotlib as mpl from numpy.testing import assert_array_equal from lumicks.pylake.detail.confocal import timestamp_mean from lumicks.pylake.detail.utilities import * -from lumicks.pylake.detail.utilities import ( - method_cache, - will_mul_overflow, - could_sum_overflow, - replace_key_aliases, -) def test_first(): @@ -293,48 +286,3 @@ def test_ref_dict_freezing_fail(request, compare_to_reference_dict): ), ): compare_to_reference_dict({"a": 5, "b": 5}, file_name="ref_dict_freezing_None_2") - - -def test_cache_method(): - calls = 0 - - def call(): - nonlocal calls - calls += 1 - - class Test: - def __init__(self): - self._cache = {} - - @property - @method_cache("example_property") - def example_property(self): - call() - return 10 - - @method_cache("example_method") - def example_method(self, argument=5): - call() - return argument - - test = Test() - assert len(test._cache) == 0 - assert test.example_property == 10 - assert len(test._cache) == 1 - assert calls == 1 - assert test.example_property == 10 - assert calls == 1 - assert len(test._cache) == 1 - - assert test.example_method() == 5 - assert calls == 2 - assert len(test._cache) == 2 - assert test.example_method() == 5 - assert calls == 2 - assert len(test._cache) == 2 - assert test.example_method(6) == 6 - assert calls == 3 - assert len(test._cache) == 3 - assert test.example_method(6) == 6 - assert calls == 3 - assert len(test._cache) == 3