Skip to content

Commit

Permalink
estimator cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
vilim committed Dec 28, 2021
1 parent 77587c6 commit 3439e78
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 170 deletions.
255 changes: 86 additions & 169 deletions stytra/stimulation/estimators.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import numpy as np
import datetime
from collections import namedtuple
from typing import NamedTuple, Optional, Tuple

import numpy as np

from stytra.collectors import QueueDataAccumulator
from stytra.collectors.namedtuplequeue import NamedTupleQueue
from stytra.utilities import reduce_to_pi
from collections import namedtuple


class Estimator:
Expand All @@ -15,11 +18,19 @@ class Estimator:

def __init__(self, acc_tracking: QueueDataAccumulator, experiment):
self.exp = experiment
self.log = experiment.estimator_log
self.acc_tracking = acc_tracking
self.output_queue = NamedTupleQueue()
self._output_type = None

def update(self):
raise NotImplementedError

def reset(self):
self.log.reset()
pass


class VigorEstimate(NamedTuple):
vigor: float


class VigorMotionEstimator(Estimator):
Expand All @@ -34,27 +45,15 @@ def __init__(self, *args, vigor_window=0.050, base_gain=-12, **kwargs):
self.vigor_window = vigor_window
self.last_dt = 1 / 500.0
self.base_gain = base_gain
self._output_type = namedtuple("s", "vigor")

def get_velocity(self, lag=0):
"""
Parameters
----------
lag :
(Default value = 0)
self._output_type = namedtuple("vigor_estimate", ("vigor",))

Returns
-------
"""
def get_vigor(self):
vigor_n_samples = max(int(round(self.vigor_window / self.last_dt)), 2)
n_samples_lag = max(int(round(lag / self.last_dt)), 0)
if not self.acc_tracking.stored_data:
return 0
past_tail_motion = self.acc_tracking.get_last_n(
vigor_n_samples + n_samples_lag
)[0:vigor_n_samples]
past_tail_motion = self.acc_tracking.get_last_n(vigor_n_samples)[
0:vigor_n_samples
]
end_t = past_tail_motion.t.iloc[-1]
start_t = past_tail_motion.t.iloc[0]
new_dt = (end_t - start_t) / vigor_n_samples
Expand All @@ -63,10 +62,15 @@ def get_velocity(self, lag=0):
vigor = np.nanstd(np.array(past_tail_motion.tail_sum))
if np.isnan(vigor):
vigor = 0
return end_t, vigor

def update(self):
end_t, vigor = self.get_vigor()
self.output_queue.put(end_t, self._output_type(vigor))

if len(self.log.times) == 0 or self.log.times[-1] < end_t:
self.log.update_list(end_t, self._output_type(vigor))
return vigor * self.base_gain

class BoutEstimate(NamedTuple):
is_bouting: bool


class BoutsEstimator(VigorMotionEstimator):
Expand All @@ -78,138 +82,59 @@ def __init__(
self.vigor_window = vigor_window
self.min_interbout = min_interbout
self.last_bout_t = None
self._output_type = namedtuple("bouts", ("is_bouting",))

def bout_occured(self):
if self.get_velocity() > self.base_gain * self.bout_threshold:
def update(self):
end_t, vigor = self.get_vigor()
is_bouting = False
if vigor > self.base_gain * self.bout_threshold:
if (
self.last_bout_t is None
or (datetime.datetime.now() - self.last_bout_t).total_seconds()
> self.min_interbout
):
self.last_bout_t = datetime.datetime.now()
return True
return False
is_bouting = True
self.output_queue.put(end_t, self._output_type(is_bouting))


class TailSumEstimator(Estimator):
def __init__(
self,
*args,
vigor_window=0.050,
theta_window=0.07,
base_gain=-30,
bout_threshold=0.05,
min_interbout=0.1,
**kwargs
):
super().__init__(*args, **kwargs)
self.vigor_window = vigor_window
self.theta_window = theta_window
self.last_dt = 1 / 500.0
self.base_gain = base_gain
self._output_type = namedtuple("s", ("vigor", "theta", "bout_on"))
self.bout_threshold = bout_threshold
self.vigor_window = vigor_window
self.min_interbout = min_interbout
self.last_bout_t = None
self.prev_time_on = False
self.bout_onset = 0
self.bout_on = 0
self.theta_provided = True
self.last_bout_t = 0
self.last_vigor = 0
self.last_bout_on = 0

self.tail_th = 0

def bout_occured(self):
if self.bout_on:
if (
self.last_bout_t is None
or (datetime.datetime.now() - self.last_bout_t).total_seconds()
> self.min_interbout
):
self.last_bout_t = datetime.datetime.now()
return True
return False
class EmbeddedBoutEstimate(NamedTuple):
vigor: float
theta: float
bout_on: bool

def get_vel_and_theta(self, lag=0):
"""

Parameters
----------
lag :
(Default value = 0)
class PositionEstimate(NamedTuple):
x: float
y: float
theta: float

Returns
-------

"""
# Vigor (copypasted from VigorEstimator method for simplicity)
vigor_n_samples = max(int(round(self.vigor_window / self.last_dt)), 2)
n_samples_lag = max(int(round(lag / self.last_dt)), 0)
if not self.acc_tracking.stored_data:
return 0, 0, 0
past_tail_motion = self.acc_tracking.get_last_n(
vigor_n_samples + n_samples_lag
)[0:vigor_n_samples]
end_t = past_tail_motion.t.iloc[-1]
start_t = past_tail_motion.t.iloc[0]
new_dt = (end_t - start_t) / vigor_n_samples
if new_dt > 0:
self.last_dt = new_dt
vigor = np.nanstd(np.array(past_tail_motion.tail_sum))
def _propagate_change_above_threshold(
current_estimate: PositionEstimate,
previous_estimate: Optional[PositionEstimate],
thresholds: PositionEstimate,
) -> PositionEstimate:
"""Return updated components of a position if the component changed enough, otherwise return the old component"""
if previous_estimate is None:
return current_estimate

if vigor is not None:
self.bout_on = int(vigor > self.bout_threshold)
else:
self.bout_on = int(self.last_vigor > self.bout_threshold)

if self.bout_onset == 0:
if self.bout_on and not self.last_bout_on:
self.bout_onset = 1
self.bout_start_t = datetime.datetime.now()

else:
self.theta_provided = False
self.bout_onset = 0

if (
not self.theta_provided
): # and (datetime.datetime.now() - self.bout_start_t).total_seconds() > 0.07:
# Tail theta:
th_n_samples = max(int(round(self.theta_window / self.last_dt)), 2)
n_samples_lag = max(int(round(lag / self.last_dt)), 0)

past_tail_motion = self.acc_tracking.get_last_n(
th_n_samples + n_samples_lag
)[0:th_n_samples]
self.tail_th = np.nanmean(
np.array(past_tail_motion.tail_sum) - past_tail_motion.tail_sum.iloc[0]
)
self.theta_provided = True
else:
self.tail_th = self.tail_th * (3 / 4)
rn = np.random.randint(0, 1) / 100
on_ns = self.bout_onset + rn
if len(self.log.times) == 0 or self.log.times[-1] < end_t:
self.log.update_list(end_t, self._output_type(vigor, self.tail_th, on_ns))

if vigor is not None:
self.last_vigor = vigor

self.last_bout_on = self.bout_on

return vigor * self.base_gain, -self.tail_th * 3, self.bout_on


def rot_mat(theta):
"""The rotation matrix for an angle theta"""
return np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
return PositionEstimate(
x=current_estimate.x
if abs(current_estimate.x - previous_estimate.x) > thresholds.x
else previous_estimate.x,
y=current_estimate.x
if abs(current_estimate.y - previous_estimate.y) > thresholds.y
else previous_estimate.y,
theta=current_estimate.x
if abs(reduce_to_pi(current_estimate.theta - previous_estimate.theta))
> thresholds.theta
else previous_estimate.theta,
)


class PositionEstimator(Estimator):
def __init__(self, *args, change_thresholds=None, velocity_window=10, **kwargs):
def __init__(self, *args, change_thresholds:Optional[PositionEstimate]=None, velocity_window:int=10, **kwargs):
"""Uses the projector-to-camera calibration to give fish position in
scree coordinates. If change_thresholds are set, update only the fish
position after there is a big enough change (which prevents small
Expand All @@ -223,14 +148,12 @@ def __init__(self, *args, change_thresholds=None, velocity_window=10, **kwargs):
super().__init__(*args, **kwargs)
self.calibrator = self.exp.calibrator
self.last_location = None
self.past_values = None
self.previous_position = None

self.velocity_window = velocity_window
self.change_thresholds = change_thresholds
if change_thresholds is not None:
self.change_thresholds = np.array(change_thresholds)

self._output_type = namedtuple("f", ["x", "y", "theta"])
self._output_type = PositionEstimate

def get_camera_position(self):
past_coords = {
Expand All @@ -248,17 +171,10 @@ def get_velocity(self):
)
return np.sqrt(np.sum(vel ** 2))

def get_istantaneous_velocity(self):
vel_xy = self.acc_tracking.get_last_n(self.velocity_window)[
["f0_vx", "f0_vy"]
].values
return np.sqrt(np.sum(vel_xy ** 2))

def reset(self):
super().reset()
self.past_values = None
self.previous_position = None

def get_position(self):
def get_position(self) -> Tuple[float, PositionEstimate]:
if len(self.acc_tracking.stored_data) == 0 or not np.isfinite(
self.acc_tracking.stored_data[-1].f0_x
):
Expand Down Expand Up @@ -286,23 +202,21 @@ def get_position(self):
else:
x, y, theta = past_coords.f0_x, past_coords.f0_y, past_coords.f0_theta

c_values = np.array((y, x, theta))
current_position = PositionEstimate(x, y, theta)

if self.change_thresholds is not None:
if self.previous_position is None:
self.previous_position = current_position

if self.past_values is None:
self.past_values = np.array(c_values)
else:
deltas = c_values - self.past_values
deltas[2] = reduce_to_pi(deltas[2])
sel = np.abs(deltas) > self.change_thresholds
self.past_values[sel] = c_values[sel]
c_values = self.past_values
current_position = _propagate_change_above_threshold(
current_position, self.previous_position, self.change_thresholds
)
self.previous_position = current_position

logout = self._output_type(*c_values)
self.log.update_list(t, logout)
return t, current_position

return c_values
def update(self):
self.output_queue.put(*self.get_position())


class SimulatedPositionEstimator(Estimator):
Expand All @@ -319,15 +233,18 @@ def __init__(self, *args, motion, **kwargs):
"""
super().__init__(*args, **kwargs)
self.motion = motion
self._output_type = namedtuple("f", ["x", "y", "theta"])
self._output_type = PositionEstimate

def get_position(self):
def get_position(self) -> Tuple[float, PositionEstimate]:
t = (datetime.datetime.now() - self.exp.t0).total_seconds()

kt = tuple(
np.interp(t, self.motion.t, self.motion[p]) for p in ("y", "x", "theta")
kt = PositionEstimate(
*(np.interp(t, self.motion.t, self.motion[p]) for p in ("x", "y", "theta"))
)
return kt
return t, kt

def update(self):
self.output_queue.put(*self.get_position())


estimator_dict = dict(
Expand Down
2 changes: 1 addition & 1 deletion stytra/stimulation/stimuli/closed_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
self.prev_bout_t = 0

def get_fish_vel(self):
"""Function that update estimated fish velocty. Change to add lag or
"""Function that update estimated fish velocity. Change to add lag or
shunting.
"""
self.fish_vel = self._experiment.estimator.get_velocity()
Expand Down
1 change: 1 addition & 0 deletions stytra/stimulation/stimuli/generic_stimuli.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(self, duration=0.0):
self._elapsed = 0.0 # time from the beginning of the stimulus
self.name = "undefined"
self._experiment = None
self._input_queue = None
self.real_time_start = None
self.real_time_stop = None

Expand Down
5 changes: 5 additions & 0 deletions stytra/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,8 @@ def save_df(df, path, fileformat):
else:
raise (NotImplementedError(fileformat + " is not an implemented log format"))
return outpath.name


def rot_mat(theta):
"""The rotation matrix for an angle theta"""
return np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])

0 comments on commit 3439e78

Please sign in to comment.