Skip to content

Commit

Permalink
kymo: centralize unit formatting in PositionUnit
Browse files Browse the repository at this point in the history
  • Loading branch information
rpauszek committed Feb 3, 2025
1 parent 98f7e2c commit d9ddbff
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 70 deletions.
21 changes: 15 additions & 6 deletions lumicks/pylake/kymo.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def plot(
axes.set_axis_off()

if scale_bar and not image_handle:
scale_bar._attach_scale_bar(axes, 60.0, 1.0, "s", self._calibration.unit_label)
scale_bar._attach_scale_bar(axes, 60.0, 1.0, "s", self._calibration.unit.label)

image = self._get_plot_data(channel, adjustment)

Expand Down Expand Up @@ -410,7 +410,7 @@ def plot(
**{**default_kwargs, **kwargs},
)
axes.set_xlabel("time (s)")
axes.set_ylabel(f"position ({self._calibration.unit_label})")
axes.set_ylabel(f"position ({self._calibration.unit.label})")
if show_title:
axes.set_title(self.name)

Expand Down Expand Up @@ -1107,6 +1107,7 @@ class PositionUnit(Enum):
um = UnitInfo(name="um", label=r"μm")
kbp = UnitInfo(name="kbp", label="kbp")
pixel = UnitInfo(name="pixel", label="pixels")
au = UnitInfo(name="au", label="au")

def __str__(self):
return self.value.name
Expand All @@ -1118,6 +1119,18 @@ def __hash__(self):
def label(self):
return self.value.label

def get_diffusion_labels(self) -> dict:
return {
"unit": f"{self}^2 / s",
"_unit_label": f"{self.label}²/s",
}

def get_squared_labels(self) -> dict:
return {
"unit": f"{self}^2",
"_unit_label": f"{self.label}²",
}


@dataclass(frozen=True)
class PositionCalibration:
Expand All @@ -1141,10 +1154,6 @@ def to_pixels(self, calibrated):
def pixelsize(self):
return np.abs(self.scale)

@property
def unit_label(self):
return self.unit.label

def downsample(self, factor):
return (
self
Expand Down
28 changes: 8 additions & 20 deletions lumicks/pylake/kymotracker/detail/msd_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import numpy as np
import numpy.typing as npt

from ...kymo import PositionUnit


@dataclass(frozen=True)
class DiffusionEstimate:
Expand Down Expand Up @@ -269,9 +271,7 @@ def calculate_msd(frame_idx, position, max_lag):
return frame_lags, msd_estimates


def calculate_ensemble_msd(
line_msds, time_step, unit="au", unit_label="au", min_count=2
) -> EnsembleMSD:
def calculate_ensemble_msd(line_msds, time_step, unit=PositionUnit.au, min_count=2) -> EnsembleMSD:
"""Calculate ensemble MSDs.
Parameters
Expand Down Expand Up @@ -305,9 +305,8 @@ def calculate_ensemble_msd(
variance=variance,
counts=counts,
effective_sample_size=effective_sample_size,
unit=f"{unit}^2",
_time_step=time_step,
_unit_label=f"{unit_label}²",
**unit.get_squared_labels(),
)


Expand Down Expand Up @@ -502,13 +501,7 @@ def fallback(warning_message):


def estimate_diffusion_constant_simple(
frame_idx,
coordinate,
time_step,
max_lag,
method,
unit="au",
unit_label="au",
frame_idx, coordinate, time_step, max_lag, method, unit=PositionUnit.au
):
r"""Estimate diffusion constant
Expand Down Expand Up @@ -616,8 +609,7 @@ def estimate_diffusion_constant_simple(
num_points=len(coordinate),
localization_variance=intercept / 2.0,
method=method,
unit=unit,
_unit_label=unit_label,
**unit.get_diffusion_labels(),
)


Expand Down Expand Up @@ -994,7 +986,6 @@ def estimate_diffusion_cve(
dt,
blur_constant,
unit,
unit_label,
localization_var=None,
var_of_localization_var=None,
) -> DiffusionEstimate:
Expand Down Expand Up @@ -1034,9 +1025,8 @@ def estimate_diffusion_cve(
num_points=len(coordinate),
localization_variance=localization_var,
method="cve",
unit=unit,
_unit_label=unit_label,
variance_of_localization_variance=var_of_localization_var,
**unit.get_diffusion_labels(),
)


Expand Down Expand Up @@ -1159,14 +1149,12 @@ def ensemble_ols(kymotracks, max_lag):
time_step = kymotracks._kymos[0].line_time_seconds
to_time = 1.0 / (2.0 * time_step)

src_calibration = kymotracks._kymos[0]._calibration
return DiffusionEstimate(
value=slope * to_time,
std_err=np.sqrt(var_slope / np.mean(ensemble_msd.effective_sample_size)) * to_time,
num_lags=optimal_lags,
num_points=sum(len(t) for t in kymotracks),
localization_variance=intercept / 2.0,
method="ensemble ols",
unit=f"{src_calibration.unit}^2 / s",
_unit_label=f"{src_calibration.unit_label}²/s",
**kymotracks._calibration.unit.get_diffusion_labels(),
)
27 changes: 11 additions & 16 deletions lumicks/pylake/kymotracker/kymotrack.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def export_kymotrackgroup_to_csv(
)

time_units = "seconds"
position_units = kymotrack_group._calibration_info.unit
position_units = kymotrack_group._calibration.unit

idx = np.hstack([np.full(len(track), idx) for idx, track in enumerate(kymotrack_group)])
coords_idx = np.hstack([track.coordinate_idx for track in kymotrack_group])
Expand Down Expand Up @@ -610,7 +610,7 @@ def plot_fit(self, node_idx, *, fit_kwargs=None, data_kwargs=None, show_data=Tru
)

plt.plot(*model_fit, **{"color": "C0"} | replace_key_aliases(fit_kwargs or {}, aliases))
plt.xlabel(f"Position [{self._kymo._calibration.unit_label}]")
plt.xlabel(f"Position [{self._kymo._calibration.unit.label}]")
plt.ylabel("Photon counts [#]")

def _check_ends_are_defined(self):
Expand Down Expand Up @@ -844,7 +844,7 @@ def plot(self, *, show_outline=True, show_labels=True, axes=None, **kwargs):
ax.plot(self.seconds, self.position, path_effects=[pe.Normal()], **kwargs)

if show_labels:
ax.set_ylabel(f"position ({self._kymo._calibration.unit_label})")
ax.set_ylabel(f"position ({self._kymo._calibration.unit.label})")
ax.set_xlabel("time (s)")

def msd(self, max_lag=None):
Expand Down Expand Up @@ -913,7 +913,7 @@ def plot_msd(self, max_lag=None, **kwargs):
lag_time, msd = self.msd(max_lag)
plt.plot(lag_time, msd, **kwargs)
plt.xlabel("Lag time [s]")
plt.ylabel(f"Mean Squared Displacement [{self._kymo._calibration.unit_label}²]")
plt.ylabel(f"Mean Squared Displacement [{self._kymo._calibration.unit.label}²]")

def estimate_diffusion(
self,
Expand Down Expand Up @@ -1039,10 +1039,6 @@ def estimate_diffusion(
)

frame_idx, positions = np.array(self.time_idx, dtype=int), np.array(self.position)
unit_labels = {
"unit": f"{self._kymo._calibration.unit}^2 / s",
"unit_label": f"{self._kymo._calibration.unit_label}²/s",
}

if method == "cve":
try:
Expand All @@ -1068,7 +1064,7 @@ def estimate_diffusion(
frame_idx,
positions,
self._line_time_seconds,
**unit_labels,
unit=self._kymo._calibration.unit,
blur_constant=blur,
localization_var=localization_variance,
var_of_localization_var=variance_of_localization_variance,
Expand All @@ -1095,7 +1091,7 @@ def estimate_diffusion(
self._line_time_seconds,
max_lag,
method,
**unit_labels,
unit=self._kymo._calibration.unit,
)


Expand Down Expand Up @@ -1216,7 +1212,7 @@ def _validate_single_linetime_pixelsize(self):
if len(pixel_sizes) == 1
else (
"All source kymographs must have the same pixel sizes, "
f"got {sorted(pixel_sizes)} {self._calibration_info.unit}."
f"got {sorted(pixel_sizes)} {self._calibration.unit}."
)
)

Expand Down Expand Up @@ -1281,7 +1277,7 @@ def _channel(self):
raise RuntimeError("No channel associated with this empty group (no tracks available)")

@property
def _calibration_info(self):
def _calibration(self):
try:
kymo = self._kymos[0]
return kymo._calibration
Expand Down Expand Up @@ -1473,7 +1469,7 @@ def plot(self, *, show_outline=True, show_labels=True, axes=None, **kwargs):
track.plot(show_outline=show_outline, show_labels=False, axes=ax, **kwargs)

if show_labels:
ax.set_ylabel(f"position ({self._calibration_info.unit_label})")
ax.set_ylabel(f"position ({self._calibration.unit.label})")
ax.set_xlabel("time (s)")

def _tracks_in_frame(self, frame_idx):
Expand Down Expand Up @@ -1902,7 +1898,7 @@ def plot_binding_histogram(self, kind, bins=10, **kwargs):
widths = np.diff(edges)
plt.bar(edges[:-1], counts, width=widths, align="edge", **kwargs)
plt.ylabel("Counts")
plt.xlabel(f"Position ({self._calibration_info.unit_label})")
plt.xlabel(f"Position ({self._calibration.unit.label})")

def _histogram_binding_profile(self, n_time_bins, bandwidth, n_position_points, roi=None):
"""Calculate a Kernel Density Estimate (KDE) of binding density along the tether for time bins.
Expand Down Expand Up @@ -2149,6 +2145,5 @@ def ensemble_msd(self, max_lag=None, min_count=2) -> EnsembleMSD:
line_msds=track_msds,
time_step=self._kymos[0].line_time_seconds,
min_count=min_count,
unit=self._calibration_info.unit,
unit_label=self._calibration_info.unit_label,
unit=self._calibration.unit,
)
Loading

0 comments on commit d9ddbff

Please sign in to comment.