Skip to content

Commit

Permalink
estimator in progress
Browse files Browse the repository at this point in the history
  • Loading branch information
vilim committed Jan 23, 2022
1 parent 3439e78 commit eb18b54
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 50 deletions.
33 changes: 24 additions & 9 deletions stytra/collectors/accumulators.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

from PyQt5.QtCore import QObject, pyqtSignal
import datetime
import numpy as np
Expand All @@ -7,14 +9,15 @@
from bisect import bisect_right
from os.path import basename

from stytra.collectors.namedtuplequeue import NamedTupleQueue
from stytra.utilities import save_df


class Accumulator(QObject):
def __init__(self, experiment, name="", max_history_if_not_running=1000):
super().__init__()
self.name = name
self.exp = experiment
#self.exp = experiment
self.stored_data = []
self.times = []
self.max_history_if_not_running = max_history_if_not_running
Expand Down Expand Up @@ -76,20 +79,23 @@ def __getitem__(self, item):
def t(self):
return np.array(self.times)

def values_at_abs_time(self, time):
def values_at_abs_time(self, time, t0):
"""Finds the values in the accumulator closest to the datetime time
Parameters
----------
time : datetime
time to search for
t0:
reference time 0
Returns
-------
namedtuple of values
"""
find_time = (time - self.exp.t0).total_seconds()
find_time = (time - t0).total_seconds()
i = bisect_right(self.times, find_time)
return self.stored_data[i - 1]

Expand Down Expand Up @@ -239,31 +245,40 @@ class QueueDataAccumulator(DataFrameAccumulator):
Parameters
----------
data_queue : (multiprocessing.Queue object)
data_queue : NamedTupleQueue
queue from witch to retrieve data.
output_queue:Optional[NamedTupleQueue]
an optinal queue to forward the data to
header_list : list of str
headers for the data to stored.
Returns
-------
"""

def __init__(self, data_queue, **kwargs):
def __init__(
self,
data_queue: NamedTupleQueue,
output_queue: Optional[NamedTupleQueue] = None,
**kwargs
):
""" """
super().__init__(**kwargs)

# Store externally the starting time make us free to keep
# only time differences in milliseconds in the list (faster)
self.starting_time = None
self.data_queue = data_queue
self.output_queue = output_queue

def update_list(self):
"""Upon calling put all available data into a list."""
while True:
try:
# Get data from queue:
t, data = self.data_queue.get(timeout=0.001)

if self.output_queue is not None:
self.output_queue.put(t, data)

newtype = False
if len(self.stored_data) == 0 or type(data) != type(
self.stored_data[-1]
Expand Down Expand Up @@ -313,7 +328,7 @@ def __init__(self, *args, queue, **kwargs):
super().__init__(*args, **kwargs)
self.queue = queue

def update_list(self):
def update_list(self, fps):
while True:
try:
# Get data from queue:
Expand Down
2 changes: 1 addition & 1 deletion stytra/examples/custom_tracking_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def retrieve_image(self):
# To match tracked points and frame displayed looks for matching
# timestamps of the displayed frame and of tracked queue:
retrieved_data = self.experiment.acc_tracking.values_at_abs_time(
self.current_frame_time
self.current_frame_time, self.experiment.t0
)

# Check for valid data to be displayed:
Expand Down
54 changes: 23 additions & 31 deletions stytra/experiments/tracking_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
EstimatorLog,
FramerateQueueAccumulator,
)
from stytra.stimulation.estimator_process import EstimatorProcess
from stytra.tracking.tracking_process import TrackingProcess
from stytra.tracking.pipelines import Pipeline
from stytra.collectors.namedtuplequeue import NamedTupleQueue
Expand Down Expand Up @@ -191,9 +192,7 @@ class TrackingExperiment(CameraVisualExperiment):
"""

def __init__(
self, *args, tracking, recording=None, second_output_queue=None, **kwargs
):
def __init__(self, *args, tracking, recording=None, second_output_queue=None, **kwargs):
"""
:param tracking_method: class with the parameters for tracking (instance
of TrackingMethod class, defined in the child);
Expand All @@ -210,14 +209,10 @@ def __init__(
super().__init__(*args, **kwargs)
self.arguments.update(locals())

self.recording_event = (
Event() if (recording is not None or recording is False) else None
)
self.recording_event = Event() if (recording is not None or recording is False) else None

self.pipeline_cls = (
pipeline_dict.get(tracking["method"], None)
if isinstance(tracking["method"], str)
else tracking["method"]
pipeline_dict.get(tracking["method"], None) if isinstance(tracking["method"], str) else tracking["method"]
)

self.frame_dispatcher = TrackingProcess(
Expand All @@ -237,20 +232,6 @@ def __init__(
assert isinstance(self.pipeline, Pipeline)
self.pipeline.setup(tree=self.dc)

self.acc_tracking = QueueDataAccumulator(
name="tracking",
experiment=self,
data_queue=self.tracking_output_queue,
monitored_headers=self.pipeline.headers_to_plot,
)
self.acc_tracking.sig_acc_init.connect(self.refresh_plots)

# Data accumulator is updated with GUI timer:
self.gui_timer.timeout.connect(self.acc_tracking.update_list)

# Tracking is reset at experiment start:
self.protocol_runner.sig_protocol_started.connect(self.acc_tracking.reset)

# start frame dispatcher process:
self.frame_dispatcher.start()

Expand All @@ -263,15 +244,28 @@ def __init__(
est = est_type

if est is not None:
self.estimator_process = EstimatorProcess(est_type, self.tracking_output_queue, self.finished_sig)
self.estimator_log = EstimatorLog(experiment=self)
self.estimator = est(
self.acc_tracking,
experiment=self,
**tracking.get("estimator_params", {})
)
self.estimator = est(self.acc_tracking, experiment=self, **tracking.get("estimator_params", {}))
self.estimator_log.sig_acc_init.connect(self.refresh_plots)
tracking_output_queue = self.estimator_process.tracking_output_queue
else:
self.estimator = None
tracking_output_queue = self.tracking_output_queue

self.acc_tracking = QueueDataAccumulator(
name="tracking",
experiment=self,
data_queue=tracking_output_queue,
monitored_headers=self.pipeline.headers_to_plot,
)
self.acc_tracking.sig_acc_init.connect(self.refresh_plots)

# Data accumulator is updated with GUI timer:
self.gui_timer.timeout.connect(self.acc_tracking.update_list)

# Tracking is reset at experiment start:
self.protocol_runner.sig_protocol_started.connect(self.acc_tracking.reset)

self.acc_tracking_framerate = FramerateQueueAccumulator(
self,
Expand Down Expand Up @@ -376,9 +370,7 @@ def end_protocol(self, save=True):
def save_data(self):
"""Save tail position and dynamic parameters and terminate."""

self.window_main.camera_display.save_image(
name=self.filename_base() + "img.png"
)
self.window_main.camera_display.save_image(name=self.filename_base() + "img.png")
self.dc.add_static_data(self.filename_prefix() + "img.png", "tracking/image")

# Save log and estimators:
Expand Down
6 changes: 3 additions & 3 deletions stytra/gui/camera_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def retrieve_image(self):
# To match tracked points and frame displayed looks for matching
# timestamps from the two different queues:
retrieved_data = self.experiment.acc_tracking.values_at_abs_time(
self.current_frame_time
self.current_frame_time, self.experiment.t0
)
# Check for data to be displayed:
# Retrieve tail angles from tail
Expand Down Expand Up @@ -442,7 +442,7 @@ def retrieve_image(self):
# To match tracked points and frame displayed looks for matching
# timestamps from the two different queues:
retrieved_data = self.experiment.acc_tracking.values_at_abs_time(
self.current_frame_time
self.current_frame_time, self.experiment.t0
)
# Check for data to be displayed:

Expand Down Expand Up @@ -622,7 +622,7 @@ def retrieve_image(self):
return

current_data = self.experiment.acc_tracking.values_at_abs_time(
self.current_frame_time
self.current_frame_time, self.experiment.t0
)

n_fish = self.tracking_params.n_fish_max
Expand Down
29 changes: 29 additions & 0 deletions stytra/stimulation/estimator_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from multiprocessing import Event, Process
from typing import Type

from stytra.collectors import QueueDataAccumulator
from stytra.collectors.namedtuplequeue import NamedTupleQueue
from stytra.stimulation.estimators import Estimator


class EstimatorProcess(Process):
def __init__(
self,
estimator_cls: Type[Estimator],
tracking_queue: NamedTupleQueue,
finished_signal: Event,
):
super().__init__()
self.tracking_queue = tracking_queue
self.tracking_output_queue = NamedTupleQueue()
self.estimator_queue = NamedTupleQueue()
self.tracking_accumulator = QueueDataAccumulator(self.tracking_queue, self.tracking_output_queue)
self.finished_signal = finished_signal
self.estimator_cls = estimator_cls

def run(self):
estimator = self.estimator_cls(self.tracking_accumulator, self.estimator_queue)

while not self.finished_signal.is_set():
self.tracking_accumulator.update_list()
estimator.update()
10 changes: 5 additions & 5 deletions stytra/stimulation/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ class Estimator:
stream of the tracking pipelines (position in pixels, tail angles, etc.).
"""

def __init__(self, acc_tracking: QueueDataAccumulator, experiment):
self.exp = experiment
def __init__(self, acc_tracking: QueueDataAccumulator, output_queue: NamedTupleQueue, cam_to_proj=None):
self.acc_tracking = acc_tracking
self.output_queue = NamedTupleQueue()
self.output_queue = output_queue
self.cam_to_proj = cam_to_proj
self._output_type = None

def update(self):
Expand Down Expand Up @@ -184,8 +184,8 @@ def get_position(self) -> Tuple[float, PositionEstimate]:
past_coords = self.acc_tracking.stored_data[-1]
t = self.acc_tracking.times[-1]

if not self.calibrator.cam_to_proj is None:
projmat = np.array(self.calibrator.cam_to_proj)
if not self.cam_to_proj is None:
projmat = np.array(self.cam_to_proj)
if projmat.shape != (2, 3):
projmat = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]])

Expand Down
2 changes: 1 addition & 1 deletion stytra/tracking/tracking_process.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from queue import Empty, Full
from multiprocessing import Event, Value
from multiprocessing import Event

from stytra.utilities import FrameProcess
from arrayqueues.shared_arrays import TimestampedArrayQueue
Expand Down

0 comments on commit eb18b54

Please sign in to comment.