Skip to content

Commit

Permalink
progress on estimators
Browse files Browse the repository at this point in the history
  • Loading branch information
vilim committed Mar 6, 2022
1 parent 6b566f5 commit 3c41e64
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 24 deletions.
44 changes: 22 additions & 22 deletions stytra/collectors/accumulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,35 @@


class Accumulator(QObject):
def __init__(self, experiment, name="", max_history_if_not_running=1000):
def __init__(self, name="", max_trimmed_len=1000, trim = False):
super().__init__()
self.name = name
#self.exp = experiment
self.stored_data = []
self.times = []
self.max_history_if_not_running = max_history_if_not_running
self.max_trimmed_len = max_trimmed_len
self._trim = trim #

@property
def trim(self) -> bool:
return self._trim

def trim_data(self):
if self.trim and len(self.times) > self.max_trimmed_len * 1.5:
self.times[: -self.max_trimmed_len] = []
self.stored_data[: -self.max_trimmed_len] = []

@property
def t0(self) -> float:
raise NotImplementedError

def is_empty(self) -> bool:
return len(self.stored_data) == 0


class DataFrameAccumulator(Accumulator):
"""Abstract class for accumulating streams of data.
It is use to save or plot in real time data from stimulus logs or
It is used to save or plot in real time data from stimulus logs or
behavior tracking. Data is stored in a list in the stored_data
attribute.
Expand Down Expand Up @@ -134,14 +150,6 @@ def reset(self, monitored_headers=None):

self._header_dict = None

def trim_data(self):
if (
not self.exp.protocol_runner.running
and len(self.times) > self.max_history_if_not_running * 1.5
):
self.times[: -self.max_history_if_not_running] = []
self.stored_data[: -self.max_history_if_not_running] = []

def get_fps(self):
""" """
try:
Expand Down Expand Up @@ -229,9 +237,6 @@ def save(self, path, format="csv"):
saved_filename = save_df(df, path, format)
return basename(saved_filename)

def is_empty(self):
return len(self.stored_data) == 0


class QueueDataAccumulator(DataFrameAccumulator):
"""General class for retrieving data from a Queue.
Expand All @@ -248,9 +253,9 @@ class QueueDataAccumulator(DataFrameAccumulator):
data_queue : NamedTupleQueue
queue from witch to retrieve data.
output_queue:Optional[NamedTupleQueue]
an optinal queue to forward the data to
an optional queue to forward the data to
header_list : list of str
headers for the data to stored.
headers for the data to be stored.
"""

Expand Down Expand Up @@ -307,11 +312,6 @@ def __init__(self, *args, goal_framerate=None, **kwargs):
super().__init__(*args, **kwargs)
self.goal_framerate = goal_framerate

def trim_data(self):
if len(self.times) > self.max_history_if_not_running * 1.5:
self.times[: -self.max_history_if_not_running] = []
self.stored_data[: -self.max_history_if_not_running] = []

def reset(self):
self.times = []
self.stored_data = []
Expand Down
8 changes: 7 additions & 1 deletion stytra/experiments/tracking_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,15 @@ def __init__(self, *args, tracking, recording=None, second_output_queue=None, **
if est is not None:
self.estimator_process = EstimatorProcess(est_type, self.tracking_output_queue, self.finished_sig)
self.estimator_log = EstimatorLog(experiment=self)
self.estimator = est(self.acc_tracking, experiment=self, **tracking.get("estimator_params", {}))
self.estimator = est(self.acc_tracking, experiment=self)
first_est_params = tracking.get("estimator_params", None)
if first_est_params is not None:
self.estimator_process.estimator_parameter_queue.put(first_est_params)

self.estimator_log.sig_acc_init.connect(self.refresh_plots)
tracking_output_queue = self.estimator_process.tracking_output_queue
self.protocol_runner.attach_estimator_queue(self.est)
self.estimator_process.start()
else:
self.estimator = None
tracking_output_queue = self.tracking_output_queue
Expand Down
16 changes: 15 additions & 1 deletion stytra/stimulation/estimator_process.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from multiprocessing import Event, Process
from multiprocessing import Event, Process, Queue
from queue import Empty
from typing import Type

from stytra.collectors import QueueDataAccumulator
Expand All @@ -11,19 +12,32 @@ def __init__(
self,
estimator_cls: Type[Estimator],
tracking_queue: NamedTupleQueue,
estimator_parameter_queue: Queue,
finished_signal: Event,
):
super().__init__()
self.tracking_queue = tracking_queue
self.tracking_output_queue = NamedTupleQueue()
self.estimator_parameter_queue = estimator_parameter_queue
self.estimator_queue = NamedTupleQueue()
self.tracking_accumulator = QueueDataAccumulator(self.tracking_queue, self.tracking_output_queue)
self.finished_signal = finished_signal
self.estimator_cls = estimator_cls


def update_estimator_params(self, estimator):
while True:
try:
param_dict = self.estimator_parameter_queue.get(timeout=0.0001)
estimator.update_params(param_dict)
except Empty:
break


def run(self):
estimator = self.estimator_cls(self.tracking_accumulator, self.estimator_queue)

while not self.finished_signal.is_set():
self.update_estimator_params(estimator)
self.tracking_accumulator.update_list()
estimator.update()
4 changes: 4 additions & 0 deletions stytra/stimulation/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ def __init__(self, acc_tracking: QueueDataAccumulator, output_queue: NamedTupleQ
self.cam_to_proj = cam_to_proj
self._output_type = None

def update_params(self, **params):
for key, value in params.items():
setattr(self, key, value)

def update(self):
raise NotImplementedError

Expand Down

0 comments on commit 3c41e64

Please sign in to comment.