From d9ddbffa2dfa4df6b43a8f538a1149f132865822 Mon Sep 17 00:00:00 2001 From: rpauszek Date: Thu, 30 Jan 2025 10:17:40 +0100 Subject: [PATCH] kymo: centralize unit formatting in PositionUnit --- lumicks/pylake/kymo.py | 21 ++++-- .../kymotracker/detail/msd_estimation.py | 28 +++----- lumicks/pylake/kymotracker/kymotrack.py | 27 ++++---- .../tests/test_derived_quantities/test_msd.py | 64 ++++++++++++------- .../pylake/nb_widgets/kymotracker_widgets.py | 10 +-- .../test_kymo_transforms.py | 2 +- 6 files changed, 82 insertions(+), 70 deletions(-) diff --git a/lumicks/pylake/kymo.py b/lumicks/pylake/kymo.py index 35c9aa947..3dbf2ab5b 100644 --- a/lumicks/pylake/kymo.py +++ b/lumicks/pylake/kymo.py @@ -382,7 +382,7 @@ def plot( axes.set_axis_off() if scale_bar and not image_handle: - scale_bar._attach_scale_bar(axes, 60.0, 1.0, "s", self._calibration.unit_label) + scale_bar._attach_scale_bar(axes, 60.0, 1.0, "s", self._calibration.unit.label) image = self._get_plot_data(channel, adjustment) @@ -410,7 +410,7 @@ def plot( **{**default_kwargs, **kwargs}, ) axes.set_xlabel("time (s)") - axes.set_ylabel(f"position ({self._calibration.unit_label})") + axes.set_ylabel(f"position ({self._calibration.unit.label})") if show_title: axes.set_title(self.name) @@ -1107,6 +1107,7 @@ class PositionUnit(Enum): um = UnitInfo(name="um", label=r"μm") kbp = UnitInfo(name="kbp", label="kbp") pixel = UnitInfo(name="pixel", label="pixels") + au = UnitInfo(name="au", label="au") def __str__(self): return self.value.name @@ -1118,6 +1119,18 @@ def __hash__(self): def label(self): return self.value.label + def get_diffusion_labels(self) -> dict: + return { + "unit": f"{self}^2 / s", + "_unit_label": f"{self.label}²/s", + } + + def get_squared_labels(self) -> dict: + return { + "unit": f"{self}^2", + "_unit_label": f"{self.label}²", + } + @dataclass(frozen=True) class PositionCalibration: @@ -1141,10 +1154,6 @@ def to_pixels(self, calibrated): def pixelsize(self): return np.abs(self.scale) - @property - def unit_label(self): - return self.unit.label - def downsample(self, factor): return ( self diff --git a/lumicks/pylake/kymotracker/detail/msd_estimation.py b/lumicks/pylake/kymotracker/detail/msd_estimation.py index 51719d01c..4f81aff52 100644 --- a/lumicks/pylake/kymotracker/detail/msd_estimation.py +++ b/lumicks/pylake/kymotracker/detail/msd_estimation.py @@ -5,6 +5,8 @@ import numpy as np import numpy.typing as npt +from ...kymo import PositionUnit + @dataclass(frozen=True) class DiffusionEstimate: @@ -269,9 +271,7 @@ def calculate_msd(frame_idx, position, max_lag): return frame_lags, msd_estimates -def calculate_ensemble_msd( - line_msds, time_step, unit="au", unit_label="au", min_count=2 -) -> EnsembleMSD: +def calculate_ensemble_msd(line_msds, time_step, unit=PositionUnit.au, min_count=2) -> EnsembleMSD: """Calculate ensemble MSDs. Parameters @@ -305,9 +305,8 @@ def calculate_ensemble_msd( variance=variance, counts=counts, effective_sample_size=effective_sample_size, - unit=f"{unit}^2", _time_step=time_step, - _unit_label=f"{unit_label}²", + **unit.get_squared_labels(), ) @@ -502,13 +501,7 @@ def fallback(warning_message): def estimate_diffusion_constant_simple( - frame_idx, - coordinate, - time_step, - max_lag, - method, - unit="au", - unit_label="au", + frame_idx, coordinate, time_step, max_lag, method, unit=PositionUnit.au ): r"""Estimate diffusion constant @@ -616,8 +609,7 @@ def estimate_diffusion_constant_simple( num_points=len(coordinate), localization_variance=intercept / 2.0, method=method, - unit=unit, - _unit_label=unit_label, + **unit.get_diffusion_labels(), ) @@ -994,7 +986,6 @@ def estimate_diffusion_cve( dt, blur_constant, unit, - unit_label, localization_var=None, var_of_localization_var=None, ) -> DiffusionEstimate: @@ -1034,9 +1025,8 @@ def estimate_diffusion_cve( num_points=len(coordinate), localization_variance=localization_var, method="cve", - unit=unit, - _unit_label=unit_label, variance_of_localization_variance=var_of_localization_var, + **unit.get_diffusion_labels(), ) @@ -1159,7 +1149,6 @@ def ensemble_ols(kymotracks, max_lag): time_step = kymotracks._kymos[0].line_time_seconds to_time = 1.0 / (2.0 * time_step) - src_calibration = kymotracks._kymos[0]._calibration return DiffusionEstimate( value=slope * to_time, std_err=np.sqrt(var_slope / np.mean(ensemble_msd.effective_sample_size)) * to_time, @@ -1167,6 +1156,5 @@ def ensemble_ols(kymotracks, max_lag): num_points=sum(len(t) for t in kymotracks), localization_variance=intercept / 2.0, method="ensemble ols", - unit=f"{src_calibration.unit}^2 / s", - _unit_label=f"{src_calibration.unit_label}²/s", + **kymotracks._calibration.unit.get_diffusion_labels(), ) diff --git a/lumicks/pylake/kymotracker/kymotrack.py b/lumicks/pylake/kymotracker/kymotrack.py index cffc06423..f02241e81 100644 --- a/lumicks/pylake/kymotracker/kymotrack.py +++ b/lumicks/pylake/kymotracker/kymotrack.py @@ -106,7 +106,7 @@ def export_kymotrackgroup_to_csv( ) time_units = "seconds" - position_units = kymotrack_group._calibration_info.unit + position_units = kymotrack_group._calibration.unit idx = np.hstack([np.full(len(track), idx) for idx, track in enumerate(kymotrack_group)]) coords_idx = np.hstack([track.coordinate_idx for track in kymotrack_group]) @@ -610,7 +610,7 @@ def plot_fit(self, node_idx, *, fit_kwargs=None, data_kwargs=None, show_data=Tru ) plt.plot(*model_fit, **{"color": "C0"} | replace_key_aliases(fit_kwargs or {}, aliases)) - plt.xlabel(f"Position [{self._kymo._calibration.unit_label}]") + plt.xlabel(f"Position [{self._kymo._calibration.unit.label}]") plt.ylabel("Photon counts [#]") def _check_ends_are_defined(self): @@ -844,7 +844,7 @@ def plot(self, *, show_outline=True, show_labels=True, axes=None, **kwargs): ax.plot(self.seconds, self.position, path_effects=[pe.Normal()], **kwargs) if show_labels: - ax.set_ylabel(f"position ({self._kymo._calibration.unit_label})") + ax.set_ylabel(f"position ({self._kymo._calibration.unit.label})") ax.set_xlabel("time (s)") def msd(self, max_lag=None): @@ -913,7 +913,7 @@ def plot_msd(self, max_lag=None, **kwargs): lag_time, msd = self.msd(max_lag) plt.plot(lag_time, msd, **kwargs) plt.xlabel("Lag time [s]") - plt.ylabel(f"Mean Squared Displacement [{self._kymo._calibration.unit_label}²]") + plt.ylabel(f"Mean Squared Displacement [{self._kymo._calibration.unit.label}²]") def estimate_diffusion( self, @@ -1039,10 +1039,6 @@ def estimate_diffusion( ) frame_idx, positions = np.array(self.time_idx, dtype=int), np.array(self.position) - unit_labels = { - "unit": f"{self._kymo._calibration.unit}^2 / s", - "unit_label": f"{self._kymo._calibration.unit_label}²/s", - } if method == "cve": try: @@ -1068,7 +1064,7 @@ def estimate_diffusion( frame_idx, positions, self._line_time_seconds, - **unit_labels, + unit=self._kymo._calibration.unit, blur_constant=blur, localization_var=localization_variance, var_of_localization_var=variance_of_localization_variance, @@ -1095,7 +1091,7 @@ def estimate_diffusion( self._line_time_seconds, max_lag, method, - **unit_labels, + unit=self._kymo._calibration.unit, ) @@ -1216,7 +1212,7 @@ def _validate_single_linetime_pixelsize(self): if len(pixel_sizes) == 1 else ( "All source kymographs must have the same pixel sizes, " - f"got {sorted(pixel_sizes)} {self._calibration_info.unit}." + f"got {sorted(pixel_sizes)} {self._calibration.unit}." ) ) @@ -1281,7 +1277,7 @@ def _channel(self): raise RuntimeError("No channel associated with this empty group (no tracks available)") @property - def _calibration_info(self): + def _calibration(self): try: kymo = self._kymos[0] return kymo._calibration @@ -1473,7 +1469,7 @@ def plot(self, *, show_outline=True, show_labels=True, axes=None, **kwargs): track.plot(show_outline=show_outline, show_labels=False, axes=ax, **kwargs) if show_labels: - ax.set_ylabel(f"position ({self._calibration_info.unit_label})") + ax.set_ylabel(f"position ({self._calibration.unit.label})") ax.set_xlabel("time (s)") def _tracks_in_frame(self, frame_idx): @@ -1902,7 +1898,7 @@ def plot_binding_histogram(self, kind, bins=10, **kwargs): widths = np.diff(edges) plt.bar(edges[:-1], counts, width=widths, align="edge", **kwargs) plt.ylabel("Counts") - plt.xlabel(f"Position ({self._calibration_info.unit_label})") + plt.xlabel(f"Position ({self._calibration.unit.label})") def _histogram_binding_profile(self, n_time_bins, bandwidth, n_position_points, roi=None): """Calculate a Kernel Density Estimate (KDE) of binding density along the tether for time bins. @@ -2149,6 +2145,5 @@ def ensemble_msd(self, max_lag=None, min_count=2) -> EnsembleMSD: line_msds=track_msds, time_step=self._kymos[0].line_time_seconds, min_count=min_count, - unit=self._calibration_info.unit, - unit_label=self._calibration_info.unit_label, + unit=self._calibration.unit, ) diff --git a/lumicks/pylake/kymotracker/tests/test_derived_quantities/test_msd.py b/lumicks/pylake/kymotracker/tests/test_derived_quantities/test_msd.py index deeb3c2bc..2faed269a 100644 --- a/lumicks/pylake/kymotracker/tests/test_derived_quantities/test_msd.py +++ b/lumicks/pylake/kymotracker/tests/test_derived_quantities/test_msd.py @@ -3,6 +3,7 @@ import pytest import matplotlib.pyplot as plt +from lumicks.pylake.kymo import PositionUnit from lumicks.pylake.detail.utilities import temp_seed from lumicks.pylake.simulation.diffusion import _simulate_diffusion_1d from lumicks.pylake.kymotracker.detail.msd_estimation import * @@ -49,11 +50,27 @@ def test_estimate(frame_idx, coordinate, time_step, max_lag, diffusion_const): time_step, max_lag, "ols", - "au", + PositionUnit.au, ) np.testing.assert_allclose(float(diffusion_est), diffusion_const) +def test_bad_unit(): + frame_idx = np.array([1, 2, 3, 4, 5]) + coordinate = np.array([-1.0, 1.0, -1.0, -3.0, -5.0]) + dt = 0.5 + max_lag = 50 + + with pytest.raises( + AttributeError, match="'str' object has no attribute 'get_diffusion_labels'" + ): + estimate_diffusion_constant_simple(frame_idx, coordinate, dt, max_lag, "ols", unit="um") + with pytest.raises( + AttributeError, match="'str' object has no attribute 'get_diffusion_labels'" + ): + estimate_diffusion_cve(frame_idx, coordinate, dt, 0, unit="um") + + def test_maxlag_asserts(): # Max_lag has to be bigger than 2 with pytest.raises(ValueError): @@ -223,7 +240,7 @@ def test_diffusion_estimate_ols( with temp_seed(0): trace = _simulate_diffusion_1d(diffusion, num_points, time_step, obs_noise) diffusion_est = estimate_diffusion_constant_simple( - np.arange(num_points), trace, time_step, max_lag, "ols", "mu^2/s", r"$\mu^2/s$" + np.arange(num_points), trace, time_step, max_lag, "ols", PositionUnit.au ) np.testing.assert_allclose(float(diffusion_est), diff_est) @@ -233,8 +250,8 @@ def test_diffusion_estimate_ols( np.testing.assert_allclose(diffusion_est.std_err, std_err_est) np.testing.assert_allclose(diffusion_est.localization_variance, loc_variance) assert diffusion_est.method == "ols" - assert diffusion_est.unit == "mu^2/s" - assert diffusion_est._unit_label == r"$\mu^2/s$" + assert diffusion_est.unit == "au^2 / s" + assert diffusion_est._unit_label == "au²/s" @pytest.mark.parametrize( @@ -268,7 +285,7 @@ def test_regression_ols_with_skipped_frames( with pytest.warns(RuntimeWarning, match="Your tracks have missing frames"): diffusion_est = estimate_diffusion_constant_simple( - frame_idx, trace, time_step, max_lag, "ols", "mu^2/s", r"$\mu^2/s$" + frame_idx, trace, time_step, max_lag, "ols", PositionUnit.au ) np.testing.assert_allclose(float(diffusion_est), diff_est) @@ -333,7 +350,7 @@ def test_diffusion_estimate_gls( with temp_seed(0): trace = _simulate_diffusion_1d(diffusion, num_points, time_step, obs_noise) diffusion_est = estimate_diffusion_constant_simple( - np.arange(num_points), trace, time_step, max_lag, "gls", "mu^2/s", r"$\mu^2/s$" + np.arange(num_points), trace, time_step, max_lag, "gls", PositionUnit.au ) np.testing.assert_allclose(float(diffusion_est), diff_est) @@ -343,25 +360,27 @@ def test_diffusion_estimate_gls( np.testing.assert_allclose(diffusion_est.std_err, std_err_est) np.testing.assert_allclose(diffusion_est.localization_variance, loc_variance) assert diffusion_est.method == "gls" - assert diffusion_est.unit == "mu^2/s" - assert diffusion_est._unit_label == r"$\mu^2/s$" + assert diffusion_est.unit == "au^2 / s" + assert diffusion_est._unit_label == "au²/s" def test_bad_input(): with pytest.raises(ValueError, match="Invalid method selected."): - estimate_diffusion_constant_simple(np.arange(5), np.arange(5), 1, 2, "glo", "unit") + estimate_diffusion_constant_simple(np.arange(5), np.arange(5), 1, 2, "glo", PositionUnit.au) with pytest.raises( ValueError, match="You need at least two lags to estimate a diffusion constant" ): - estimate_diffusion_constant_simple(np.arange(5), np.arange(5), 1, 1, "gls", "unit") + estimate_diffusion_constant_simple(np.arange(5), np.arange(5), 1, 1, "gls", PositionUnit.au) def test_singular_handling(): with temp_seed(0): trace = _simulate_diffusion_1d(0, 30, 3, 0) with pytest.warns(RuntimeWarning, match="Covariance matrix is singular"): - estimate_diffusion_constant_simple(np.arange(len(trace)), trace, 1, 3, "gls", "unit") + estimate_diffusion_constant_simple( + np.arange(len(trace)), trace, 1, 3, "gls", PositionUnit.au + ) @pytest.mark.parametrize( @@ -474,8 +493,7 @@ def test_estimate_diffusion_cve( trace, time_step, blur_constant, - "mu^2/s", - r"$\mu^2/s$", + PositionUnit.au, localization_var, var_of_localization_var, ) @@ -493,8 +511,8 @@ def test_estimate_diffusion_cve( else: assert diffusion_est.variance_of_localization_variance is None assert diffusion_est.method == "cve" - assert diffusion_est.unit == "mu^2/s" - assert diffusion_est._unit_label == r"$\mu^2/s$" + assert diffusion_est.unit == "au^2 / s" + assert diffusion_est._unit_label == "au²/s" @pytest.mark.parametrize( @@ -646,7 +664,7 @@ def test_ensemble_msd(): ] # By default, the single lag rho (5) should be ignored - result = calculate_ensemble_msd(track_msds, 1.0, unit="what_a_unit", unit_label="label_ahoy") + result = calculate_ensemble_msd(track_msds, 1.0, unit=PositionUnit.um) np.testing.assert_allclose(result.lags, frame_diffs) np.testing.assert_allclose(result.msd, frame_diffs**2) num_means = np.array([3, 3, 3, 2]) # number of means contributing to the estimate @@ -655,8 +673,8 @@ def test_ensemble_msd(): # Tracks are equal length, so the effective sample size is just the means that contributed np.testing.assert_allclose(result.effective_sample_size, num_means) np.testing.assert_allclose(result.sem, np.sqrt(0.02 / ((num_means - 1) * num_means))) - assert result.unit == "what_a_unit^2" - assert result._unit_label == "label_ahoy²" + assert result.unit == "um^2" + assert result._unit_label == r"μm²" def test_ensemble_msd_unequal_points(): @@ -676,6 +694,8 @@ def test_ensemble_msd_unequal_points(): # ESS is less than 2 since we used weighting np.testing.assert_allclose(result.effective_sample_size, np.ones(5) * 9 / 5) np.testing.assert_allclose(result.sem, np.ones(5) * np.sqrt(5 / 2)) + assert result.unit == "au^2" + assert result._unit_label == "au²" def test_ensemble_msd_little_data(): @@ -686,23 +706,23 @@ def test_ensemble_msd_little_data(): with pytest.raises( ValueError, match="Need more than one average to compute a weighted variance" ): - calculate_ensemble_msd([trk1, trk1, trk2], 1.0, unit="au", unit_label="au", min_count=0) + calculate_ensemble_msd([trk1, trk1, trk2], 1.0, min_count=0) for msds in ([trk1], []): with pytest.raises( ValueError, match="You need at least two tracks to compute the ensemble MSD" ): - calculate_ensemble_msd(msds, 1.0, unit="au", unit_label="au", min_count=0) + calculate_ensemble_msd(msds, 1.0, min_count=0) def test_ensemble_msd_plot(): """Test whether the plot spins up""" frame_diffs = np.arange(1, 5, 1) trk1 = [frame_diffs, frame_diffs**2, np.arange(len(frame_diffs), 0, -1)] - calculate_ensemble_msd([trk1, trk1, trk1], 1.0, unit="au", unit_label="label_unit").plot() + calculate_ensemble_msd([trk1, trk1, trk1], 1.0, PositionUnit.kbp).plot() axis = plt.gca() lines = axis.lines[0] np.testing.assert_allclose(lines.get_xdata(), frame_diffs) np.testing.assert_allclose(lines.get_ydata(), frame_diffs**2) assert axis.xaxis.get_label().get_text() == "Time [s]" - assert axis.yaxis.get_label().get_text() == "Squared Displacement [label_unit²]" + assert axis.yaxis.get_label().get_text() == "Squared Displacement [kbp²]" diff --git a/lumicks/pylake/nb_widgets/kymotracker_widgets.py b/lumicks/pylake/nb_widgets/kymotracker_widgets.py index 32511220f..2f46530a9 100644 --- a/lumicks/pylake/nb_widgets/kymotracker_widgets.py +++ b/lumicks/pylake/nb_widgets/kymotracker_widgets.py @@ -847,7 +847,7 @@ def _get_default_parameters(kymo, channel): r"track. Larger values will result in a wider range in which points are added to a " r"track.", abridged_name="Search range", - display_unit=kymo._calibration.unit_label, + display_unit=kymo._calibration.unit.label, ), "window": KymotrackerParameter( "Maximum gap", @@ -891,7 +891,7 @@ def _get_default_parameters(kymo, channel): r"roughly the width of the point spread function. Setting it larger rejects more " r"noise, but at the cost of potentially merging tracks that are close together.", abridged_name="Spot size", - display_unit=kymo._calibration.unit_label, + display_unit=kymo._calibration.unit.label, ), "filter_width": KymotrackerParameter( "Width of the Gaussian filter to apply", @@ -908,7 +908,7 @@ def _get_default_parameters(kymo, channel): r"should be chosen to match the point spread function." ), abridged_name="Filter width", - display_unit=kymo._calibration.unit_label, + display_unit=kymo._calibration.unit.label, ), "adjacency_filter": KymotrackerParameter( "Adjacency filter", @@ -936,7 +936,7 @@ def _get_default_parameters(kymo, channel): r"in future scan lines to connect. Points within a certain distance from the expected " r"future position are connected.", abridged_name="Velocity", - display_unit=f"{kymo._calibration.unit_label}/s", + display_unit=f"{kymo._calibration.unit.label}/s", ), "diffusion": KymotrackerParameter( "Diffusion", @@ -948,7 +948,7 @@ def _get_default_parameters(kymo, channel): r"When tracking, the algorithm searches for points in future frames to connect. Points " r"within a certain distance from the expected future position are connected. The " r"diffusion parameter determines how quickly this distance grows over time.", - display_unit=f"{kymo._calibration.unit_label}²/s", + display_unit=f"{kymo._calibration.unit.label}²/s", ), "sigma_cutoff": KymotrackerParameter( "Sigma cutoff", diff --git a/lumicks/pylake/tests/test_imaging_confocal/test_kymo_transforms.py b/lumicks/pylake/tests/test_imaging_confocal/test_kymo_transforms.py index 9bc9f8df3..bee5c0978 100644 --- a/lumicks/pylake/tests/test_imaging_confocal/test_kymo_transforms.py +++ b/lumicks/pylake/tests/test_imaging_confocal/test_kymo_transforms.py @@ -143,7 +143,7 @@ def test_enum_in_calibration(): PositionCalibration("kbp", scale=0.42) c = PositionCalibration(PositionUnit.um, scale=0.42) - assert c.unit_label == PositionUnit.um.label + assert c.unit.label == PositionUnit.um.label def test_coordinate_transforms():